Sliding Window Attention

제투아·2025년 3월 26일

Sliding Window Attention 알고리즘 정리

기본 개념

Sliding Window Attention은 Transformer의 attention 계산을 각 토큰 주변의 제한된 윈도우로 한정하여 긴 시퀀스를 효율적으로 처리합니다.

수학적 표현

표준 Self-Attention

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Sliding Window Attention

윈도우 마스크 M을 적용:
Attention(Q,K,V)=softmax(QKTdk+M)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V

여기서 M은 윈도우 바깥의 값은 -∞, 윈도우 내부는 0인 마스크입니다.

구현 예시

class SlidingWindowAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, window_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.window_size = window_size
        
        # 투영 레이어
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, x):
        # 1. Q, K, V 계산
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        # 2. 슬라이딩 윈도우 마스크 생성
        seq_len = x.size(1)
        mask = self._create_window_mask(seq_len)
        
        # 3. 어텐션 계산
        attn_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.hidden_size ** 0.5)
        attn_scores = attn_scores + mask  # 마스크 적용
        
        # 4. 소프트맥스 및 출력 계산
        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_weights, v)
        
        return self.out_proj(output)
    
    def _create_window_mask(self, seq_len):
        # 윈도우 마스크 생성 함수
        mask = torch.ones(seq_len, seq_len) * float('-inf')
        for i in range(seq_len):
            start = max(0, i - self.window_size // 2)
            end = min(seq_len, i + self.window_size // 2 + 1)
            mask[i, start:end] = 0.0
        return mask

최적화된 계산 방법

각 위치마다 필요한 윈도우 범위만 계산:

def efficient_sliding_attention(query, key, value, window_size):
    batch_size, seq_len = query.shape[0], query.shape[1]
    results = []
    
    for pos in range(seq_len):
        # 현재 위치의 윈도우 범위 계산
        start = max(0, pos - window_size // 2)
        end = min(seq_len, pos + window_size // 2 + 1)
        
        # 윈도우 내 요소만 추출
        q = query[:, pos:pos+1]
        k = key[:, start:end]
        v = value[:, start:end]
        
        # 어텐션 계산
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(q.size(-1))
        weights = F.softmax(scores, dim=-1)
        context = torch.matmul(weights, v)
        
        results.append(context)
    
    return torch.cat(results, dim=1)

주요 변형 기법

1. 확장된 윈도우(Dilated Window)

일정 간격으로 토큰을 건너뛰어 더 넓은 범위 포착:

# 확장 계수가 2인 마스크 생성
def create_dilated_mask(seq_len, window_size, dilation=2):
    mask = torch.ones(seq_len, seq_len) * float('-inf')
    for i in range(seq_len):
        for j in range(-window_size, window_size+1):
            idx = i + j * dilation
            if 0 <= idx < seq_len:
                mask[i, idx] = 0.0
    return mask

2. 글로벌-로컬 혼합

특정 토큰은 전체 시퀀스와 attention 계산:

def global_local_mask(seq_len, window_size, global_idx=0):
    mask = torch.ones(seq_len, seq_len) * float('-inf')
    
    # 로컬 윈도우
    for i in range(seq_len):
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)
        mask[i, start:end] = 0.0
    
    # 글로벌 토큰
    mask[global_idx, :] = 0.0  # 글로벌 토큰이 모두 볼 수 있음
    mask[:, global_idx] = 0.0  # 모든 토큰이 글로벌 토큰을 볼 수 있음
    
    return mask

복잡도 비교

어텐션 유형시간 복잡도공간 복잡도
표준 Self-AttentionO(N²)O(N²)
Sliding Window AttentionO(N×W)O(N×W)

여기서 N은 시퀀스 길이, W는 윈도우 크기입니다.

profile
Zero to AGI

0개의 댓글