한줄 요약: 셀프 어텐션의 진짜 병목은 연산량이 아니라 GPU 메모리 읽기/쓰기(IO)이며, 타일링+online softmax로 메모리 O(n)에 2-4배 속도 향상을 정확하게 달성했다.
| 항목 | 내용 |
|---|---|
| 저자 | Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré |
| 소속 | Stanford University, University at Buffalo |
| 발표 | NeurIPS 2022 |
| 링크 | arxiv.org/abs/2205.14135 |
| 키워드 | Attention, IO-Awareness, GPU Optimization, Memory Efficiency |
셀프 어텐션은 O(n²) 연산과 O(n²) 메모리를 필요로 한다. 많은 연구가 연산량을 줄이는 근사 어텐션(Sparse, Linear 등)에 집중했다. 하지만 이 논문의 핵심 관찰: 현대 GPU에서 어텐션의 실제 병목은 연산이 아니라 GPU HBM(High Bandwidth Memory) ↔ SRAM 사이의 데이터 전송(IO) 이다.
GPU 메모리 계층:
SRAM (on-chip): ~20MB, ~19 TB/s 대역폭 — 빠르지만 작다
HBM (off-chip): ~40-80GB, ~2 TB/s 대역폭 — 크지만 느리다
표준 어텐션의 IO 패턴:
1. Q, K를 HBM에서 읽음 → S = QK^T 계산 → S를 HBM에 저장
2. S를 HBM에서 읽음 → P = softmax(S) → P를 HBM에 저장
3. P, V를 HBM에서 읽음 → O = PV 계산 → O를 HBM에 저장
→ 중간 행렬 S, P가 HBM에 왕복 → IO가 병목!
Q, K, V를 블록 단위로 SRAM에 로드:
for each Q_block:
for each K_block, V_block:
S_block = Q_block @ K_block.T ← SRAM에서 계산
P_block = softmax(S_block) ← SRAM에서 계산
O_block += P_block @ V_block ← SRAM에서 누적
→ 중간 행렬 S, P를 HBM에 저장하지 않음!
→ HBM 접근: O(n²/M) (M = SRAM 크기) vs 표준 O(n²)
타일링의 난제: softmax는 전체 행을 봐야 계산 가능. 블록 단위로 어떻게?
Milakov & Gimelshein의 Online Softmax:
블록 1: max₁ = max(S₁), sum₁ = Σexp(S₁ - max₁)
블록 2: max₂ = max(max₁, max(S₂))
→ 이전 결과를 새 max로 보정
sum₂ = sum₁ × exp(max₁ - max₂) + Σexp(S₂ - max₂)
...
→ 최종 결과는 표준 softmax와 수학적으로 정확히 동일!
→ 근사가 아님 (exact)
| HBM 읽기/쓰기 | 추가 메모리 | |
|---|---|---|
| 표준 어텐션 | O(n²) | O(n²) (S, P 저장) |
| Flash Attention | O(n²/M) | O(n) (softmax 통계만) |
GPT-2 어텐션 (시퀀스 길이 1024, A100):
| 구현 | 속도 (TFLOPs/s) | 배율 |
|---|---|---|
| PyTorch 표준 | 41.6 | 1.0x |
| Megatron-LM | 55.0 | 1.3x |
| Flash Attention | 124.8 | 3.0x |
| 시퀀스 길이 | 표준 어텐션 | Flash Attention |
|---|---|---|
| 1K | 428 MB | 39 MB |
| 4K | 6.8 GB | 155 MB |
| 16K | OOM | 620 MB |
→ 메모리 ~10-40배 절약, 기존에 OOM이던 긴 시퀀스도 처리 가능
| 모델 | 표준 대비 Flash Attention |
|---|---|
| BERT-large | 15% 빠름 |
| GPT-2 medium | 3x 빠름 (long sequence) |
| Long Range Arena | 2.4x 빠름 |
| 버전 | 핵심 개선 |
|---|---|
| Flash Attention 2 (2023) | 워크 파티셔닝 최적화, 시퀀스 병렬, causal masking 효율화 → 추가 2x |
| Flash Attention 3 (2024) | H100 Hopper 최적화, FP8 지원, 비동기 실행 → 추가 1.5-2x |
이 논문의 가장 놀라운 점은 "어텐션을 근사하지 않는다"는 것이다. 기존 효율적 어텐션 연구(Sparse Transformer, Linformer, Performer 등)는 모두 O(n²)을 O(n) 또는 O(n√n)으로 줄이기 위해 근사를 도입했다. Flash Attention은 연산량 자체는 O(n²)이지만, IO를 O(n²/M)으로 줄여 실제 속도를 2-4배 높인다.
이 접근이 성공한 이유: 현대 GPU는 연산은 넘치지만(compute-rich) 메모리 대역폭이 부족하다(memory-bound). 표준 어텐션은 GPU 연산력의 ~5-20%만 활용하고 나머지는 메모리에서 데이터가 오기를 기다린다. Flash Attention은 이 유휴 연산력을 활용하여 "같은 FLOPs인데 더 빠른" 결과를 달성한다.
시스템 논문이 AI 연구 전체에 미치는 영향의 극적인 사례다. Tri Dao 한 명의 시스템 최적화가 전 세계 LLM 학습 비용을 수십% 절감했다.
관련 논문: Ring Attention, PagedAttention (vLLM), Mamba