정의: 행렬곱(Matrix Multiplication)은 행렬을 곱해서 새로운 행렬을 만드는 연산
두 행렬 A(m x n), B(n x p)가 있을 때, 곱 A x B가 성립하려면,
예시)
A: 3 x 4
B: 2 x 6
A x B= 3 x 6
결과 행렬 C = A x B의 각 원소는
C𝑖, 𝑗i, j = A의 i번째 행 · B의 j번째 열 (내적)
A (2×3)
1 2 3
4 5 6
B (3×2)
| 1 | 4 |
| 2 | 5 |
| 3 | 6 |
C = A × B (2×2)
A는 행방향으로 B는 열방향으로 이동하며 곱함.
각 원소는 이렇게 계산됨:
C[0,0] = 1×1 + 2×2 + 3×3 = 14
C[0,1] = 1×4 + 2×5 + 3×6 = 32
C[1,0] = 4×1 + 5×2 + 6×3 = 32
C[1,1] = 4×4 + 5×5 + 6×6 = 77
결과:
14 32
32 77
일반적인 Shape을 예시로
이 상황에서 Key의 num_tokens와 head_dim의 위치를 바꿈
attn_scores = queries @ keys.transpose(2, 3)
Q: num_tokens x head_dim
Kᵀ: head_dim x num_tokens
Q x Kᵀ: num_tokens x num_tokens
-> Attention Score Matrix
이로 인해 Q x Kᵀ는 Q(토큰i)가 K(토큰j)를 얼마나 보고 있는가를 뜻한다.
Q[i] · K[j]