FlashAttention

HanJu Han·2024년 10월 15일

LLM 최적화

목록 보기
1/16

FlashAttention은 Transformer 모델의 self-attention 메커니즘을 최적화하는 기술입니다. 주요 특징과 장점은 다음과 같습니다:

  1. 메모리 효율성:
    기존 self-attention은 모든 데이터를 메모리에 한 번에 로드했지만, FlashAttention은 데이터를 HBM(High Bandwidth Memory)과 SRAM(Static Random-Access Memory) 사이에서 효율적으로 이동시킵니다. 이를 통해 메모리 사용량을 크게 줄일 수 있습니다.

  2. 속도 향상:
    데이터를 빠르게 이동시키고 처리함으로써, 전체적인 연산 속도가 향상됩니다. 특히 긴 시퀀스를 처리할 때 더욱 효과적입니다.

  3. FlashAttention-2:
    FlashAttention의 개선된 버전으로, forward와 backward 알고리즘을 수정하여 연산 속도를 더욱 개선했습니다. 원본 대비 약 2배의 속도 향상을 달성했다고 합니다.

FlashAttention은 대규모 언어 모델의 학습과 추론 속도를 크게 향상시키며, 더 긴 문장이나 문서를 효율적으로 처리할 수 있게 합니다. 예를 들어, GPT-3와 같은 대형 모델에서 이 기술을 적용하면, 학습 시간을 크게 단축하고 더 긴 컨텍스트를 처리할 수 있게 되어 모델의 성능 향상에 기여할 수 있습니다.


예제

HBM(High Bandwidth Memory)과 SRAM(Static Random-Access Memory)은 컴퓨터 시스템에서 사용되는 두 가지 다른 유형의 메모리입니다.

  1. HBM (High Bandwidth Memory):

    • 고대역폭 메모리로, 대량의 데이터를 매우 빠르게 전송할 수 있습니다.
    • 주로 그래픽 카드나 AI 가속기에서 사용됩니다.
    • 대용량이지만 상대적으로 접근 속도가 느립니다.
  2. SRAM (Static Random-Access Memory):

    • 정적 랜덤 접근 메모리로, 매우 빠른 읽기와 쓰기가 가능합니다.
    • CPU 캐시 메모리로 주로 사용됩니다.
    • 용량이 작지만 접근 속도가 매우 빠릅니다.

실질적인 예제를 통해 FlashAttention의 작동 방식을 설명해보겠습니다:

예: 긴 문서 요약 작업

문서: "인공지능의 발전은 20세기 후반부터 가속화되었습니다. 머신러닝, 딥러닝 등의 기술이 등장하면서 AI는 다양한 분야에 적용되기 시작했습니다. 특히 자연어 처리, 컴퓨터 비전, 로봇공학 등에서 큰 진전을 이루었습니다. 21세기에 들어서면서 AI는 우리의 일상생활에 깊숙이 파고들었습니다. 스마트폰의 음성 비서, 추천 시스템, 자율주행 자동차 등이 그 예입니다."

기존 Attention 방식:
1. 전체 문서를 HBM에 로드합니다.
2. 모든 단어 쌍 간의 관계를 한 번에 계산합니다.
3. 이 과정에서 대량의 메모리를 사용하며, 긴 문서의 경우 처리 속도가 느려집니다.

FlashAttention 방식:
1. 문서를 작은 청크로 나눕니다. 예: ["인공지능의 발전은", "20세기 후반부터 가속화되었습니다.", ...]
2. 첫 번째 청크를 SRAM에 로드하고 처리합니다.
3. 결과를 HBM에 저장하고, 다음 청크를 SRAM에 로드합니다.
4. 이 과정을 반복하며, 필요할 때마다 이전 결과를 HBM에서 SRAM으로 가져와 활용합니다.

이러한 방식으로 FlashAttention은:

  • 메모리 사용을 최적화합니다 (한 번에 SRAM에 작은 청크만 로드).
  • 데이터 이동을 효율적으로 관리하여 전체 처리 속도를 향상시킵니다.
  • 더 긴 문서도 효율적으로 처리할 수 있게 됩니다.

결과적으로, 이 예제의 문서 요약 작업을 더 빠르고 효율적으로 수행할 수 있게 되어, AI 모델의 성능과 처리 능력이 크게 향상됩니다.


FlashAttention 방식은 전체 문장의 유사도 계산이 된다.

FlashAttention은 전체 문장의 유사도를 계산할 수 있으며, 동시에 효율성도 높입니다.

  1. 청크 기반 처리:
    FlashAttention은 입력을 작은 청크로 나누어 처리하지만, 이는 전체 문장의 정보를 무시한다는 의미가 아닙니다.

  2. 누적 계산:
    각 청크를 처리할 때, 이전 청크들의 정보를 누적하여 계산에 포함시킵니다. 이를 통해 전체 문맥을 유지합니다.

  3. 효율적인 메모리 관리:
    HBM과 SRAM을 효율적으로 사용하여 더 큰 컨텍스트 윈도우를 처리할 수 있게 합니다.

  4. 수학적 동등성:
    FlashAttention의 결과는 기존 Attention 메커니즘과 수학적으로 동일합니다. 차이점은 계산 방식과 메모리 사용의 효율성에 있습니다.

예를 들어 설명하겠습니다:

문장: "인공지능은 빠르게 발전하고 있으며 우리의 삶을 변화시키고 있다"

기존 Attention:

  • 모든 단어 쌍의 관계를 한 번에 계산합니다.
  • 메모리 사용량이 크고, 긴 문장에서 비효율적입니다.

FlashAttention:
1. 문장을 청크로 나눕니다: ["인공지능은 빠르게", "발전하고 있으며", "우리의 삶을", "변화시키고 있다"]
2. 첫 번째 청크 "인공지능은 빠르게"를 처리합니다.
3. 결과를 저장하고, 다음 청크 "발전하고 있으며"를 처리할 때 이전 결과를 고려합니다.
4. 이 과정을 반복하며, 각 단계에서 전체 문맥을 고려합니다.

결과적으로:

  • 전체 문장의 유사도와 관계가 정확히 계산됩니다.
  • 메모리 사용이 최적화되어 더 긴 문장도 효율적으로 처리할 수 있습니다.
  • 계산 속도가 향상되어 대규모 언어 모델의 성능이 개선됩니다.

따라서 FlashAttention은 전체 문장의 유사도를 정확히 계산하면서도, 메모리와 계산 효율성을 크게 향상시키는 방법입니다.

profile
시리즈를 기반으로 작성하였습니다.

0개의 댓글