Self-Attention과 KV 캐시

HanJu Han·2024년 12월 18일

LLM 최적화

목록 보기
14/16

Self-Attention과 KV 캐시 설명

쿼리가 전체 키에 대한 정보를 계산하는 이유

Self-Attention의 핵심은 각 Query(QQ)가 전체 Key(KK)와의 연관성을 계산해, 적절한 Value(VV)를 선택하도록 하는 것입니다. 이를 통해 모델은 문맥(Context)을 이해하고 중요한 정보를 가중치로 반영할 수 있습니다.

Attention Score 계산 (정확히 다시 설명)

Score=Q[K1,K2,,Kn]Td\text{Score} = \frac{Q [K_1, K_2, \ldots, K_n]^T}{\sqrt{d}}
  • QQ: 현재 토큰의 Query
  • [K1,K2,,Kn][K_1, K_2, \ldots, K_n]: 이전 모든 토큰의 Key
  • KiTK_i^T: Key를 전치해 Query와 내적 가능하도록 준비

Query가 전체 Key와의 연관성을 계산하는 것은, 문장에서 특정 토큰이 다른 모든 토큰과 어떻게 연결되는지를 반영하기 위함입니다.

수식과 예시를 통해 다시 설명

모델 가정

  • 토큰: ["ChatGPT", "는", "뛰어나다"]
  • Key, Value, Query 차원: d=2d = 2
  • 가중치 동일 (WKW_K, WVW_V, WQW_Q)

1. 첫 번째 토큰: "ChatGPT"

Key, Value, Query 계산

K1=[12],V1=[0.51],Q1=[31]K_1 = \begin{bmatrix}1 \\ 2\end{bmatrix}, \quad V_1 = \begin{bmatrix}0.5 \\ 1\end{bmatrix}, \quad Q_1 = \begin{bmatrix}3 \\ -1\end{bmatrix}

Self-Attention

첫 번째 토큰은 Key가 하나이므로:

Score1=Q1K1T2=[31][12]2=3×1+(1)×22=12\text{Score}_1 = \frac{Q_1 K_1^T}{\sqrt{2}} = \frac{\begin{bmatrix}3 & -1\end{bmatrix} \begin{bmatrix}1 \\ 2\end{bmatrix}}{\sqrt{2}} = \frac{3 \times 1 + (-1) \times 2}{\sqrt{2}} = \frac{1}{\sqrt{2}}

Softmax 결과는 1이 되므로,

Attention1=V1=[0.51]\text{Attention}_1 = V_1 = \begin{bmatrix}0.5 \\ 1\end{bmatrix}

2. 두 번째 토큰: "는"

Key, Value, Query 계산

K2=[01],V2=[00.5],Q2=[11]K_2 = \begin{bmatrix}0 \\ 1\end{bmatrix}, \quad V_2 = \begin{bmatrix}0 \\ 0.5\end{bmatrix}, \quad Q_2 = \begin{bmatrix}1 \\ -1\end{bmatrix}

Attention Score 계산

Query(Q2Q_2)는 모든 Key(K1,K2K_1, K_2)에 대해 Score를 계산:

Score2=Q2[K1,K2]T2=[11][1021]2=[1,1]2\text{Score}_2 = \frac{Q_2 [K_1, K_2]^T}{\sqrt{2}} = \frac{\begin{bmatrix}1 & -1\end{bmatrix} \begin{bmatrix}1 & 0 \\ 2 & 1\end{bmatrix}}{\sqrt{2}} = \frac{\begin{bmatrix}-1, -1\end{bmatrix}}{\sqrt{2}}

Softmax 적용:

softmax(1,1)=[0.5,0.5]\text{softmax}(-1, -1) = \begin{bmatrix}0.5, 0.5\end{bmatrix}

Attention Output

Attention2=softmax[V1,V2]=[0.5,0.5][0.5010.5]=[0.250.75]\text{Attention}_2 = \text{softmax} \cdot [V_1, V_2] = \begin{bmatrix}0.5, 0.5\end{bmatrix} \begin{bmatrix}0.5 & 0 \\ 1 & 0.5\end{bmatrix} = \begin{bmatrix}0.25 \\ 0.75\end{bmatrix}

3. 세 번째 토큰: "뛰어나다"

Key, Value, Query 계산

K3=[21],V3=[10.5],Q3=[31]K_3 = \begin{bmatrix}2 \\ 1\end{bmatrix}, \quad V_3 = \begin{bmatrix}1 \\ 0.5\end{bmatrix}, \quad Q_3 = \begin{bmatrix}3 \\ 1\end{bmatrix}

Attention Score 계산

Query(Q3Q_3)는 모든 Key(K1,K2,K3K_1, K_2, K_3)에 대해 Score를 계산:

Score3=Q3[K1,K2,K3]T2=[31][102211]2=[5,1,7]2\text{Score}_3 = \frac{Q_3 [K_1, K_2, K_3]^T}{\sqrt{2}} = \frac{\begin{bmatrix}3 & 1\end{bmatrix} \begin{bmatrix}1 & 0 & 2 \\ 2 & 1 & 1\end{bmatrix}}{\sqrt{2}} = \frac{\begin{bmatrix}5, 1, 7\end{bmatrix}}{\sqrt{2}}

Softmax 적용:

softmax(52,12,72)=[0.119,0.006,0.875]\text{softmax}\left(\frac{5}{\sqrt{2}}, \frac{1}{\sqrt{2}}, \frac{7}{\sqrt{2}}\right) = \begin{bmatrix}0.119, 0.006, 0.875\end{bmatrix}

Attention Output

Attention3=softmax[V1,V2,V3]=[0.119,0.006,0.875][0.50110.50.5]=[0.9420.557]\text{Attention}_3 = \text{softmax} \cdot [V_1, V_2, V_3] = \begin{bmatrix}0.119, 0.006, 0.875\end{bmatrix} \begin{bmatrix}0.5 & 0 & 1 \\ 1 & 0.5 & 0.5\end{bmatrix} = \begin{bmatrix}0.942 \\ 0.557\end{bmatrix}

KV 캐시의 역할

  • Key와 Value를 계속 저장해, 새 Query만 전체 Key에 대한 Attention Score를 계산.
  • 이전 Key/Value를 다시 계산할 필요 없음.
profile
시리즈를 기반으로 작성하였습니다.

0개의 댓글