Flash Attention은 GPU 메모리 계층 구조(SRAM vs HBM)를 활용해 기존 어텐션 메커니즘의 메모리 병목 현상을 해결하는 알고리즘임.
핵심 아이디어:
# O(N²) 메모리 사용 - HBM에 전체 어텐션 행렬 저장
scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d) # [B,H,N,N]
attn_weights = torch.softmax(scores, dim=-1) # [B,H,N,N]
output = torch.matmul(attn_weights, V) # [B,H,N,D]
소프트맥스를 블록 단위로 계산하는 재귀식 사용
블록 에서 최대값 와 합 를 추적하면서:
재귀적으로:
def flash_attention(q, k, v):
"""Flash Attention 의사 코드"""
B, H, N, D = q.shape # 배치, 헤드, 시퀀스 길이, 차원
# 출력과 소프트맥스 통계 초기화
o = torch.zeros_like(q)
l = torch.zeros((B, H, N, 1)) # 누적 합
m = torch.ones((B, H, N, 1)) * float('-inf') # 최대값
# 블록 크기 (SRAM에 맞게 조정)
Br, Bc = 64, 64
# 블록 단위 처리
for i in range(0, N, Br):
# Q 블록 로드
q_block = q[:, :, i:min(i+Br, N), :]
o_block = torch.zeros_like(q_block)
m_block = m[:, :, i:min(i+Br, N), :]
l_block = l[:, :, i:min(i+Br, N), :]
for j in range(0, N, Bc):
# K, V 블록 로드
k_block = k[:, :, j:min(j+Bc, N), :]
v_block = v[:, :, j:min(j+Bc, N), :]
# SRAM에서 블록 어텐션 계산
s_block = torch.matmul(q_block, k_block.transpose(-1, -2)) / math.sqrt(D)
# 최대값 업데이트
m_block_new = torch.maximum(m_block, s_block.max(dim=-1, keepdim=True)[0])
# 소프트맥스 통계 업데이트
m_diff = m_block_new - m_block
l_block_new = l_block * torch.exp(m_diff) + torch.sum(
torch.exp(s_block - m_block_new.unsqueeze(-1)), dim=-1, keepdim=True)
# 출력 업데이트
o_block = o_block * torch.exp(m_diff) + torch.matmul(
torch.exp(s_block - m_block_new.unsqueeze(-1)), v_block)
# 통계 업데이트
m_block, l_block = m_block_new, l_block_new
# 정규화된 결과 저장
o[:, :, i:min(i+Br, N), :] = o_block / l_block
return o