Flash Attention 핵심 정리

제투아·2025년 3월 26일

1. 기본 개념

Flash Attention은 GPU 메모리 계층 구조(SRAM vs HBM)를 활용해 기존 어텐션 메커니즘의 메모리 병목 현상을 해결하는 알고리즘임.

핵심 아이디어:

  • 전체 N×N 어텐션 행렬을 HBM에 저장하지 않음
  • 작은 블록 단위로 SRAM에서 계산 수행
  • 소프트맥스를 재귀적으로 계산

2. 문제 정의와 해결책

표준 어텐션

# O(N²) 메모리 사용 - HBM에 전체 어텐션 행렬 저장
scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d)  # [B,H,N,N]
attn_weights = torch.softmax(scores, dim=-1)                  # [B,H,N,N]
output = torch.matmul(attn_weights, V)                        # [B,H,N,D]

Flash Attention 해결책

소프트맥스를 블록 단위로 계산하는 재귀식 사용

3. 수학적 정의

재귀적 소프트맥스 계산

블록 jj에서 최대값 mjm_j와 합 ljl_j를 추적하면서:

Oi=j=1Nexp(QiKjT/d)Vjj=1Nexp(QiKjT/d)O_i = \frac{\sum_{j=1}^{N} \exp(Q_i K_j^T/\sqrt{d}) V_j}{\sum_{j=1}^{N} \exp(Q_i K_j^T/\sqrt{d})}

재귀적으로:

mi=max(mi,maxbBQiKbT/d)m_i' = \max(m_i, \max_{b \in B} Q_i K_b^T/\sqrt{d})

li=liexp(mimi)+bBexp(QiKbT/dmi)l_i' = l_i \cdot \exp(m_i - m_i') + \sum_{b \in B} \exp(Q_i K_b^T/\sqrt{d} - m_i')

Oi=Oiliexp(mimi)+bBexp(QiKbT/dmi)VbliO_i' = \frac{O_i \cdot l_i \cdot \exp(m_i - m_i') + \sum_{b \in B} \exp(Q_i K_b^T/\sqrt{d} - m_i') V_b}{l_i'}

4. 핵심 구현 코드

def flash_attention(q, k, v):
    """Flash Attention 의사 코드"""
    B, H, N, D = q.shape  # 배치, 헤드, 시퀀스 길이, 차원
    
    # 출력과 소프트맥스 통계 초기화
    o = torch.zeros_like(q)
    l = torch.zeros((B, H, N, 1))  # 누적 합
    m = torch.ones((B, H, N, 1)) * float('-inf')  # 최대값
    
    # 블록 크기 (SRAM에 맞게 조정)
    Br, Bc = 64, 64
    
    # 블록 단위 처리
    for i in range(0, N, Br):
        # Q 블록 로드
        q_block = q[:, :, i:min(i+Br, N), :]
        o_block = torch.zeros_like(q_block)
        m_block = m[:, :, i:min(i+Br, N), :]
        l_block = l[:, :, i:min(i+Br, N), :]
        
        for j in range(0, N, Bc):
            # K, V 블록 로드
            k_block = k[:, :, j:min(j+Bc, N), :]
            v_block = v[:, :, j:min(j+Bc, N), :]
            
            # SRAM에서 블록 어텐션 계산
            s_block = torch.matmul(q_block, k_block.transpose(-1, -2)) / math.sqrt(D)
            
            # 최대값 업데이트
            m_block_new = torch.maximum(m_block, s_block.max(dim=-1, keepdim=True)[0])
            
            # 소프트맥스 통계 업데이트
            m_diff = m_block_new - m_block
            l_block_new = l_block * torch.exp(m_diff) + torch.sum(
                torch.exp(s_block - m_block_new.unsqueeze(-1)), dim=-1, keepdim=True)
            
            # 출력 업데이트
            o_block = o_block * torch.exp(m_diff) + torch.matmul(
                torch.exp(s_block - m_block_new.unsqueeze(-1)), v_block)
            
            # 통계 업데이트
            m_block, l_block = m_block_new, l_block_new
        
        # 정규화된 결과 저장
        o[:, :, i:min(i+Br, N), :] = o_block / l_block
        
    return o

5. 결론

  • 메모리 복잡도: O(N²) → O(N)
  • 계산 효율성: 실제 FLOPs는 같지만 메모리 대역폭 낭비 제거
  • 속도: 시퀀스 길이에 따라 2-7배 속도 향상
  • 최대 길이: 일반 어텐션의 8K 제한에서 64K+ 토큰 처리 가능
profile
Zero to AGI

0개의 댓글