Sliding Window Attention은 Transformer의 attention 계산을 각 토큰 주변의 제한된 윈도우로 한정하여 긴 시퀀스를 효율적으로 처리합니다.
윈도우 마스크 M을 적용:
여기서 M은 윈도우 바깥의 값은 -∞, 윈도우 내부는 0인 마스크입니다.
class SlidingWindowAttention(nn.Module):
def __init__(self, hidden_size, num_heads, window_size):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.window_size = window_size
# 투영 레이어
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.out_proj = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
# 1. Q, K, V 계산
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# 2. 슬라이딩 윈도우 마스크 생성
seq_len = x.size(1)
mask = self._create_window_mask(seq_len)
# 3. 어텐션 계산
attn_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.hidden_size ** 0.5)
attn_scores = attn_scores + mask # 마스크 적용
# 4. 소프트맥스 및 출력 계산
attn_weights = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
return self.out_proj(output)
def _create_window_mask(self, seq_len):
# 윈도우 마스크 생성 함수
mask = torch.ones(seq_len, seq_len) * float('-inf')
for i in range(seq_len):
start = max(0, i - self.window_size // 2)
end = min(seq_len, i + self.window_size // 2 + 1)
mask[i, start:end] = 0.0
return mask
각 위치마다 필요한 윈도우 범위만 계산:
def efficient_sliding_attention(query, key, value, window_size):
batch_size, seq_len = query.shape[0], query.shape[1]
results = []
for pos in range(seq_len):
# 현재 위치의 윈도우 범위 계산
start = max(0, pos - window_size // 2)
end = min(seq_len, pos + window_size // 2 + 1)
# 윈도우 내 요소만 추출
q = query[:, pos:pos+1]
k = key[:, start:end]
v = value[:, start:end]
# 어텐션 계산
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(q.size(-1))
weights = F.softmax(scores, dim=-1)
context = torch.matmul(weights, v)
results.append(context)
return torch.cat(results, dim=1)
일정 간격으로 토큰을 건너뛰어 더 넓은 범위 포착:
# 확장 계수가 2인 마스크 생성
def create_dilated_mask(seq_len, window_size, dilation=2):
mask = torch.ones(seq_len, seq_len) * float('-inf')
for i in range(seq_len):
for j in range(-window_size, window_size+1):
idx = i + j * dilation
if 0 <= idx < seq_len:
mask[i, idx] = 0.0
return mask
특정 토큰은 전체 시퀀스와 attention 계산:
def global_local_mask(seq_len, window_size, global_idx=0):
mask = torch.ones(seq_len, seq_len) * float('-inf')
# 로컬 윈도우
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
mask[i, start:end] = 0.0
# 글로벌 토큰
mask[global_idx, :] = 0.0 # 글로벌 토큰이 모두 볼 수 있음
mask[:, global_idx] = 0.0 # 모든 토큰이 글로벌 토큰을 볼 수 있음
return mask
| 어텐션 유형 | 시간 복잡도 | 공간 복잡도 |
|---|---|---|
| 표준 Self-Attention | O(N²) | O(N²) |
| Sliding Window Attention | O(N×W) | O(N×W) |
여기서 N은 시퀀스 길이, W는 윈도우 크기입니다.