Transformers Inference Optimization Toolset

손정기·2024년 12월 6일
0

DNN

목록 보기
9/9

대규모 언어 모델(Large Language Models, LLM)은 인공지능의 경계를 확장하고 있지만, 그 방대한 크기는 상당한 계산적 도전 과제를 안겨줍니다. 이러한 모델이 커질수록, 이를 현대 하드웨어에서 효율적으로 실행할 수 있게 하는 스마트한 최적화 기술의 필요성도 함께 증가합니다.

이 글에서는 LLM을 더 빠르고 메모리 효율적으로 만드는 주요 최적화 전략들을 탐구할 것입니다. 먼저, 이러한 기술의 기초를 형성하는 GPU 메모리 계층 구조(GPU memory hierarchy)를 간략히 살펴본 후, LLM이 정보를 더 빠르게 처리하고 더 긴 문맥을 다룰 수 있게 하는 알고리즘들을 탐구할 것입니다. 이러한 최적화 기법을 이해하면 대규모 언어 모델(Large Language Models)의 잠재력을 완전히 발휘할 수 있는 귀중한 통찰을 얻게 될 것입니다.

이 글의 목적은 트랜스포머(transformer)에 대한 특정 최적화 기법만을 논의하려는 것이 아닙니다. 트랜스포머를 더 빠르게 만들기 위한 수많은 자료가 이미 있으며, 그 중에서 제가 가장 좋아하는 것은 Andrej Karpathy의 "GPT-2 재현하기(Let's reproduce GPT-2)"입니다. 이 글의 주된 목표는 현재 수많은 기사와 논문을 한데 모으지 못하고 있는 연구자들에게 진입 장벽을 낮춰주는 것입니다.

많은 최적화 기법들은 다루지 않을 것입니다. 예를 들어, 양자화 기법(quantization methods)은 상대적으로 다양하며 별도의 글에서 다루어야 할 주제입니다. 또한 주로 트랜스포머 추론(transformer inference)에 대해 논의할 것이며, 혼합 정밀도 훈련(mixed-precision training), 그래디언트 체크포인팅(gradient checkpointing) 또는 시퀀스 패킹(sequence packing)과 같은 훈련 시 사용되는 몇 가지 트릭은 언급하지 않을 것입니다. 그럼에도 불구하고, 이 글에서 다루는 많은 최적화 기법들은 훈련에도 적용할 수 있습니다.

GPU 아키텍처 개요

언어 모델 속도 향상 문제를 해결하려면, 우리가 사용하는 하드웨어의 개념을 먼저 이해할 필요가 있습니다. 구글의 TPU(Tensor Processing Unit)나 애플 실리콘 칩이 부상하고 있지만, NVIDIA의 GPU가 여전히 시장을 지배하고 있으므로 이들을 중점적으로 살펴볼 것입니다.

그래픽 처리 장치(Graphic Processor Unit, GPU)는 여러 스트리밍 멀티프로세서(Streaming Multiprocessors, SM)를 통해 모든 연산을 수행합니다. SM은 CPU의 코어와 유사한 GPU의 기본 구성 요소로, 자체적인 명령 스케줄러와 다양한 명령 실행 파이프라인을 가지고 있습니다. 최신 GPU는 고대역폭 메모리(High Bandwidth Memory, HBM)라는 특수한 오프칩(off-chip) 메모리도 갖추고 있으며, 데이터가 처음 저장되고 최종적으로 기록되는 곳입니다. CPU가 제어하는 동적 랜덤 접근 메모리(Dynamic Random Access Memory, DRAM)는 일반적으로 저지연 접근(low latency access)에 최적화되어 있지만, HBM은 GPU에 물리적으로 적층된 레이어로 연결되어 수천 개의 핀을 통해 설계상 대규모 병렬 데이터 처리량을 제공합니다.

스트리밍 멀티프로세서(SM)는 L2 캐시(cache)를 통해 HBM에서 데이터와 코드를 액세스합니다. L2 캐시는 오프칩 메모리와 온칩 메모리(on-chip memory) 사이의 중간 단계로 작용하며, 여러 SM 간에 공유될 수 있는 데이터를 캐싱합니다. 또한 장치 간 데이터를 이동하는 경로에 위치해 있습니다. 마지막으로, 각 SM은 자체 L1 캐시와 공유 메모리(Shared Memory, SRAM)를 가지고 있으며, 이는 저지연 온칩 메모리 캐시입니다. L1 캐시와 SRAM은 HBM보다 몇 배 더 빠르지만, 크기는 매우 작습니다. L1 캐시는 GPU 하드웨어에 의해 관리되며, SRAM은 NVIDIA 도구를 통해 프로그래머가 명시적으로 관리할 수 있습니다.

GPUNVLink라는 고대역폭 인터커넥트를 통해 서로 통신할 수 있으며, PCIe 버스(모든 마더보드에 흔히 사용되는 고속 버스 표준)를 사용해 외부와 연결되거나 Infiniband라는 특수한 이더넷 대안을 통해 데이터를 전송할 수 있습니다. 일반적으로 8개의 GPU가 하나의 노드에 탑재됩니다. 병렬 장치 훈련(multi-device training)에 대해 더 알고 싶다면 제 병렬화 전략(parallelization strategies) 글을 참조하시기 바랍니다.

![[Pasted image 20241015090831.png]]

이제 GPU 성능에 대해 논의할 때, 우리는 세 가지를 살펴봐야 합니다:

  1. 계산 성능(compute performance)은 초당 수행할 수 있는 수조 개의 부동 소수점 연산(TFLOPS)으로 측정됩니다.
  2. GPU 메모리(GPU memory)는 모델 파라미터(parameters), 히든 액티베이션(hidden activations)캐시 값(cache values)을 저장하는 데 필요하며, 단위는 GB로 측정됩니다. 예를 들어, GPT-3은 1,750억 개의 파라미터(parameters)를 가지고 있으므로, 이를 fp16으로 장치에 저장하려면 350GB의 메모리(memory)가 필요합니다.
  3. 메모리 대역폭(memory bandwidth)GPU에서 처리 장치로 바이트를 이동시키는 속도로, GB/s로 측정됩니다.

GPU 성능은 기하급수적으로 빠르게 증가하고 있습니다. NVIDIA 문서에 따르면, 2018년에 출시된 T4 그래픽 카드(graphics card)는 65 TFLOPS, 각각 64KB의 L1 캐시(cache)를 가진 40개의 스트리밍 멀티프로세서(SM), 1.3TB/s의 대역폭(bandwidth)을 가진 4MB의 L2 캐시(cache), 그리고 300 GB/s 대역폭의 16GB HBM을 갖추고 있었습니다. 불과 2년 후, A100이 출시되었으며, 이는 312 TFLOPS, 192KB의 L1 캐시(cache)를 가진 108개의 SM, 1.55TB/s 대역폭의 80GB HBM과 40MB의 L2 캐시(cache)를 제공했습니다. 이 숫자들을 최신 B100 카드(card)와 비교하면, 1.8 PFLOPS의 성능을 제공하며, 192GB의 HBM과 8TB/s의 대역폭(bandwidth)을 가지고 있습니다.

![[Pasted image 20241015090942.png]]

메모리 접근에 필요한 시간은 장치, 수정 사항, 그리고 인프라 설정에 따라 달라질 수 있습니다. 그러나 중요한 점은 처리량(throughput) 숫자를 비교할 때, 그 차이가 몇 배 이상 나는 경우가 있다는 것입니다:

  • L1 캐시(cache) / SRAM에서 데이터를 읽는 시간: x 나노초(ns).
  • L2 캐시(cache)에서 데이터를 읽는 시간: 2-3x 나노초(ns).
  • HBM 메모리(HBM memory)에서 데이터를 읽는 시간: 10x 나노초(ns).
  • NVLink를 통해 GPU 간에 데이터를 공유하는 시간(양방향): 50-60x 나노초(ns).
  • PCIe 버스(PCIe bus)를 통해 CPU에서 GPU DRAM으로 데이터를 로드하는 시간: 약 300x 나노초(ns).

이 숫자들은 초당 수행되는 연산의 수가 중요하지만, 피연산자(operand)의 위치가 추론 속도(inference speed) 최적화에 있어서 더 중요할 수 있다는 것을 보여줍니다. 느린 메모리가 항상 성능 병목을 지배한다는 점을 명심하세요.

산술 밀도(arithmetic intensity) vs ops:byte

계산과 메모리 접근의 균형에 따라, 연산은 다음과 같이 분류될 수 있습니다:

  • 계산 집중(compute-bound): 산술 연산에 소비된 시간이 메모리 접근과 같은 다른 작업에 소비된 시간을 초과하는 경우. 일반적인 예로는 큰 내부 차원을 가진 선형 레이어(linear layer) 또는 많은 채널을 가진 컨볼루션 레이어(convolutional layer)가 있습니다.
  • 메모리 집중(memory-bound): 메모리 접근에 걸리는 시간이 계산 시간보다 긴 경우. 대부분의 연산이 여기에 해당하며, 예로는 요소별 연산(elementwise operations)(활성화 함수, 드롭아웃) 또는 리덕션 연산(reductions)(합, 소프트맥스, 정규화)이 있습니다.
  • 오버헤드 집중(overhead-bound): 통신, 인터프리터 등 기타 모든 작업이 여기에 속합니다. 이 글에서는 논의하지 않겠지만, Making Deep Learning Go Brrrr From First Principles 블로그 글을 읽어보길 권장합니다. 이 글은 GPU 메커니즘을 이해하는 데 도움을 줄 것이며, 대부분의 경우 병목 현상이 이들과는 관련이 없다는 점을 보여줍니다.

첫 번째 두 가지 사이의 균형은 일반적으로 산술 밀도(arithmetic intensity)로 측정되며, 이는 연산을 수행하는 데 필요한 메모리 접근당 산술 연산의 수입니다. 예를 들어, 입력 텐서(tensor)ReLU 활성화 함수를 적용한다고 가정해 봅시다(반정밀도 조건). 이 경우, 각 텐서(tensor) 요소마다 다음이 필요합니다:

  • 2 바이트를 읽음
  • 1 비교를 수행
  • 2 바이트를 씀

xx의 크기와 상관없이 ReLU의 산술 밀도는 #flops#bytes=14\frac{\# \operatorname{flops}}{\# \operatorname{bytes}}=\frac{1}{4}와 같습니다. 이는 각 연산에 대해 4번의 메모리 접근이 필요하다는 것을 의미합니다.

산술 밀도는 하드웨어에 특화된 ops:byte 비율과 비교되어, 계산 집중 상황인지 메모리 집중 상황인지 확인합니다. 이 작동 방식을 설명하기 위해, A100 GPU에서 선형 레이어(linear layer)의 순방향 패스를 예로 들어보겠습니다. 입력 배치(batch) xRB×dx \in \mathbb{R}^{B \times d}가중치 행렬(weight matrix) WRd×d\mathbf{W} \in \mathbb{R}^{d \times d}(여기서 BB배치 크기(batch size)이고 dd임베딩 차원(embedding dimension))이 주어졌을 때, 선형 레이어(linear layer)는 기본적으로 행렬 곱(matrix multiplication) xWx\mathbf{W}를 나타냅니다. 선형 레이어(linear layer) 계산에는 2Bd22Bd^2 플롭스(flops)가 필요합니다. 따라서 A100에서의 계산 시간은

Tcompute=#flopscomputeperformance=2Bd23121012s.T_{\operatorname{compute}} = \frac{\# \operatorname{flops}}{\operatorname{compute performance}} = \frac{2Bd^2}{312 \cdot 10^{12}} s.

동시에 가중치 행렬(weight matrix) W\mathbf{W}를 로드하기 위해 메모리에서 2d22d^2 바이트를 읽어야 합니다(여전히 fp16/bf16으로 작업하는 조건에서). 또한, 간단히 하기 위해 BdB \ll d라고 가정하고, W\mathbf{W}와 비교했을 때 xx의 로딩 시간을 무시하겠습니다. 모델 파라미터(model parameters)는 보통 HBM에 저장되므로,

Tmemory=#bytesmemorybandwidth=2d21.551012s.T_{\operatorname{memory}} = \frac{\# \operatorname{bytes}}{\operatorname{memory bandwidth}} = \frac{2d^2}{1.55 \cdot 10^{12}} s.

산술 밀도는 #flops#bytes\frac{\# \operatorname{flops}}{\# \operatorname{bytes}}로 정의되며, ops:bytecomputeperformancememorybandwidth\frac{\operatorname{compute performance}}{\operatorname{memory bandwidth}}로 주어집니다. 모델의 병목 현상을 찾기 위해 이 두 항의 비율을 살펴보면

TcomputeTmemoryB200.\frac{T_{\operatorname{compute}}}{T_{\operatorname{memory}}} \approx \frac{B}{200}.

이는 배치 크기(batch size)가 보다 작을 때, 시스템 성능이 메모리 집중(memory-bound) 상태라는 것을 의미합니다. 입력 배치를 보다 큰 값으로 늘리면, 메모리 전송 시간(memory transfer time)은 일정하게 유지되면서 계산 시간이 증가하여 계산 집중(compute-bound) 상황으로 전환됩니다.

ops:byte 비율 분석은 유용하지만, 이는 GPU 작업 부하가 충분히 커서 계산 및 메모리 파이프라인을 완전히 활용할 수 있을 때를 가정합니다. 작업 부하가 충분히 크지 않거나 병렬성이 부족한 경우, 프로세서는 충분히 활용되지 않으며 성능은 지연(latency)에 의해 제한됩니다.

고수준 알고리즘 최적화(High-level algorithmic optimizations)

이제 트랜스포머(transformer) 최적화의 구체적인 내용에 들어갈 준비가 되었습니다. 이전 블로그 게시물에서 트랜스포머(transformer) 아키텍처를 정의한 바 있습니다. 간략하게 다시 상기하자면, 스케일드 닷 프로덕트 어텐션(scaled dot product attention) 연산은 쿼리 Q\mathbf{Q}, 키 K\mathbf{K}, 그리고 값 V\mathbf{V}의 세트를 입력으로 받아 다음과 같은 출력을 생성합니다:

Attention(Q,K,V)=softmax(QKTd)V,\operatorname{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \operatorname{softmax} \Big( \frac{\mathbf{QK}^T}{\sqrt{d}} \Big) \cdot \mathbf{V},

여기서 dd쿼리(queries)키(keys)의 숨겨진 차원(hidden dimensionality)을 의미합니다. GPT 기반 모델(GPT-based models)을 사용할 때는 마스킹된 어텐션(masked attention)을 사용합니다. 이때 소프트맥스(softmax)의 입력은 마스크(mask) 텐서(tensor)로 수정되어, 특정 토큰에 주의를 기울이지 않으려면 해당 마스킹된 어텐션(masked attention) 값을 -\infty로 설정합니다. 입력 텐서(tensor)QRL×d\mathbf{Q} \in \mathbb{R}^{L \times d}K,VRM×d\mathbf{K}, \mathbf{V} \in \mathbb{R}^{M \times d}이며, 여기서 LLMM은 시퀀스 길이(sequence lengths)를 나타냅니다.

또한, 멀티 헤드 어텐션 레이어(Multi-head attention layer, MHA)의 정의를 다시 한 번 살펴보겠습니다:

MultiHead(Q,K,V)=[head1;;headk]WO,\operatorname{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V})=[\operatorname{head}_1; \dots; \operatorname{head}_k] \cdot \mathbf{W}^O,

여기서

headi=Attention(QWiQ,KWiK,VWiV),i=1,,h.\operatorname{head}_i = \operatorname{Attention}(\mathbf{QW}_i^Q, \mathbf{KW}_i^K, \mathbf{VW}_i^V), \quad i = 1, \dots, h.

학습 가능한 파라미터(learnable parameters) W1hQ,W1hK,W1hV\mathbf{W}^Q_{1 \dots h}, \mathbf{W}^K_{1 \dots h}, \mathbf{W}^V_{1 \dots h}WO\mathbf{W}^O가 주어집니다. 만약 MHA가 Q=K\mathbf{Q} = \mathbf{K} (보통 =V= \mathbf{V})을 받으면, 이를 멀티 헤드 셀프 어텐션(multi-head self-attention)이라고 부르며, 그렇지 않으면 멀티 헤드 크로스 어텐션(multi-head cross-attention)이라고 부릅니다. 우리는 생성적 대규모 언어 모델(generative LLM)에서 널리 사용되는 셀프 어텐션(self-attention) 메커니즘에 집중할 것입니다.

이제 우리는 핵심 어텐션 메커니즘(attention mechanism)에 초점을 맞출 것입니다. 표기법을 단순화하기 위해 새로운 텐서(tensor) 이름을 도입하겠습니다: 닷 프로덕트(dot product)S:=QKTRL×L\mathbf{S} := \mathbf{QK}^T \in \mathbb{R}^{L \times L}, 정규화된 어텐션 가중치(normalized attention weights)P:=softmax(Smask)RL×L\mathbf{P} := \operatorname{softmax}(\mathbf{S} \otimes \text{mask}) \in \mathbb{R}^{L \times L} (mask\text{mask}S\mathbf{S}로 브로드캐스트 가능함), 그리고 출력(output)을 O:=PVRL×d\mathbf{O} := \mathbf{PV} \in \mathbb{R}^{L \times d}로 정의합니다.

KV 캐시(KV Cache)

GPT와 같은 모델을 사용할 때, 텍스트 생성은 두 단계로 이루어집니다:

  1. 프리필(prefill) - 모델이 대량의 프롬프트 토큰을 병렬로 처리하여 한 번의 패스로 모든 히든 스테이트(hidden states)와 출력을 계산합니다.
  2. 프리필(prefill)이 완료되면 자가 회귀 디코딩(auto-regressive decoding)이 시작됩니다. 디코딩은 순차적 특성 때문에 일반적으로 프리필(prefill)보다 시간이 더 많이 소요됩니다. 응답 토큰이 항상 하나씩 순차적으로 생성되기 때문입니다.

![[Pasted image 20241015094139.png]]
![[Pasted image 20241015094258.png]]
![[Pasted image 20241015094312.png]]
![[Pasted image 20241015094049.png]]
![[Pasted image 20241015094343.png]]
![[Pasted image 20241015094356.png]]
![[Pasted image 20241015094440.png]]
![[Pasted image 20241015094454.png]]
![[Pasted image 20241015094508.png]]
![[Pasted image 20241015094522.png]]
인과적 셀프 어텐션(causal self-attention)의 시퀀스 길이 LL에 따른 텍스트 생성 표현 스케일링 계수 d\sqrt{d}는 생략되었습니다. 프롬프트 토큰(prompt tokens) 전체에 대해 어텐션(attention) 값은 한 번만 계산할 수 있지만, 그 후에는 응답 토큰을 생성하기 위해 순차적으로 한 번에 하나씩 계산해야 합니다.

@jit
def dot_product_attention(query, key, value, mask=None):
    d = query.shape[-1]
    # attn_logits shape is [batch..., num_heads, q_seq_len, kv_seq_len]
    attn_logits = jnp.einsum('...lhd,...mhd->...hlm', query, key)
    attn_logits = attn_logits / jnp.sqrt(d) # normalize logits
    if mask is not None:
        big_neg = jnp.finfo(attn_logits.dtype).min
        attn_logits = jnp.where(mask, big_neg, attn_logits)
    # logits -> weights
    attention = nn.softmax(attn_logits, axis=-1)
    # return weighted sum over values for each query position
    output = jnp.einsum('...hlm,...mhv->...lhv', attention, value)
    return output, attention

MHA(Multi-head attention) 추론이 계산 집중(compute-bound)인지 메모리 집중(memory-bound)인지 확인해 봅시다.

  • 계산 복잡도(computational complexity):
    • 쿼리(query) Q\mathbf{Q}를 계산하기 위해 입력 행렬 xRL×d\boldsymbol{x} \in \mathbb{R}^{L \times d}hh개의 헤드에 걸쳐 행렬(matrix) W1hQRd×dh\mathbf{W}_{1 \ldots h}^Q \in \mathbb{R}^{d \times \frac{d}{h}}로 곱하는데, 이는 O(Ld2)\mathcal{O}\left(L d^2\right) 연산이 필요합니다. K\mathbf{K}V\mathbf{V}도 동일한 계산량이 필요합니다.
    • 어텐션(attention) 계산은 S=QKT\mathbf{S}=\mathbf{Q K}^TO=PV\mathbf{O}=\mathbf{P V} 모두에 대해 O(L2d)\mathcal{O}\left(L^2 d\right)가 필요합니다.

따라서 전체적으로 O(Ld2+L2d)\mathcal{O}\left(L d^2+L^2 d\right) 연산을 수행해야 합니다.

  • 메모리 접근(memory accesses):
    • 입력 xx와 중간 텐서(tensor) Q,K,V,O\mathbf{Q}, \mathbf{K}, \mathbf{V}, \mathbf{O}O(Ld)\mathcal{O}(L d) 바이트를 차지합니다.
    • 어텐션 로짓(attention logits) S\mathbf{S}가중치(weights) P\mathbf{P}는 모든 헤드에 대해 총 O(L2h)\mathcal{O}\left(L^2 h\right) 바이트를 차지합니다.
    • 프로젝션 가중치(projection weights) W1hQ,K,V\mathbf{W}_{1 \ldots h}^{Q, K, V}O(d2)\mathcal{O}\left(d^2\right) 메모리 공간을 필요로 합니다.

접근해야 할 전체 메모리 크기(memory size)는 관련된 모든 텐서(tensor)의 크기의 합과 같으며, 이는 O(Ld+L2h+d2)\mathcal{O}\left(L d+L^2 h+d^2\right) 바이트입니다. 따라서 우리는 산술 밀도가 다음과 비례하게 됩니다:

L2d+Ld2L2h+Ld+d2Lddh\frac{L^2 d+L d^2}{L^2 h+L d+d^2} \xrightarrow[L \gg d]{ } \frac{d}{h}

![[Pasted image 20241015100715.png]]

우리는 이미 현대 GPU 하드웨어가 메모리 대역폭(memory bandwidth)보다 연산 능력이 몇 배 더 높다는 것을 확인했습니다. 그래프가 보여주듯, 시퀀스 길이가 충분히 클 경우 산술 밀도는 항상 어텐션 헤드(attention head)임베딩 차원(embedding dimension)dh\frac{d}{h}보다 크며, 이는 보통 수백에 해당합니다. 따라서 산술 밀도는 ops:byte 비율과 같거나 더 큽니다.

일반적으로 이는 높은 알고리즘 효율성을 의미하지만, 텍스트 생성이라는 두 번째 단계에서는 상황이 다릅니다. 여기서 주목해야 할 첫 번째 점은 생성 시나리오에서는 입력 시퀀스 xx의 각 토큰에 대해 어텐션 출력(attention outputs)을 계산할 필요가 없으며, 다음 토큰 xL+1x_{L+1}을 디코딩하기 위해 마지막 토큰에 대해서만 계산하면 된다는 것입니다. 따라서 전체 쿼리(query) 벡터 Q\mathbf{Q}어텐션 메커니즘(attention mechanism)에 보낼 필요가 없습니다.

두 번째 중요한 점은 이전에 계산된 액티베이션(activations)을 재사용할 수 있다는 것입니다. 즉, 생성 과정(generation process)에서 K\mathbf{K}V\mathbf{V} 값을 캐시할 수 있으므로 이 명칭이 붙은 것입니다. 우리는 KV 캐시(KV cache)를 저장하여 효율성을 높이고, 특히 긴 시퀀스에서 불필요한 계산 요구를 줄일 수 있습니다.

![[Pasted image 20241015101006.png]]
![[Pasted image 20241015101025.png]]
![[Pasted image 20241015101038.png]]
![[Pasted image 20241015101051.png]]
![[Pasted image 20241015101106.png]]
![[Pasted image 20241015101125.png]]
![[Pasted image 20241015101138.png]]
![[Pasted image 20241015101152.png]]
![[Pasted image 20241015101209.png]]
![[Pasted image 20241015101225.png]]
디코딩 단계에서 KV 캐시(KV cache)를 사용하는 인과적 셀프 어텐션(causal self-attention)의 표현입니다. 각 시간 단계에서 마지막 토큰에 대해 계산된 K\mathbf{K}V\mathbf{V}는 캐시에 추가되며, 이후 단계에서 다시 사용됩니다.

각 생성 단계에서 텍스트 생성을 위한 플롭스(flops)메모리 접근(memory accesses) 수를 알아보겠습니다(이 값을 시퀀스 전체에 대해 얻으려면 LL 단계를 곱하면 됩니다).

  • 계산(Compute):
    • 쿼리(query) Q\mathbf{Q}를 계산하기 위해 입력 벡터 xRdx \in \mathbb{R}^dhh개의 헤드에 걸쳐 행렬(matrix) W1..hQRd×dh\mathbf{W}_{1 . . h}^Q \in \mathbb{R}^{d \times \frac{d}{h}}로 곱하는데, 이는 O(d2)\mathcal{O}\left(d^2\right) 연산이 필요합니다. K\mathbf{K}V\mathbf{V}KV 캐시(KV cache)를 저장할 때 동일합니다.
    • 어텐션(attention) 계산은 S=QKT\mathbf{S}=\mathbf{Q} \mathbf{K}^TO=PV\mathbf{O}=\mathbf{P V} 모두에 대해 최대 O(Ld)\mathcal{O}(L d)가 필요합니다.

따라서 각 단계에서 O(d2+Ld)\mathcal{O}\left(d^2+L d\right) 연산을 수행해야 합니다. 프리필(prefill) 단계에서와 마찬가지로 LL 단계에 대한 연산 수는 동일하게 유지됩니다.

  • 메모리:
    • 입력 x\boldsymbol{x}와 중간 텐서(tensor) Q,O\mathbf{Q}, \mathbf{O}O(d)\mathcal{O}(d) 바이트를 차지하지만, 캐시에 있는 K\mathbf{K}V\mathbf{V}O(Ld)\mathcal{O}(L d) 공간이 필요합니다.
    • 어텐션 로짓(attention logits) S\mathbf{S}가중치(weights) P\mathbf{P}는 모든 헤드에 대해 최대 O(Lh)\mathcal{O}(L h) 바이트를 차지합니다.
    • 프로젝션 가중치(projection weights) W1...hQ,K,V\mathbf{W}_{1 . . . h}^{Q, K, V}는 다시 O(d2)\mathcal{O}\left(d^2\right) 바이트가 필요합니다.

따라서 총 O(Ld+Lh+d2)\mathcal{O}\left(L d+L h+d^2\right) 바이트가 필요합니다. 마지막으로

 arithmetic intensity Ld+d2Ld+Lh+d2<1\text { arithmetic intensity } \propto \frac{L d+d^2}{L d+L h+d^2}<1

이는 분명히 ops:byte 비율보다 작으며, 우리는 메모리 집중(memory-bound) 상태에 있게 됩니다. L1L-1개의 쿼리(queries)를 제거하여 연산 수를 LL배 줄였음에도 불구하고, 메모리 접근(memory accesses) 수는 그만큼 줄어들지 않았습니다. 그 이유는 각 단계에서 KV 캐시(KV cache)의 모든 값을 검색해야 하기 때문이며, 이 캐시의 크기는 시퀀스 길이에 비례하여 증가합니다.

이로 인해 또 다른 단점이 발생합니다. KV 캐시(KV cache)를 저장하려면 많은 HBM(high bandwidth memory) 용량이 필요합니다. 예를 들어, 트랜스포머(transformer)에서 nn 개의 레이어로 디코딩을 실행할 때, O(Lnd)\mathcal{O}(L n d) 바이트의 KV 캐시(KV cache)를 저장해야 합니다. 따라서 이를 수용할 충분한 메모리(memory)가 있는지 확인하거나, CPU DRAM에서 로드해야 합니다. 그러나 CPU DRAM에서 읽는 것은 HBM에서 읽는 것보다 10배에서 100배 정도 느립니다.

실제 사례: A100 GPU 80GB HBM을 가진 A100 GPU를 예로 들어보겠습니다. 우리가 GPT-3 모델(n=96n=96, d=12,288d=12,288)로 작업하고, 시퀀스 길이 L=4096L=4096을 맞추려고 한다고 가정해 봅시다. 그러면 추가로 필요한 공간은 다음과 같습니다:

2K/V2 float16 4096296 sequence length 12,288 number of layers =19,327,352,832. embedding dimension \underset{\mathbf{K} / \mathrm{V}}{2} \cdot \underset{\text { float16 }}{2} \cdot \frac{4096}{2} \cdot \underset{\text { sequence length }}{96} \cdot \underset{\text { number of layers }}{12,288}=\underset{\text { embedding dimension }}{19,327,352,832 .}

따라서 단일 시퀀스 샘플에 대한 KV 캐시(KV cache)A100 메모리 공간의 18GB, 즉 22.5%가 필요합니다. 대부분의 GPU 공간이 모델 파라미터(model parameters)로 채워질 것이라는 점을 감안하면, 배치 크기를 늘리지 않아도 금방 메모리(memory)가 부족해질 수 있습니다.

멀티 쿼리 / 그룹 쿼리 어텐션(Multi-query / Grouped-query attention)

표준 어텐션 메커니즘(attention mechanism)에서는 각 쿼리 벡터(query vector)에 대해 KV 쌍(KV pairs)이 독립적으로 계산됩니다. 이는 입력 시퀀스의 각 토큰에 대해 별도의 키-값(key-value) 쌍이 계산되고 캐시에 저장된다는 것을 의미합니다. 하지만 많은 경우, 서로 다른 쿼리 벡터(query vectors)가 유사한 어텐션 패턴(attention patterns)을 공유할 수 있으며, 이 경우 해당 키(key)값(value)을 여러 쿼리(queries)에서 재사용할 수 있습니다. 멀티 쿼리 어텐션(Multi-query attention, MQA)(Shazeer, 2019)은 여러 쿼리(queries)에 걸쳐 캐시된 KV 쌍(KV pairs)을 공유하여 텍스트 생성 중 KV 캐시(KV cache)와 관련된 메모리 요구량을 크게 줄입니다.

MQA는 메모리 소비를 줄일 뿐만 아니라 추론 처리량(inference throughput)도 증가시킵니다. 알고리즘적 관점에서 보면 행렬 곱(matrix multiplication)의 수는 동일하게 유지되고, 단지 서로 다른 헤드에서 K\mathbf{K}V\mathbf{V}를 재사용할 뿐이므로 계산 복잡도는 변하지 않습니다. 그러나 메모리 측면에서는 KV 캐시(KV cache)가 이제 O(Ldh)\mathcal{O}\left(L \frac{d}{h}\right) 공간만 필요하며, 산술 밀도는 다음과 비례하게 됩니다:

Ld+d2Ldh+Lh+d2Lddhd+h2h\frac{L d+d^2}{L \frac{d}{h}+L h+d^2} \xrightarrow[L \gg d]{ } \frac{d h}{d+h^2} \approx h

이는 LL이 증가함에 따라 점차 커지며 몇 배의 플래토(plateau)에 도달하게 됩니다. 대부분의 경우 여전히 메모리 집중(memory-bound) 상태일 수 있지만, 멀티 쿼리 어텐션(Multi-query attention) 기법을 사용하면

  • KV 캐시(KV cache) 크기를 hh배 줄일 수 있고,
  • 디코딩 알고리즘의 속도를 최대 hh배까지 빠르게 만들 수 있습니다.

![[Pasted image 20241015102038.png]]
산술 밀도(arithmetic intensity) vs 시퀀스 길이 LL 멀티 쿼리(Multi-query)멀티 헤드 셀프 어텐션(multi-head self-attention) 메커니즘의 산술 밀도를 시퀀스 길이 LL에 따라 자동 회귀 생성(auto-regressive generation) 동안 분석한 결과로, 임베딩 차원 d=12,288d=12,288 및 헤드 수 h=96h=96 기준입니다.

위에서 설명한 GPT-3 모델의 예시에서 멀티 쿼리 어텐션(MQA)을 사용하면 KV 캐시(KV cache)h=96h=96배 더 작아지며, 필요한 공간은 약 200MB로 줄어들어 A100 GPU 메모리의 단 0.25%만을 차지하게 됩니다.

물론, 이러한 가속과 메모리 절감에는 대가가 따릅니다. 모델 파라미터(model parameters)가 줄어들기 때문에 모델의 잠재적 용량도 감소하게 됩니다. 품질 저하를 방지할 수 있는 한 가지 방법은 멀티 헤드 어텐션(MHA)멀티 쿼리 어텐션(MQA) 사이를 보간하는 기술을 사용하는 것입니다. 그 중 하나가 그룹 쿼리 어텐션(Grouped Query Attention, GQA)입니다(Ainslie et al., 2023). GQA에서는 hh개의 쿼리 헤드(query heads)를 각각 고유한 키(keys)값(values)을 가지는 gg개의 그룹으로 나눕니다. 참고로, g=1g=1일 때 GQA멀티 쿼리 어텐션(MQA)과 동일하며, g=hg=h일 때 GQA멀티 헤드 어텐션(MHA)와 동일합니다. gg의 선택은 메모리 절약과 잠재적 정확도 손실 사이의 트레이드오프를 나타냅니다. 더 큰 그룹 크기는 더 많은 메모리 절약을 가져오지만 어텐션 계산(attention computations)에서 더 큰 근사 오류를 초래할 수 있습니다. 실제로, 최적의 그룹 크기는 특정 모델 아키텍처와 메모리 효율성 및 모델 성능 간의 트레이드오프를 기반으로 경험적으로 결정해야 할 수 있습니다.

![[Pasted image 20241015102532.png]]

@jit
def gqa_dot_product_attention(query, key, value, mask=None):
    num_heads, num_kv_heads = query.shape[-2], key.shape[-2]
    # broadcast K/V heads to match number of Q heads
    num_heads_per_kv = num_heads // num_kv_heads
    key = jnp.repeat(key, num_heads_per_kv, axis=-2)
    value = jnp.repeat(value, num_heads_per_kv, axis=-2)
    return dot_product_attention(query, key, value, mask)

다음은 입력 xRB×L×dx \in \mathbb{R}^{B \times L \times d}와 큰 컨텍스트 크기 LL에 대한 배치 디코딩/추론 알고리즘의 복잡도를 비교한 표입니다. 주목할 점은, 표에 있는 모든 알고리즘의 계산 복잡도는 동일하다는 것입니다! 하지만 실제 효율성은 설정에 따라 크게 달라질 수 있습니다.

![[Pasted image 20241015105957.png]]

청크를 사용한 프리필(prefill with chunking)

KV 캐시(KV cache)가 시퀀스 길이가 증가함에 따라 발생하는 메모리 문제의 유일한 원인은 아닙니다. 프리필(prefill) 단계에서는 모든 출력과 KV\mathbf{K V} 쌍을 한 번에 계산합니다. 이는 어텐션 매트릭스(attention matrix) SRL×L\mathbf{S} \in \mathbb{R}^{L \times L}를 계산해야 하며, 이 값은 컨텍스트 길이에 따라 이차적으로 증가합니다. 만약 프롬프트 크기가 너무 커서 모든 어텐션 가중치(attention weights)를 메모리에 담을 수 없다면 어떻게 할까요? 디코딩 단계에서처럼 토큰을 하나씩 통과시키면서 계산할 수 있지만, 이 절차는 메모리 집중(memory-bound) 상태이기 때문에 훨씬 느립니다. 하지만 미래의 토큰을 미리 알고 있기 때문에, 우리는 이를 청크 단위로 모델에 입력할 수 있습니다:

  • 첫 번째 CC개의 토큰(C<LC<L)을 프롬프트에서 가져와 프리필(prefill) 단계를 거치고, 그들의 KV 캐시(KV cache) 값을 저장합니다. 이때 어텐션 가중치(attention weights)SRC×C\mathbf{S} \in \mathbb{R}^{C \times C}가 됩니다.
  • 그 다음, 다음 CC개의 토큰에 동일한 절차를 적용하지만, 이전 청크에서 저장된 KV 쌍(KV pairs)을 사용하여 해당 토큰들에 주의를 기울입니다. 이때 어텐션 가중치(attention weights)SRC×2C\mathbf{S} \in \mathbb{R}^{C \times 2 C}가 됩니다.
  • 이렇게 전체 프롬프트가 프리필(prefill) 될 때까지 반복합니다. 마지막에 S\mathbf{S}의 최대 크기는 C×LC \times L이 됩니다.
    ![[Pasted image 20241015110151.png]]

청크(chunking)를 사용하면 어텐션 매트릭스(attention matrix) SS의 최대 크기는 LL에 선형적으로 의존하며, 이는 제어 가능한 상수 계수 CC에 의해 곱해집니다.

슬라이딩 윈도우 어텐션(Sliding window attention)

멀티 쿼리 어텐션(multi-query attention)프리필 청크(prefill-chunking)를 사용하더라도, KV 캐시(KV cache)어텐션 가중치(attention weights)프리필 단계(prefill phase)디코딩 단계(decoding phase)에서 계속해서 컨텍스트가 증가함에 따라 커집니다. 각 토큰이 주의를 기울일 수 있는 토큰의 수를 일정한 상수 LwL_w로 제한한다면, 메모리 요구량은 입력 시퀀스 길이에 의존하지 않게 됩니다. 이것이 바로 슬라이딩 윈도우 어텐션(sliding window attention) 기술의 핵심입니다. 이 방법은 어텐션 마스크(attention mask)를 하삼각 행렬에서 대각선 근처의 밴드 행렬로 변경합니다.
![[Pasted image 20241015110310.png]]

이 방식은 어텐션 레이어(attention layer)가 로컬 컨텍스트에만 집중하게 만듭니다. 하지만 토큰은 암묵적으로 이전 nLwn \cdot L_w 토큰에 주의를 기울일 수 있습니다. 여기서 n\boldsymbol{n}트랜스포머(transformer) 모델의 레이어 수입니다. 이는 컨볼루션 네트워크(convolutional networks)에서 수용 영역(receptive field)이 작동하는 방식과 매우 유사합니다. 최적의 윈도우 크기를 선택할 때는 메모리 효율성과 컨텍스트 유지 간의 트레이드오프가 발생합니다. 윈도우 크기가 클수록 더 많은 컨텍스트를 유지할 수 있지만 더 많은 메모리를 요구하고, 윈도우 크기가 작으면 메모리 효율은 높아지지만 일부 컨텍스트를 잃을 수 있습니다.

슬라이딩 윈도우 어텐션(sliding window attention)을 사용할 때, 롤링 KV 캐시(rolling KV cache)도 사용할 수 있습니다. KV 캐시(KV cache)는 이제 주어진 상수 2Lw2 \cdot L_w로 제한되며, 각 단계에서 하나의 KV\mathbf{K V} 쌍만 변경되기 때문에, 더 이상 주의를 기울이지 않는 가장 오래된 토큰과 관련된 쌍을 제거하고, 가장 최신의 쌍으로 교체할 수 있습니다. 실무에서는 쓰기 포인터를 가장 오래된 쌍에 두고, 교체 후에 한 칸씩 이동시킵니다. 버퍼의 끝에 도달하면, 포인터를 처음 위치로 다시 이동시킵니다.

![[Pasted image 20241015110441.png]]
![[Pasted image 20241015110459.png]]
![[Pasted image 20241015110520.png]]
![[Pasted image 20241015110544.png]]
![[Pasted image 20241015110558.png]]
![[Pasted image 20241015110623.png]]
![[Pasted image 20241015110655.png]]
![[Pasted image 20241015110710.png]]
![[Pasted image 20241015110729.png]]
![[Pasted image 20241015110755.png]]

또 다른 슬라이딩 윈도우 어텐션(sliding window attention)의 장점은 프리필(prefill) 단계에서 청크(chunking)와 결합하면 어텐션 매트릭스(attention matrix)의 최대 크기를 일정하게 유지할 뿐만 아니라 (SRC×Lw)\left(\mathbf{S} \in \mathbb{R}^{C \times L_{\mathrm{w}}}\right), 닷 프로덕트(dot-products)의 계산 횟수도 줄일 수 있다는 점입니다.

슬라이딩 윈도우 어텐션(sliding window attention)의 단점은 모든 토큰 간 상호작용을 포착하지 못해 성능 저하가 발생할 수 있다는 것입니다. Xiao et al.(2024)이 발견한 흥미로운 현상은 이를 어텐션 싱크(attention sink)라고 부르며, 시퀀스 시작 부분에 있는 소수의 토큰에 대한 KV를 유지하면 윈도우 어텐션(window attention) 성능을 크게 회복할 수 있다는 것입니다. 그들은 대규모 언어 모델(LLMs)이 초기 토큰에 대해 강한 어텐션 스코어(attention scores)를 출력하는데, 이는 의미적으로 중요하지 않은 토큰임에도 불구하고 "싱크(sink)" 역할을 한다고 관찰했습니다.

선형 어텐션(Linear attention)

선형 어텐션 메커니즘(linear attention mechanism)(Katharopoulos et al.(2020))은 긴 시퀀스에 대한 O(L2)\mathcal{O}\left(L^2\right) 스케일링을 피하기 위한 대안적 방법군입니다. 선형 어텐션(linear attention)표준 어텐션 메커니즘(standard attention mechanism)을 근사하면서 시간과 공간 복잡도가 선형인 결과를 달성합니다. 핵심 아이디어는 어텐션 연산(attention operation)을 행렬 곱셈의 결합 법칙(associative property)커널 함수(kernel functions)를 사용하여 재구성하는 것입니다.

커널 함수(kernel function) K(q,k)\mathcal{K}(q, k)는 입력 qqkk 간의 유사성을 측정하는 지표로, 모든 어텐션 메커니즘(attention mechanism)에서와 동일한 방식으로 작동합니다. 계산을 단순화하기 위해, 커널 함수(kernel function)는 종종 피처 맵(feature map) ϕ\phi의 형태로 표현할 수 있는 방식으로 선택됩니다:

K(q,k)=ϕ(q)Tϕ(k)\mathcal{K}(q, k)=\phi(q)^T \phi(k)

이러한 피처 맵(feature map)을 찾을 수 있다면, 전체 어텐션 매트릭스(attention matrix) QKT\mathbf{Q K}^{\mathbf{T}}를 명시적으로 계산하지 않고도 쿼리(queries)키(keys) 간의 유사성을 암묵적으로 계산할 수 있습니다.

어텐션 메커니즘(attention mechanism)에서 쿼리 임베딩(query embedding) q\mathbf{q}키 임베딩(key embedding) k\mathbf{k} 사이의 정규화되지 않은 유사성은 다음과 같이 측정됩니다: K(q,k)=exp(qTkd)\mathcal{K}(\mathbf{q}, \mathbf{k})=\exp \left(\frac{\mathbf{q}^T \mathbf{k}}{\sqrt{d}}\right). 소프트맥스 마스크 어텐션 매트릭스(softmax masked attention matrix)의 각 요소 P=softmax(\mathbf{P}=\operatorname{softmax}( mask S)\otimes \mathbf{S}), 즉 쿼리 행(query row) Qi\mathbf{Q}_i키 행(key row) Kj(ji)\mathbf{K}_j(j \leq i) 간의 정규화된 유사성은 다음과 같이 표현될 수 있습니다:

Pij=K(Qi,Kj)jiK(Qi,Kj)\mathbf{P}_{i j}=\frac{\mathcal{K}\left(\mathbf{Q}_i, \mathbf{K}_j\right)}{\sum_{j \leq i} \mathcal{K}\left(\mathbf{Q}_i, \mathbf{K}_j\right)}

피처 맵(feature maps)을 사용하여 닷 프로덕트 어텐션(dot-product attention) 출력의 각 행 iiO=PV\mathbf{O}=\mathbf{P V}로 다시 쓸 수 있습니다:

Oi=jiPijVj=jiK(Qi,Kj)VjjiK(Qi,Kj)=jiϕ(Qi)Tϕ(Kj)Vjjiϕ(Qi)Tϕ(Kj)=ϕ(Qi)Tjiϕ(Kj)VjTϕ(Qi)Tjiϕ(Kj)=ϕ(Qi)TUiϕ(Qi)TZi.\begin{aligned} \mathbf{O}_{i} &= \sum_{j \leq i} \mathbf{P}_{ij} \mathbf{V}_j \\ &= \frac{\sum_{j \leq i} \mathcal{K}(\mathbf{Q}_i, \mathbf{K}_j) \cdot \mathbf{V}_{j}}{\sum_{j \leq i} \mathcal{K}(\mathbf{Q}_i, \mathbf{K}_j)} \\ &= \frac{\sum_{j \leq i} \phi(\mathbf{Q}_i)^T \phi(\mathbf{K}_j) \cdot \mathbf{V}_{j}}{\sum_{j \leq i} \phi(\mathbf{Q}_i)^T \phi(\mathbf{K}_j) } \\ &= \frac{ \phi(\mathbf{Q}_i)^T \cdot \color{Salmon}{\sum_{j \leq i} \phi(\mathbf{K}_j) \mathbf{V}^T_{j}} }{ \phi(\mathbf{Q}_i)^T \cdot \color{#007BA7}{\sum_{j \leq i} \phi(\mathbf{K}_j)}} \\ &= \frac{ \phi(\mathbf{Q}_i)^T \cdot \color{Salmon}{\mathbf{U}_i}}{ \phi(\mathbf{Q}_i)^T \cdot \color{#007BA7}{\mathbf{Z}_i} }. \end{aligned}

위 식은 분자를 다음과 같이 벡터화된 형태로 쓰면 더 간단해집니다:

(ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)\left(\phi(\mathbf{Q}) \phi(\mathbf{K})^T\right) \mathbf{V}=\phi(\mathbf{Q})\left(\phi(\mathbf{K})^T \mathbf{V}\right)

시퀀스 길이 LL의 값과 상관없이, 더 이상 이차적으로 증가하는 어텐션 매트릭스(attention matrix)를 저장할 필요가 없으며, 우리는 O(d2)\mathcal{O}\left(d^2\right) 공간만 필요합니다. 이는 UL=ϕ(K)TVRd×d\mathbf{U}_L=\phi(\mathbf{K})^T \mathbf{V} \in \mathbb{R}^{d \times d}입니다.

![[Pasted image 20241015111253.png]]
![[Pasted image 20241015111308.png]]
![[Pasted image 20241015111322.png]]
![[Pasted image 20241015111333.png]]
![[Pasted image 20241015111349.png]]
![[Pasted image 20241015111406.png]]
![[Pasted image 20241015111422.png]]
![[Pasted image 20241015111438.png]]
![[Pasted image 20241015111453.png]]
![[Pasted image 20241015111507.png]]
프리필(prefill) 단계나 디코딩(decoding) 단계에서 선형 어텐션(linear attention)을 사용할 때는 더 이상 O(L2)\mathcal{O}\left(L^2\right) 공간이 필요하지 않습니다. 스칼라 분모 ϕ(QL)TZL\phi\left(\mathbf{Q}_L\right)^T \cdot \mathbf{Z}_L는 여기서 생략되었습니다.

피처 맵(feature maps)의 도입으로 또 하나의 흥미로운 속성이 등장합니다: 선형 어텐션(linear attention) 계산은 반복적으로 표현할 수 있습니다. 우리는 다음과 같은 관계를 가집니다:

Ui=Ui1+ϕ(Ki)ViTZi=Zi1+ϕ(Ki)\begin{aligned} & \mathbf{U}_i=\mathbf{U}_{i-1}+\phi\left(\mathbf{K}_i\right) \mathbf{V}_i^T \\ & \mathbf{Z}_i=\mathbf{Z}_{i-1}+\phi\left(\mathbf{K}_i\right) \end{aligned}

여기서 U0,Z0\mathbf{U}_0, \mathbf{Z}_0는 모두 0 값입니다. 이를 통해 자동 회귀 디코딩(auto-regressive decoding) 중에도 상수 크기의 은닉 상태(hidden states) U\mathbf{U}Z\mathbf{Z}만 유지하면 어텐션을 계산할 수 있으며, 선형적으로 증가하는 입력을 모델에 제공할 필요가 없습니다.

Hedgehog & the Porcupine

피처 맵(feature map) ϕ\phi의 선택은 다음과 같이

ϕ(q)ϕ(k)Texp(qkT)\phi(q) \phi(k)^T \approx \exp \left(q k^T\right)

쉽지 않은 작업입니다. 비록 선형 어텐션(linear attention)이 계산 복잡도를 줄일 수 있지만, 커널 근사(kernel approximation)완전 어텐션(full attention)의 중요한 속성을 포착하지 못하면 모델 성능이 저하될 수 있습니다.

원래 선형 어텐션(linear attention)의 저자들은 실험에서 ϕ(x)=ELU(x)+1\phi(x)=\operatorname{ELU}(x)+1피처 맵(feature map)으로 사용했습니다. 또 다른 옵션은 표준 ReLU 함수를 사용하는 것입니다(하지만 이는 음수 입력에 대해 그래디언트를 0으로 만듭니다). 이러한 선택은 단순하고 효과적인 계산을 가능하게 하지만, Zhang et al.(2024)은 이러한 피처 맵(feature maps)이 소프트맥스 어텐션과 달리 다음과 같은 두 가지 중요한 속성을 잃게 된다고 지적했습니다:

  • 낮은 엔트로피 스파이키함(low-entropy "spikyness"): 직관적으로, 어텐션은 관련된 토큰에만 주의를 기울이고 관련 없는 토큰은 무시해야 합니다.
  • 닷 프로덕트 단조성(dot-product monotonicity): 어텐션 가중치(attention weights)는 해당 쿼리(queries)키(keys)의 닷 프로덕트가 증가함에 따라 증가해야 합니다. 이 단조성이 없으면 훈련 중에 불안정한 그래디언트를 생성할 수 있습니다.

그들은 차수 2의 테일러 근사(Taylor approximation)로 정의된 지수 함수 ϕtaylor (x)=[1,x1,,xd][xixji,j[d]]\phi_{\text {taylor }}(\mathbf{x})=\left[1, x_1, \ldots, x_d\right] \cup\left[x_i \cdot x_j \mid i, j \in[d]\right]dd 차원 벡터 x\mathbf{x}에 대해 스파이키함과 단조성을 모두 유지하며, 소프트맥스 어텐션 성능과 거의 일치하는 성능을 낸다고 관찰했습니다. 하지만 ϕtaylor (x)\phi_{\text {taylor }}(\mathbf{x})R1+d+d2\mathbb{R}^{1+d+d^2} 공간에 매핑되어, O(Ld3)\mathcal{O}\left(L d^3\right) 복잡도를 가지게 되며, 임베딩 차원이 증가함에 따라 계산 비용이 크게 늘어납니다.

이 문제를 해결하기 위해 그들은 Hedgehog를 제안했습니다. 이는 지수 활성화 함수를 가진 학습 가능한 선형 레이어(learnable linear layer)로, 이러한 속성을 학습하고 소프트맥스 어텐션 가중치를 모방하도록 훈련되었습니다:

ϕmlp (x)=exp(xW)\phi_{\text {mlp }}(\mathbf{x})=\exp (\mathbf{x} \mathbf{W})

소프트맥스 근사를 학습하기 위해, 그들은 ϕmlp (x)\phi_{\text {mlp }}(\mathbf{x})를 훈련시켜 계산된 선형 어텐션 가중치(linear attention weights)소프트맥스 마스크 어텐션(softmax masked attention) P\mathbf{P}를 통해 계산된 가중치 간의 교차 엔트로피 손실(cross-entropy loss)을 최소화합니다:

Li=jiPijlogϕmlp(Qi)Tϕmlp(Kj)jiϕmlp(Qi)Tϕmlp(Kj)\mathcal{L}_i=-\sum_{j \leq i} \mathbf{P}_{i j} \cdot \log \frac{\phi_{\mathrm{mlp}}\left(\mathbf{Q}_i\right)^T \phi_{\mathrm{mlp}}\left(\mathbf{K}_j\right)}{\sum_{j \leq i} \phi_{\mathrm{mlp}}\left(\mathbf{Q}_i\right)^T \phi_{\mathrm{mlp}}\left(\mathbf{K}_j\right)}

하드웨어 저수준 최적화(Low-level hardware optimizations)

CUDA 커널 융합(CUDA kernel fusion)

앞서 말했듯이, 대역폭 비용(bandwidth cost)(메모리 내에서 데이터를 이동하는 비용)이 성능과 관련하여 가장 중요한 요소임을 관찰했습니다. 왜 그런지 이해하기 위해 GPU 하드웨어(GPU hardware)를 다시 한번 자세히 살펴보겠습니다.

GPU스레드(thread)라고 불리는 수천 개의 간단한 작업을 병렬로 수행하도록 설계되었습니다. 각 스레드는 GPU에서 가장 빠른 메모리인 레지스터(register)를 가지고 있습니다. 동일한 스트리밍 멀티프로세서(SM, streaming multiprocessors)에서 실행되는 스레드는 스레드 블록(thread block)으로 그룹화되어 동시에 실행될 수 있습니다(한 블록 내 최대 스레드 수는 아키텍처에 의해 제한되며, 보통 1024입니다). 스레드 블록 내 스레드는 전역 메모리(글로벌 메모리(global memory), HBM)에서 데이터를 공유 메모리(shared memory, SRAM)로 로드하고, 공유 메모리는 스레드 간 통신을 위해 사용됩니다. 그런 다음 계산을 수행하고 결과를 다시 글로벌 메모리에 기록합니다.

SM에 하나 이상의 스레드 블록이 할당되면, 이를 워프(warp)로 나눕니다. 워프(warp)는 32개의 스레드(thread)로 구성되며, 워프 내 모든 스레드는 동일한 명령을 실행합니다. 마지막으로 여러 스레드 블록이 결합되어 그리드(grid)를 형성합니다. 동일한 그리드 내 모든 블록은 동일한 스레드 수를 가집니다. 스레드 블록의 스레드 수가 제한되므로, 그리드는 병렬로 작동해야 하는 많은 스레드 블록이 필요한 계산에 사용됩니다.

계산은 커널(kernel)이라고 불리는 작은 C++ 함수로 정의되며, 이 함수는 여러 스레드에 의해 병렬로 여러 번 실행됩니다(일반적인 C++ 함수처럼 한 번만 실행되지 않음). 커널(kernel)그리드(grid)로 실행되며, 서로 다른 커널은 각기 다른 그리드 및 블록 구성(grid and block configurations)을 가질 수 있습니다.

다시 A100을 예로 들어보겠습니다. A100은 108개의 SM을 가지고 있으며, 각 SM은 최대 2048개의 스레드(thread)를 실행할 수 있습니다. 각 SM공유 메모리(shared memory) 용량은 최대 192KB입니다. 이는 큰 행렬(최대 몇 MB)을 작은 청크로 나누어 A100 레지스터(register)SRAM에 맞추고, 초당 18TB의 속도로 행렬 곱셈(matrix multiplication)을 수행할 수 있음을 의미합니다. 이로 인해 GPUCPU보다 행렬 곱셈에서 훨씬 더 효율적입니다.

하지만 딥러닝(deep learning)에서는 행렬 곱셈 외에도 정규화 레이어, 활성화 함수, 드롭아웃과 같은 다양한 연산이 많이 발생합니다. 이 연산들은 총 FLOP의 작은 비율만 차지하지만, GPU에서 훨씬 느리게 실행됩니다. 첫째, 텐서 코어(Tensor Cores)라는 행렬 곱셈에 특화된 유닛이 있어, 행렬 곱셈(matmul)의 처리량이 비행렬 연산보다 최대 16배 높을 수 있습니다. 예를 들어, A100에서 312 TFLOPs의 행렬 곱셈(matmul) 처리량과 19.5 TFLOPs의 비행렬 처리량을 비교할 수 있습니다. 둘째, 이러한 연산들은 매우 메모리 집중(memory-bound) 상태일 수 있습니다.

어텐션(attention)에서 P=softmax(S)RL×L\mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{L \times L}을 생각해봅시다. 시퀀스 길이 LL이 크면, 텐서 S\mathbf{S}는 너무 커서 SRAM에 적재할 수 없기 때문에 HBM에 있어야 합니다. 소프트맥스(softmax)의 가장 단순한 구현은 다음을 요구합니다:

  • S\mathbf{S}L2L^2 플로트를 읽고, exp(S)\exp (\mathbf{S})를 계산한 후 L2L^2 플로트를 다시 HBM에 기록.
  • 행 축을 따라 exp(S)\exp (\mathbf{S})의 합을 구함: L2L^2를 읽고 LL을 기록.
  • exp(S)\exp (\mathbf{S})를 분모로 나누기: L2+LL^2+L을 읽고 L2L^2를 기록.

결과적으로 우리는 5L2+L5 L^2+L 플로트를 이동해야 합니다. 만약 S\mathbf{S}L2L^2 플로트만 읽고 P\mathbf{P}L2L^2 플로트를 기록할 수 있다면, 이 작업을 2.5배 이상 빠르게 수행할 수 있습니다.

여기서 커널 융합(kernel fusion)이 등장합니다. 저대역폭 글로벌 메모리(low-bandwidth global memory)에 계산 출력 y=f(x)y=f(x)를 기록한 후 다시 읽어서 z=g(y)z=g(y)를 구하는 대신, 여러 계산을 한 번에 수행하는 커널(kernel)을 구현할 수 있습니다. 즉, 추가 메모리 액세스 없이 z=(gf)(x)z=(g \circ f)(x)를 수행하는 것입니다. Jax의 XLA 컴파일러는 간단한 융합을 수행할 수 있지만, 프로그래머는 Triton이나 Pallas로 사용자 정의 CUDA 커널을 작성할 수도 있습니다.

메모리 효율적인 어텐션(memory-efficient attention)

온라인 소프트맥스(Online softmax)

함수 y=softmax(x)\mathbf{y}=\operatorname{softmax}(\mathbf{x})는 다음과 같이 정의됩니다:

yi=exijexj.\mathbf{y}_i=\frac{e^{\mathbf{x}_{\mathbf{i}}}}{\sum_j e^{\mathbf{x}_j}} .
@jit
def naive_softmax(logits):
    exp_logits = jnp.exp(logits)
    return exp_logits / exp_logits.sum()

소프트맥스(softmax)의 단순한 구현은 x\mathbf{x}를 두 번 스캔합니다. 첫 번째는 정규화 항을 계산하고, 두 번째는 출력 벡터 y\mathbf{y}를 계산하는 과정입니다. 그러나 실제 하드웨어에서 이러한 구현에는 심각한 결함이 있습니다: xi89\mathbf{x}_i \geq 89일 때, bf16fp32에서 지수 함수 계산 결과가 무한대가 됩니다. 이를 방지하기 위한 트릭이 있습니다. 모든 상수 mm에 대해 다음을 참고하세요:

softmax(x)i=exijexj=exijexjemem=eximjexjm=softmax(xm)i\begin{aligned} \operatorname{softmax}(\mathbf{x})_i & =\frac{e^{\mathbf{x}_i}}{\sum_j e^{\mathbf{x}_j}} \\ & =\frac{e^{\mathbf{x}_i}}{\sum_j e^{\mathbf{x}_j}} \cdot \frac{e^{-m}}{e^{-m}} \\ & =\frac{e^{\mathbf{x}_i-m}}{\sum_j e^{\mathbf{x}_j-m}} \\ & =\operatorname{softmax}(\mathbf{x}-m)_i \end{aligned}

우리가 m(x)=maxixim(\mathbf{x})=\max _i \mathbf{x}_i로 설정하고, (x)=jexjm(x)\ell(\mathbf{x})=\sum_j e^{\mathbf{x}_j-m(\mathbf{x})}로 계산하면

yi=softmax(xm(x))i=exim(x)(x)\mathbf{y}_i=\operatorname{softmax}(\mathbf{x}-m(\mathbf{x}))_i=\frac{e^{\mathbf{x}_i-m(\mathbf{x})}}{\ell(\mathbf{x})}

우리는 소프트맥스(softmax)의 수치적으로 안정된 버전을 구현할 수 있습니다. 이것을 안전한 소프트맥스(safe softmax)라고 부르기도 합니다.

@jit
def safe_softmax(logits):
    exp_logits = jnp.exp(logits - logits.max())
    return exp_logits / exp_logits.sum()

하지만 안정성을 확보하는 대가로 효율성이 떨어지게 됩니다. 이제 우리는 x\mathbf{x}에 대해 m(x)\boldsymbol{m}(\mathbf{x})를 계산하기 위해 한 번 더 스캔해야 하기 때문입니다. 결과적으로 벡터의 각 요소에 대해 총 4번의 메모리 접근(3번의 로드와 1번의 저장)이 발생하며, 이를 개선하고자 합니다.

두 벡터 x1\mathbf{x}^1, x2\mathbf{x}^2에 대해, 연결된 벡터 x=[x1,x2]\mathbf{x}=\left[\mathbf{x}^1, \mathbf{x}^2\right]의 통계를 다음과 같이 분해할 수 있습니다:

  • m(x)=max(m(x1),m(x2))m(\mathbf{x})=\max \left(m\left(\mathbf{x}^1\right), m\left(\mathbf{x}^2\right)\right)
  • (x)=em(x1)m(x)(x1)+em(x2)m(x)(x2)\ell(\mathrm{x})=e^{m\left(\mathrm{x}^1\right)-m(\mathrm{x})} \ell\left(\mathrm{x}^1\right)+e^{m\left(\mathrm{x}^2\right)-m(\mathrm{x})} \ell\left(\mathrm{x}^2\right)

이 성질을 기반으로 Milakov와 Gimelshein(2018)온라인 소프트맥스(online softmax)를 제시했으며, 이를 통해 m(x)m(\mathbf{x})(x)\ell(\mathbf{x})를 한 번의 스캔으로 계산할 수 있게 했습니다. 초기값을 m0=m_0=-\infty0=0\ell_0=0으로 설정한 후, 각 반복에서 i=1,,Li=1, \ldots, L에 대해 다음을 수행합니다:

  • mimax(mi1,xi)m_i \leftarrow \max \left(m_{i-1}, \mathbf{x}_i\right),
  • ii1emi1mi+eximi\ell_i \leftarrow \ell_{i-1} e^{m_{i-1}-m_i}+e^{\mathbf{x}_i-m_i}.

이 알고리즘은 최대값

mi=m(x1,,xi)m_i=m\left(\left|\mathbf{x}_1, \ldots, \mathbf{x}_i\right|\right)

과 정규화 항

i=([x1,,xi])\ell_i=\ell\left(\left[\mathbf{x}_1, \ldots, \mathbf{x}_i\right]\right)

을 입력 배열의 각 요소를 반복하면서 유지합니다. 각 반복에서 정규화 항을 새 최대값 mim_i에 맞춰 조정한 후, i\ell_i에 새 값을 추가합니다.

또한 장치를 완전히 활용할 수 있는 병렬 버전도 존재합니다:

[m(x)(x)]=[x11][x21][xL1]\left[\begin{array}{c} m(\mathbf{x}) \\ \ell(\mathbf{x}) \end{array}\right]=\left[\begin{array}{c} \mathbf{x}_1 \\ 1 \end{array}\right] \oplus\left[\begin{array}{c} \mathbf{x}_2 \\ 1 \end{array}\right] \oplus \cdots \oplus\left[\begin{array}{c} \mathbf{x}_L \\ 1 \end{array}\right]

여기서 이항 연산(binary operation) :R2×R2R2\oplus: \mathbb{R}^2 \times \mathbb{R}^2 \rightarrow \mathbb{R}^2는 다음과 같이 정의됩니다:

[mii][mjj]=[max(mi,mj)iemimax(mi,mj)+jemjmax(mi,mj)]\left[\begin{array}{c} m_i \\ \ell_i \end{array}\right] \oplus\left[\begin{array}{c} m_j \\ \ell_j \end{array}\right]=\left[\begin{array}{c} \max \left(m_i, m_j\right) \\ \ell_i e^{m_i-\max \left(m_i, m_j\right)}+\ell_j e^{m_j-\max \left(m_i, m_j\right)} \end{array}\right]

연산 \oplus결합 법칙(associative)교환 법칙(commutative)을 따르기 때문에 병렬로 효율적인 평가가 가능합니다.

@jit
def online_softmax(logits):
    
    def reducer(x, y):
        m_i, l_i = x
        m_j, l_j = y
        m = jnp.maximum(m_i, m_j)
        l = l_i * jnp.exp(m_i - m) + l_j * jnp.exp(m_j - m)
        return (m, l)
    
    m, l = jax.lax.reduce(
        (logits, jnp.ones_like(logits)), 
        (-jnp.inf, 0.), 
        reducer, 
        (0,)
    )
    exp_logits = jnp.exp(logits - m)
    return exp_logits / l

이 작은 테스트 스크립트를 실행하여 각 구현의 효율성을 평가할 수 있습니다:

# create large random vector
logits = jax.random.uniform(random.PRNGKey(42), shape=(1_000_000,))

# one warmup run for each function to compile
naive_softmax(logits)
safe_softmax(logits)
online_softmax(logits)

print('Naive:')
%timeit naive_softmax(logits).block_until_ready()
print('\nSafe:')
%timeit safe_softmax(logits).block_until_ready()
print('\nOnline:')
%timeit online_softmax(logits).block_until_ready()

이것은 TPU-v3에서 실행된 스크립트의 출력 결과입니다:

Naive:
194 μs ± 15.4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Safe:
254 μs ± 17.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Online:
199 μs ± 22.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

원본 논문을 참조하여 더 많은 세부 정보와 softmax + top-k 융합 알고리즘을 확인하세요.

Lazy softmax

이제 다시 어텐션 연산(attention operation)의 계산으로 돌아가 봅시다. 쿼리(query), 키(key) 및 값(value) 텐서 Q,K,VRL×d\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{L \times d}가 주어졌을 때, 목표는 다음을 계산하는 것입니다(간단히 하기 위해 마스킹과 1d\frac{1}{\sqrt{d}} 정규화를 생략합니다):

S=QKT,P=softmax(S),O=PV\mathbf{S}=\mathbf{Q K}^{\mathbf{T}}, \quad \mathbf{P}=\operatorname{softmax}(\mathbf{S}), \quad \mathbf{O}=\mathbf{P V}

직접적인 구현은 다음과 같습니다:

Sij=QiTKj,Pij=eSijl=1LeSil,Oi=l=1LPilVli,j=1,,L\mathbf{S}_{i j}=\mathbf{Q}_i^T \mathbf{K}_j, \quad \mathbf{P}_{i j}=\frac{e^{\mathbf{S}_{i j}}}{\sum_{l=1}^L e^{\mathbf{S}_{i l}}}, \quad \mathbf{O}_i=\sum_{l=1}^L \mathbf{P}_{i l} \mathbf{V}_l \quad \forall i, j=1, \ldots, L

위의 구현 문제는 모든 jj에 대해 Sij\mathbf{S}_{i j}를 먼저 계산하고 저장해야 한다는 점에서, 각 쿼리에 대해 선형 시간 및 메모리 복잡도를 요구하며, 전체 시간 및 공간 복잡도는 O(L2)\mathcal{O}\left(L^2\right)입니다. Rabe와 Staats (2022)는 분배 법칙을 사용하여 어텐션 연산의 마지막에 정규화 항의 나눗셈을 이동하는 것을 제안했습니다:

Oi=l=1LVleSill=1LeSdi=1,,L\mathbf{O}_i=\frac{\sum_{l=1}^L \mathbf{V}_l e^{\mathbf{S}_{i l}}}{\sum_{l=1}^L e^{\mathbf{S}_d}} \quad \forall i=1, \ldots, L

이 구현은 Lazy softmax라 불리며, 각 쿼리에 대해 상수 메모리로 계산할 수 있습니다. v0Rd\mathbf{v}_0 \in \mathbb{R}^d 벡터와 0\ell_0 스칼라를 0으로 초기화한 상태에서 j=1,,Lj=1, \ldots, L에 대해 키/값 쌍을 순차적으로 처리할 때 다음과 같이 업데이트합니다:

vjvj1+Vjesijjj1+eSij\begin{aligned} & \mathbf{v}_j \leftarrow \mathbf{v}_{j-1}+\mathbf{V}_j e^{\mathbf{s}_{i j}} \\ & \ell_j \leftarrow \ell_{j-1}+e^{\mathbf{S}_{i j}} \end{aligned}

모든 키와 값을 처리한 후, vLlL\frac{\mathbf{v}_L}{l_L}을 나누어 최종 결과를 얻습니다.

이 알고리즘이 Softmax의 단순 구현과 동일한 수치적 문제를 가지고 있다는 것을 알 수 있습니다: 즉, 지수화된 점수(및 값)의 합을 점진적으로 계산하는 것입니다. 일반적인 Safe-softmax 트릭은 여기에서 적용할 수 없는데, 그 이유는 최대값이 시퀀스의 마지막 점수에 따라 달라질 수 있기 때문입니다. 또한, 점수들은 지수화된 후에야 누적합에 더할 수 있기 때문에 빼기를 지연할 수도 없습니다.

이 문제를 해결하기 위해 저자들은 Online softmax에서와 같이 추가적인 스칼라 mm을 도입하여, 점진적인 알고리즘이 지금까지 본 최대 점수를 추적하고 필요할 때마다 지수화된 값의 합을 재정규화합니다:

mjmax(m,Sij)vjvj1emj1mj+VjeSijmjjj1+emj1mj\begin{aligned} m_j & \leftarrow \max \left(m, \mathbf{S}_{i j}\right) \\ \mathbf{v}_j & \leftarrow \mathbf{v}_{j-1} e^{m_{j-1}-m_j}+\mathbf{V}_j e^{\mathbf{S}_{i j}-m_j} \\ \ell_j & \leftarrow \ell_{j-1}+e^{m_{j-1}-m_j} \end{aligned}

저자들은 또한 대규모 병렬 처리를 활용하여 메모리 효율적인 병렬 알고리즘을 위한 코드를 Jax로 제공했습니다. 여기에서 주목할 점은 summarize_chunk 함수에서 jax.checkpoint 데코레이터를 사용하는 이유입니다. 이는 이 알고리즘이 순방향 패스 동안 어텐션 매트릭스(attention matrix)의 일부를 순차적으로 요약함으로써 메모리를 절약하기 때문에 이미 요약한 어텐션 매트릭스 부분을 잊어버릴 수 있게 하는 것입니다. 만약 순방향 패스 동안 모든 중간 결과를 저장해야 한다면, 알고리즘은 메모리 이점을 완전히 잃게 될 것입니다. 따라서 저자들은 개별 청크를 요약하는 함수에 그레디언트 체크포인팅(gradient checkpointing)을 적용할 것을 제안했습니다. 이로 인해 순방향 패스 동안 중간 결과를 잊어버릴 수 있고, 역전파 과정에서 다시 계산될 수 있습니다.

표준 어텐션 알고리즘(attention algorithm)에 체크포인팅을 적용하는 것만으로는 이러한 결과를 얻을 수 없습니다. 표준 어텐션 알고리즘에서 체크포인팅은 어텐션 매트릭스가 생성된 후 이를 잊어버리게 하지만, 쿼리 청크 어텐션 알고리즘(query chunk attention algorithm)은 처음부터 전체 어텐션 매트릭스를 형성하지 않습니다.

FlashAttention

FlashAttention은 현재 가장 인기 있는 어텐션 메커니즘(attention mechanism) 구현 방식 중 하나일 수 있습니다. 실제로는 표준 어텐션(attention)보다 더 많은 FLOP를 수행하지만, GPU 메모리 레벨 간의 읽기 및 쓰기를 고려하여 어텐션 알고리즘(attention algorithm)을 최적화함으로써 최대 3배 더 빠르게 실행됩니다. 이 글의 첫 번째 섹션에서 GPU 아키텍처(GPU architecture)SRAM에서 텐서(tensor)를 이동하는 것이 현대 GPU에서 HBM에서 이동하는 것보다 10배 빠르다고 설명한 것을 기억하시나요?

표준 어텐션(forward pass)는 다음과 같이 진행됩니다:

  • HBM에서 블록 단위로 Q,K\mathbf{Q}, \mathbf{K}를 로드하고, S=QKT\mathbf{S}=\mathbf{Q K}^{\mathbf{T}}를 계산한 후, S\mathbf{S}HBM에 기록합니다.
  • HBM에서 S\mathbf{S}를 읽어 소프트맥스(softmax)P=softmax(S)\mathbf{P}=\operatorname{softmax}(\mathbf{S})를 계산한 후, P\mathbf{P}HBM에 기록합니다.
  • HBM에서 블록 단위로 P\mathbf{P}V\mathbf{V}를 로드하고, O=PV\mathbf{O}=\mathbf{P V}를 계산한 후, O\mathbf{O}HBM에 기록합니다.
  • O 반환

Dao et al. (2022)LL에 대한 HBM 접근을 서브-제곱(Sub-quadratic)으로 줄이기 위해 두 가지 기법을 도입하여 이를 수정했습니다:

  1. 타일링(tiling): 입력 Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V}SRAM에 맞도록 블록으로 나눈 후, 각 블록에 대한 온라인 소프트맥스(online softmax)를 계산하여 어텐션(attention) 점수를 얻고, 그 결과를 모두 결합합니다. 타일링을 사용하면 전체 어텐션 메커니즘을 하나의 CUDA 커널(CUDA kernel)에서 구현할 수 있습니다:

    • HBM에서 입력을 로드합니다.
    • 어텐션 계산의 모든 단계를 수행합니다: QK T{ }^T 행렬 곱, 소프트맥스(softmax), 선택적으로 마스킹(masking)드롭아웃(dropout), PV 행렬 곱.
    • 그리고 나서야 결과를 HBM에 기록합니다.
  2. 재계산 (Recomputation):
    훈련(Training) 과정에서는 역전파(backward pass) 동안 Q,K,V\mathbf{Q}, \mathbf{K}, V에 대한 그라디언트(gradients)를 계산하기 위해 S,PRL×L\mathbf{S}, \mathbf{P} \in \mathbb{R}^{L \times L}과 같은 중간 출력 값을 저장해야 합니다. 훈련 중 메모리 사용량을 줄이는 표준 메커니즘은 그라디언트 체크포인팅(gradient checkpointing)을 사용하는 것입니다. 이는 순전파(forward pass) 중에 이러한 출력 값을 잊어버리고 역전파에서 다시 계산하는 방법입니다. 하지만 이 방식은 메모리를 절약하는 대신 속도를 희생해야 합니다.

    저자들은 선택적인 그라디언트 체크포인팅을 제안합니다. 이 방법은 출력 O\mathbf{O}소프트맥스(softmax) 정규화 통계( m,m, \ell )만 저장하여, SRAM에 있는 Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V} 블록으로부터 역전파 시 어텐션 매트릭스(attention matrices) S\mathbf{S}P\mathbf{P}를 쉽게 재계산할 수 있도록 합니다.

FlashAttention 순전파(FlashAttention forward pass):

  • 블록 크기(block sizes)를 설정합니다: BKV=M4d,BQ=min(M4d,d)B_{\mathrm{KV}}=\left\lceil\frac{M}{4 d}\right\rceil, B_{\mathbf{Q}}=\min \left(\left\lceil\frac{M}{4 d}\right\rceil, d\right), 여기서 MM온칩 SRAM(on-chip SRAM) 크기입니다.
  • ORL×d\color{#E86456}{\mathbf{O} \in \mathbb{R}^{L \times d}}, RL\color{#E86456}{\ell \in \mathbb{R}^{L}} 을 0으로 초기화하고, mRL\color{#E86456}{m \in \mathbb{R}^L}-\infty로 초기화하여 모두 HBM에 저장합니다.
  • 시퀀스 축(sequence axis)을 따라 Q\color{#E86456}{\mathbf{Q}}TQ=LBQT_{\mathbf{Q}}=\left\lceil\frac{L}{B_{\mathbf{Q}}}\right\rceil 블록으로, K,V\color{#E86456}{\mathbf{K}, \mathbf{V}}TKV=LBKvT_{\mathbf{K V}}=\left\lceil\frac{L}{B_{\mathrm{Kv}}}\right\rceil 블록으로 나눕니다.
  • 시퀀스 축을 따라 O,,m\color{#E86456}{\mathbf{O}, \ell, m}TQT_{\mathbf{Q}} 블록으로 나눕니다.
  • FOR j=1,,TKVj=1, \ldots, T_{\mathrm{KV}}에 대해:
    - Kj,Vj\color{#65AD69}{\mathbf{K}_j, \mathbf{V}_j}HBM에서 SRAM으로 로드합니다.
    - i=1,,TQi=1, \ldots, T_{\mathbf{Q}}에 대해:
    - Qi,Oi,i,mi\color{#65AD69}{\mathbf{Q}_i, \mathbf{O}_i, \ell_i, m_i}HBM에서 SRAM으로 로드합니다.
    - 정규화되지 않은 어텐션 점수(attention scores)를 계산합니다:
Sij=QiKjTRBQ×BKV.\color{#65AD69}{\mathbf{S}_{ij} = \mathbf{Q}_i \mathbf{K}^T_j \in \mathbb{R}^{B_{\mathbf{Q}} \times B_{\mathbf{KV}}}}.
	- 계속해서 계산합니다.
m~ij=rowmax(Sij)RBQ,P~ij=exp(Sijm~ij)RBQ×BKV,~ij=rowsum(P~ij)RBQ\color{#65AD69}{ \begin{aligned} \tilde{m}_{ij} & = \operatorname{rowmax}(\mathbf{S}_{ij}) \in \mathbb{R}^{B_{\mathbf{Q}}}, \\ \tilde{\mathbf{P}}_{ij} & = \exp ( \mathbf{S}_{ij} - \tilde{m}_{ij}) \in \mathbb{R}^{B_{\mathbf{Q}} \times B_{\mathbf{KV}}}, \\ \tilde{\ell}_{ij} & = \operatorname{rowsum}(\tilde{\mathbf{P}}_{ij}) \in \mathbb{R}^{B_{\mathbf{Q}}} \end{aligned}}
	- 통계 갱신
minew=max(mi,m~ij)RBQ,inew=emiminewi+em~ijminew~ijRBQ\color{#65AD69}{ \begin{aligned} m_i^{\text{new}} & = \max(m_i, \tilde{m}_{ij}) \in \mathbb{R}^{B_{\mathbf{Q}}}, \\ \ell_{i}^{\text{new}} & = e^{m_i-m_i^{\text{new}}} \ell_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}}\tilde{\ell}_{ij} \in \mathbb{R}^{B_{\mathbf{Q}}} \end{aligned}}
 -  $\color{#E86456}{\mathbf{O}_i} \color{#EDA137}{ \leftarrow } \color{#65AD69}{\operatorname{diag}(\ell_i^{\text{new}})^{-1} \big( \operatorname{diag}(\ell_i) e^{m_i-m_i^{\text{new}}} \mathbf{O}_i + e^{\tilde{m}_{ij} - m_i^{\text{new}}} \tilde{\mathbf{P}}_{ij} \mathbf{V}_j \big) }$을 HBM에 기록합니다.
 - $\color{#E86456}{m_i} \color{#EDA137}{ \leftarrow } \color{#65AD69}{m_{i}^{\text{new}}}$을 HBM에 기록합니다.
  • O\mathbf{O}를 반환합니다.

결국, FlashAttention 알고리즘은 O=softmax(QKT)V\mathbf{O}=\operatorname{softmax}\left(\mathbf{Q K}^T\right) \mathbf{V}O(L2d)\mathcal{O}\left(L^2 d\right) FLOPs로 반환하며, 입력 및 출력 이외에 추가로 O(L)\mathcal{O}(L) 메모리가 필요합니다. 메모리 접근 측면에서는, FlashAttentionO(L2d2M1)\mathcal{O}\left(L^2 d^2 M^{-1}\right) HBM 접근을 요구하는데, 여기서 dMLdd \leq M \leq L d이고, 표준 어텐션(attention)과 비교하면 O(Ld+L2)\mathcal{O}\left(L d+L^2\right)보다 적은 접근이 필요합니다.

![[Pasted image 20241015115922.png]]

플래시어텐션(FlashAttention) **순전파(forward pass)**가 수행되는 과정의 개략도입니다. 여기서 $\mathbf{Q}$는 $T_{\mathrm{Q}}=1$ 크기 $B_{\mathrm{Q}} \times d$의 블록으로 분할되고, $B_{\mathbf{Q}}=3$입니다. $\mathbf{K}$와 $\mathbf{V}$는 각각 $B_{\mathrm{KV}}=2$인 크기 $B_{\mathrm{KV}} \times d$의 블록으로 나뉘어 $T_{\mathrm{KV}}=2$ 블록으로 분할됩니다. 이때 $\ell_1=\sum e^{\mathrm{s}_{11}}, \ell_2=\ell_1+\sum e^{\mathrm{S}_{12}}$가 됩니다. 소프트맥스(softmax)에서 $m$을 빼는 단계는 간소화를 위해 생략되었습니다.

플래시어텐션(FlashAttention) 저자들은 이를 쿼리 청크 어텐션 알고리즘(query chunk attention algorithm)과 비교하며 세 가지 주요 차이점을 언급했습니다:

  1. 메모리 효율적인 어텐션(memory-efficient attention)은 메모리 사용량을 줄이는 데 중점을 두는 반면, 플래시어텐션(FlashAttention)은 HBM(HBM: High Bandwidth Memory) 접근을 줄이는 데 초점을 맞춥니다.
  2. 메모리 효율적인 어텐션(memory-efficient attention)은 각 블록을 처리할 때 그 임시 출력과 함께 소프트맥스 정규화 통계를 요약합니다. 반면에 플래시어텐션(FlashAttention)은 각 블록을 처리한 후 출력을 점진적으로 업데이트하므로 출력의 복사본이 하나만 필요합니다.
  3. 메모리 효율적인 어텐션(memory-efficient attention)은 그라디언트 체크포인팅(gradient checkpointing)을 사용하여 어텐션 행렬과 각 블록의 임시 출력을 다시 계산합니다. 반면 플래시어텐션(FlashAttention)은 어텐션 행렬만 다시 계산하며 각 블록의 임시 출력을 다시 계산하지 않습니다.

"플래시어텐션(FlashAttention)" 논문은 매우 잘 작성된 논문으로, 대규모 언어 모델을 훈련하는 사람이라면 꼭 읽어볼 가치가 있습니다. 논문에는 플래시어텐션의 역전파(backward pass), 이론적 증명 및 다른 최적화 알고리즘과의 비교에 대한 더 많은 세부 사항이 포함되어 있습니다.

플래시어텐션(FlashAttention) + 병렬 처리

플래시어텐션(FlashAttention)은 어텐션 계산 속도를 크게 향상시키고, 시퀀스 길이에 대한 메모리 사용을 이차적으로 줄여 1차적으로 만듭니다. 대부분의 경우에 효과적으로 작동하지만, 긴 시퀀스와 작은 배치 크기 또는 적은 수의 어텐션 헤드를 사용하는 경우에는 병렬 처리가 충분하지 않아 최적화되지 않았습니다.

첫 번째 버전의 플래시어텐션(FlashAttention) 커널은 하나의 어텐션 헤드당 하나의 스레드 블록을 사용하며, 총 Bh\boldsymbol{B h} 스레드 블록을 실행합니다 (여기서 BB는 배치 크기, hh는 어텐션 헤드의 수를 의미합니다). 각 스레드 블록은 스트리밍 멀티프로세서(SM)에서 실행되며, 이러한 스케줄링은 BhB h가 SM의 수와 유사할 때만 (예: A100 GPU에서 108개의 SM) 효율적으로 컴퓨팅 리소스를 사용할 수 있습니다. 현대 병렬 처리 기법(modern parallelism techniques)을 사용하여 LLM을 훈련할 때, 배치 크기는 데이터 병렬 처리(DP)로, 어텐션 헤드의 수는 모델 병렬 처리(TP)로 인해 줄어듭니다.

플래시어텐션(FlashAttention) 원작자인 Tri Dao는 GPU의 멀티프로세서를 더 잘 활용하기 위해 시퀀스 축을 따라 추가 병렬 처리를 적용했습니다. 이러한 방식에서, 순방향 패스에서는 이제 하나의 어텐션 헤드를 여러 스레드 블록이 처리하며, 각 블록은 어텐션 행렬의 자신의 행 부분을 담당합니다. 어텐션 행렬의 행은 서로 의존하지 않으므로, 블록 간의 통신이 필요하지 않습니다.

역방향 패스에서는 이제 각 스레드 블록이 어텐션 행렬의 열 부분을 담당하게 됩니다. 열을 따라 병렬화하는 것이 행을 따라 병렬화하는 것보다 더 빠른데, 이는 작업자 간의 통신이 줄어들기 때문입니다 (열을 따라 병렬화하면 쿼리의 그라디언트만 집계하면 되지만, 행을 따라 병렬화하면 키와 값의 그라디언트를 집계해야 합니다).

플래시어텐션-2(FlashAttention-2)

플래시어텐션(FlashAttention)의 새로운 버전(Tri Dao, 2023)은 긴 시퀀스에서 스레드 블록 간의 병렬 처리를 포함하여 점유율을 높였습니다. 이와 더불어 두 가지 수치가 감소되었습니다:

  • 비행렬 곱셈(non-matmul) 연산의 FLOPs 수,
  • SRAM을 통한 통신 횟수.

첫 번째 개선점은 온라인 소프트맥스(online softmax) 트릭을 다시 작성하여, 출력은 변경하지 않으면서 리스케일링(rescaling) 작업과 경계 검사(boundchecking), 인과 마스킹(causal masking)의 수를 줄이는 것입니다(행렬 곱셈 처리량은 비행렬 곱셈 처리량보다 몇 배 더 높을 수 있음을 기억해야 합니다).

두 번째 개선점은 워프(warp) 간의 작업 분할을 최적화하는 것입니다. 워프는 함께 작업하는 스레드 그룹을 의미합니다. 첫 번째 플래시어텐션(FlashAttention)에서는 K\mathbf{K}V\mathbf{V}가 각 스레드 블록 내에서 4개 또는 8개의 워프로 나뉘어졌습니다. 각 워프는 QKT\mathbf{Q K}^T의 일부분을 계산한 후, 그 결과를 다시 V\mathbf{V}의 일부분과 곱하여 결과를 더하기 위해 공유 메모리로 기록하고 동기화해야 했습니다. 하지만, 이 방식은 비효율적이었습니다. 모든 워프가 중간 결과를 공유 메모리에 기록하고, 동기화한 후에 다시 결과를 더해야 했기 때문입니다.

플래시어텐션-2(FlashAttention-2)에서는 K\mathbf{K}V\mathbf{V}는 모든 워프에서 접근 가능하도록 유지하면서, Q\mathbf{Q}가 4개의 워프로 나뉘었습니다. 각 워프는 QKT\mathbf{Q} \mathbf{K}^{\boldsymbol{T}}의 일부분을 계산한 후, 해당 결과를 V\mathbf{V}와 곱하여 결과를 산출하므로 워프 간의 통신이 필요하지 않습니다. 이로 인해 SRAM 읽기/쓰기가 줄어들어 속도가 향상됩니다.

두 번째 버전에서는 또한 더 큰 헤드 차원( d=128256d=128 \rightarrow 256 )을 지원하고, MQA와 GQA에 대한 지원도 도입되었습니다.

플래시어텐션-3(FlashAttention-3)

가장 최근 버전인 플래시어텐션-3(FlashAttention-3, 2024)은 H100 GPU에 특화된 최적화에 중점을 두었습니다. 이전 버전에서는 호퍼(Hopper) 아키텍처에서 최대 35%의 활용도에 그쳤기 때문에, 이번 업데이트에서는 새로운 기능을 활용했습니다. 그중 하나가 텐서 메모리 가속기(Tensor Memory Accelerator, TMA)로, 이 하드웨어는 비동기 주소 생성을 수행하고 메모리 접근을 가속화할 수 있는 새로운 장치입니다. 주요 속도 향상 기법은 다음과 같습니다:

  1. 전체 연산과 데이터 이동을 중첩(Overlap)
  2. 블록 단위의 행렬 곱셈(matmul)과 소프트맥스(softmax) 연산을 교차(Interleave) 수행
  3. FP8 저정밀도에 대한 하드웨어 지원을 활용한 블록 양자화(Block quantization) 및 비동기 처리

이러한 최적화를 통해 플래시어텐션-3은 H100 GPU의 성능을 크게 향상시켰습니다.

링 어텐션(Ring Attention)

플래시 어텐션(FlashAttention)을 사용하더라도 메모리 복잡도는 여전히 LL에 대해 선형적으로 증가하기 때문에, 시퀀스 길이를 확장하는 데는 메모리 용량의 한계가 있습니다. 우리는 장치의 수 NN에 맞춰 문맥 크기를 확장할 수 있으며, 입력을 NN 개의 부분으로 나누어 병렬로 연산을 수행한 다음 결과를 모을 수 있습니다. 그러나 어텐션은 Q\mathbf{Q}K\mathbf{K}, V\mathbf{V} 행렬의 모든 요소에 접근해야 하기 때문에, 이때 대규모 행렬을 장치 간에 전송하는 것은 엄청난 통신 오버헤드를 발생시킬 수 있습니다 (예: A100의 NVLink를 통한 처리량은 600 GB/s600 \mathrm{~GB} / \mathrm{s}인 반면, PCIe는 64GB/s에 불과합니다).

링 어텐션(Ring Attention) (Lie et al. (2023))은 이러한 문제를 해결하며, 매우 큰 문맥 시나리오에서 통신 오버헤드를 계산 뒤로 숨기는 아이디어를 탐구합니다. 이 알고리즘의 절차는 다음과 같습니다:

  • 입력 시퀀스를 NN 개의 블록으로 나누고, 시퀀스 축을 따라 각 장치가 크기 C=LNC=\left\lceil\frac{L}{N}\right\rceil인 입력 블록 하나씩을 저장하도록 합니다. 각 ii-번째 장치는 자신에게 할당된 입력 블록에 대해 Qi\mathbf{Q}_i, Ki\mathbf{K}_i, Vi\mathbf{V}_i를 계산합니다.
  • iter =0,,N1=0, \ldots, N-1 동안:
    • ii-번째 장치에서 병렬로:
      • j=(j=( iter +i)modN+i) \bmod N으로 설정합니다.
      • 로컬 Qi\mathbf{Q}_i, Kj\mathbf{K}_j, Vj\mathbf{V}_j 블록을 사용하여 메모리 효율적인 어텐션을 점진적으로 계산합니다.
      • 동시에, Kj\mathbf{K}_j, Vj\mathbf{V}_j 블록을 다음 장치로 전송하고, 이전 장치로부터 새 블록을 수신합니다.

![[Pasted image 20241015121832.png]]
![[Pasted image 20241015121848.png]]
![[Pasted image 20241015121937.png]]
![[Pasted image 20241015121953.png]]

GPU는 링 형태로 배열되며, 각 GPU는 $\mathbf{Q}$의 일부를 보유합니다. 링 어텐션(Ring Attention) 과정에서 GPU들은 서로에게 $\mathbf{K}$, $\mathbf{V}$ 블록을 전달합니다.

각 반복(iteration)에서 로컬 어텐션을 계산하는 데는 O(C2d)\mathcal{O}\left(C^2 d\right) 연산이 필요하며, 각 장치는 Kj,VjRC×d\mathbf{K}_j, \mathbf{V}_j \in \mathbb{R}^{C \times d} 텐서를 전송해야 하므로 O(Cd)\mathcal{O}(\mathcal{C} d) 바이트의 데이터를 전송해야 합니다. 따라서 통신 오버헤드를 효과적으로 숨기기 위한 최소 블록 크기 하한은 다음과 같이 계산됩니다:

C=LNcomputeperformancecommunicationbandwidth.C = \big \lceil \frac{L}{N} \big \rceil \geq \frac{\operatorname{compute performance}}{\operatorname{communication bandwidth}}.

저자들은 또한 JAX에서 구현된 코드를 공유했습니다.

스트라이프 어텐션(Stripe Attention)

Brandon et al. 2023은 원인적 트랜스포머(causal transformer)에 대한 링 어텐션(Ring Attention)의 성능을 연구했으며, 삼각형 구조의 원인적 어텐션 계산 때문에 작업량이 크게 불균형하게 분포된다는 사실을 발견했습니다. 이를 해결하기 위해 간단한 확장 방법인 스트라이프 어텐션(Stripe Attention)을 제안하여 약 1.5배의 속도 향상을 달성했습니다.

링 어텐션의 문제는 첫 번째 반복을 제외한 모든 반복에서 일부 장치의 작업량은 완전히 필요한 작업(마스크되지 않음)인 반면, 다른 장치의 작업량은 최종 출력에 불필요한 작업(마스크됨)이라는 점입니다. 따라서 불필요한 작업을 계산할 필요가 없습니다. 총 지연 시간(latency)은 각 반복에서 참여하는 장치 중 가장 오랜 시간을 소요한 장치에 의해 결정됩니다. 그 결과, 각 장치에 대한 최적화 여부와 상관없이, 반복당 지연 시간은 완전히 마스크되지 않은 작업을 계산하는 데 걸리는 시간과 동일하게 유지됩니다. 따라서 이론적으로 절반의 연산만 필요함에도 불구하고, 링 어텐션은 마스크되지 않은 작업을 모두 처리하는 것처럼 실행됩니다.

![[Pasted image 20241015122346.png]]
![[Pasted image 20241015122400.png]]
![[Pasted image 20241015122417.png]]
![[Pasted image 20241015122437.png]]

링 주의(Ring Attention, 왼쪽)와 스트라이프 주의(Stripe Attention)의 작업 부하 분배. 정사각형 행렬은 가능한 모든 쿼리/키 쌍 상호작용을 나타내며, 행 인덱스는 쿼리에 대응하고 열 인덱스는 키에 대응합니다. 대각선 위의 모든 셀은 인과 마스크에 의해 마스킹 처리되어 건너뛸 수 있습니다. 색상은 각 장치가 담당하는 계산 부분을 나타냅니다. 링 주의(Ring Attention)에서 일부 장치는 완전히 마스킹 처리된 작업을 담당하고 있는 반면, 스트라이프 주의(Stripe Attention)는 모든 장치에 걸쳐 균형 잡힌 작업 부하를 유지하고 있습니다.

링 주의(Ring Attention)처럼 연속된 블록으로 토큰을 분할하는 대신, 스트라이프 주의(Stripe Attention)는 잔여물(residues) 모듈로 NN을 기준으로 일정한 간격으로 스트라이프로 토큰을 분할합니다. 즉, ii번째 토큰은 imodNi \bmod N 장치에 할당됩니다. 실제로는 모델의 첫 임베딩 레이어 전에 입력 토큰을 퍼뮤테이션(permute)하여 이 분할 방식을 구현한 다음, 링 주의(Ring Attention)처럼 연속된 블록으로 분할할 수 있습니다. 분할이 끝나면 스트라이프 주의(Stripe Attention)와 링 주의(Ring Attention) 알고리즘은 거의 동일하게 진행됩니다.

또 다른 최근 개선 사항으로 트리 주의(Tree Attention, Shyam 외. 2024)가 있는데, 다수의 GPU에서 디코딩할 때 통신 비용을 줄이기 위해 트리-리덕션(tree-reduction) 토폴로지를 활용하여 링 주의(Ring Attention) 대비 비대칭적으로 8×8 \times 속도 향상을 달성했습니다.

PagedAttention / vLLM

KV 캐시는 상당한 양의 메모리를 차지합니다. 단순한 구현에서는 요청의 KV 캐시를 다른 텐서와 마찬가지로 연속된 메모리 공간에 저장할 수 있습니다. 그러나 전통적인 딥러닝 작업에서의 텐서와 달리, KV 캐시는 모델이 새로운 토큰을 생성함에 따라 동적으로 증가하고 감소하며, 그 수명과 길이는 미리 알 수 없습니다.

연속된 공간에 요청의 KV 캐시를 저장하려면 요청의 최대 길이에 맞춘 연속된 메모리 블록을 미리 할당해야 하며, 요청의 실제 길이는 훨씬 짧을 수 있습니다. 또 다른 메모리 비효율성은 빔 서치(beam search) 또는 병렬 샘플링(parallel sampling)과 같은 고급 디코딩 기술을 사용할 때 발생하는데, 이러한 경우 하나의 요청이 부분적으로 KV 캐시를 공유하는 여러 시퀀스로 구성될 수 있습니다. 그러나 KV 캐시가 연속된 별도의 공간에 저장되는 경우 메모리 공유가 불가능합니다.

이러한 한계를 해결하기 위해, 권 외(2023)는 운영체제(OS)의 메모리 단편화(fragmentation) 및 공유 문제 해결 방법인 가상 메모리(paging)를 차용한 PagedAttention 알고리즘을 제안했습니다. PagedAttention은 요청의 KV 캐시를 고정된 수의 토큰을 포함하는 블록으로 나누며, 이 블록들은 반드시 연속된 공간에 저장될 필요가 없습니다. 우리는 KV 캐시를 보다 유연한 방식으로 관리할 수 있습니다.

핵심 아이디어는 KV 캐시를 논리적 KV 블록 시리즈로 표현하는 것입니다. 새로운 토큰이 생성됨에 따라 왼쪽에서 오른쪽으로 채워지는 이 논리적 KV 블록들은 GPU 메모리의 비연속적인 물리적 KV 블록들로 나뉘며, 이는 블록 엔진에 의해 할당됩니다. KV 블록 관리자는 각 요청의 논리적 KV 블록과 물리적 KV 블록 간의 매핑을 유지하는 블록 테이블을 관리합니다. 논리적 KV 블록과 물리적 KV 블록을 분리함으로써 vLLM은 모든 위치에 대한 메모리를 미리 예약하지 않고도 KV 캐시 메모리를 동적으로 확장할 수 있습니다.

profile
AI Developer를 꿈꾸는 늦깎이 개발자

0개의 댓글