[논문 리뷰] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

smj·2026년 3월 31일

review

목록 보기
7/30

한줄 요약: 셀프 어텐션의 진짜 병목은 연산량이 아니라 GPU 메모리 읽기/쓰기(IO)이며, 타일링+online softmax로 메모리 O(n)에 2-4배 속도 향상을 정확하게 달성했다.

항목내용
저자Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
소속Stanford University, University at Buffalo
발표NeurIPS 2022
링크arxiv.org/abs/2205.14135
키워드Attention, IO-Awareness, GPU Optimization, Memory Efficiency

1. 문제 정의

셀프 어텐션은 O(n²) 연산과 O(n²) 메모리를 필요로 한다. 많은 연구가 연산량을 줄이는 근사 어텐션(Sparse, Linear 등)에 집중했다. 하지만 이 논문의 핵심 관찰: 현대 GPU에서 어텐션의 실제 병목은 연산이 아니라 GPU HBM(High Bandwidth Memory) ↔ SRAM 사이의 데이터 전송(IO) 이다.

GPU 메모리 계층:
  SRAM (on-chip): ~20MB, ~19 TB/s 대역폭 — 빠르지만 작다
  HBM (off-chip):  ~40-80GB, ~2 TB/s 대역폭 — 크지만 느리다

표준 어텐션의 IO 패턴:
  1. Q, K를 HBM에서 읽음 → S = QK^T 계산 → S를 HBM에 저장
  2. S를 HBM에서 읽음 → P = softmax(S) → P를 HBM에 저장
  3. P, V를 HBM에서 읽음 → O = PV 계산 → O를 HBM에 저장
  → 중간 행렬 S, P가 HBM에 왕복 → IO가 병목!

2. 제안 방법

핵심 기법 1: 타일링 (Tiling)

Q, K, V를 블록 단위로 SRAM에 로드:
  for each Q_block:
    for each K_block, V_block:
      S_block = Q_block @ K_block.T     ← SRAM에서 계산
      P_block = softmax(S_block)         ← SRAM에서 계산
      O_block += P_block @ V_block       ← SRAM에서 누적
  → 중간 행렬 S, P를 HBM에 저장하지 않음!
  → HBM 접근: O(n²/M) (M = SRAM 크기) vs 표준 O(n²)

핵심 기법 2: Online Softmax

타일링의 난제: softmax는 전체 행을 봐야 계산 가능. 블록 단위로 어떻게?

Milakov & Gimelshein의 Online Softmax:
  블록 1: max₁ = max(S₁), sum₁ = Σexp(S₁ - max₁)
  블록 2: max₂ = max(max₁, max(S₂))
          → 이전 결과를 새 max로 보정
          sum₂ = sum₁ × exp(max₁ - max₂) + Σexp(S₂ - max₂)
  ...
  → 최종 결과는 표준 softmax와 수학적으로 정확히 동일!
  → 근사가 아님 (exact)

결과: IO 복잡도

HBM 읽기/쓰기추가 메모리
표준 어텐션O(n²)O(n²) (S, P 저장)
Flash AttentionO(n²/M)O(n) (softmax 통계만)

3. 실험 결과

3.1 속도 벤치마크

GPT-2 어텐션 (시퀀스 길이 1024, A100):

구현속도 (TFLOPs/s)배율
PyTorch 표준41.61.0x
Megatron-LM55.01.3x
Flash Attention124.83.0x

3.2 메모리 사용량

시퀀스 길이표준 어텐션Flash Attention
1K428 MB39 MB
4K6.8 GB155 MB
16KOOM620 MB

→ 메모리 ~10-40배 절약, 기존에 OOM이던 긴 시퀀스도 처리 가능

3.3 End-to-End 학습 속도

모델표준 대비 Flash Attention
BERT-large15% 빠름
GPT-2 medium3x 빠름 (long sequence)
Long Range Arena2.4x 빠름

4. 한계점

  • CUDA 커널 수준 구현 필요: Python만으로는 불가, 하드웨어 전문 지식 필요
  • 역전파(backward pass) 최적화가 추가로 필요: 체크포인팅으로 해결하지만, 구현 복잡
  • GPU 아키텍처 특화: A100 기준 최적화 → 다른 GPU에서는 조정 필요
  • 블록 크기 선택: SRAM 크기에 따라 최적 블록 크기가 달라짐 → 자동 튜닝 필요
  • Causal masking 등 변형 지원: 추가 구현 복잡도 (Flash Attention 2에서 해결)
  • 근사 어텐션과 결합 어려움: Flash Attention은 정확한 어텐션이므로, 근사 기법과의 시너지 제한적

5. 후속 발전

버전핵심 개선
Flash Attention 2 (2023)워크 파티셔닝 최적화, 시퀀스 병렬, causal masking 효율화 → 추가 2x
Flash Attention 3 (2024)H100 Hopper 최적화, FP8 지원, 비동기 실행 → 추가 1.5-2x

6. 의의와 영향

  • 사실상 모든 현대 LLM의 학습/추론에 사용: PyTorch 2.0에 기본 통합 (torch.nn.functional.scaled_dot_product_attention)
  • 정확한(exact) 어텐션이면서 빠르다: 근사 어텐션의 품질 손실 없이 속도/메모리 개선
  • Long context의 기술적 기반: Flash Attention 없이 128K 컨텍스트 실현 불가
  • "IO-awareness"라는 최적화 관점 확립: 알고리즘 설계 시 하드웨어 메모리 계층을 고려
  • Ring Attention, PagedAttention 등 후속 연구의 빌딩 블록

7. 💬 리뷰어 코멘트

이 논문의 가장 놀라운 점은 "어텐션을 근사하지 않는다"는 것이다. 기존 효율적 어텐션 연구(Sparse Transformer, Linformer, Performer 등)는 모두 O(n²)을 O(n) 또는 O(n√n)으로 줄이기 위해 근사를 도입했다. Flash Attention은 연산량 자체는 O(n²)이지만, IO를 O(n²/M)으로 줄여 실제 속도를 2-4배 높인다.

이 접근이 성공한 이유: 현대 GPU는 연산은 넘치지만(compute-rich) 메모리 대역폭이 부족하다(memory-bound). 표준 어텐션은 GPU 연산력의 ~5-20%만 활용하고 나머지는 메모리에서 데이터가 오기를 기다린다. Flash Attention은 이 유휴 연산력을 활용하여 "같은 FLOPs인데 더 빠른" 결과를 달성한다.

시스템 논문이 AI 연구 전체에 미치는 영향의 극적인 사례다. Tri Dao 한 명의 시스템 최적화가 전 세계 LLM 학습 비용을 수십% 절감했다.


관련 논문: Ring Attention, PagedAttention (vLLM), Mamba

0개의 댓글