행렬곱 Matrix Multiplication

박찬호·2025년 11월 17일

정의: 행렬곱(Matrix Multiplication)은 행렬을 곱해서 새로운 행렬을 만드는 연산

행렬 곱의 조건

두 행렬 A(m x n), B(n x p)가 있을 때, 곱 A x B가 성립하려면,

  • 앞 행렬의 열(column)과 뒤 행렬의 행(row)의 수가 같아야함
  • 행렬곱 결과의 크기는 m x p

    예시)
    A: 3 x 4
    B: 2 x 6
    A x B= 3 x 6

행렬곱이 어떻게 실제로 계산되는가

결과 행렬 C = A x B의 각 원소는

C𝑖, 𝑗i, j = A의 i번째 행 · B의 j번째 열 (내적)

A (2×3)
1 2 3
4 5 6

B (3×2)
| 1 | 4 |
| 2 | 5 |
| 3 | 6 |

C = A × B (2×2)
A는 행방향으로 B는 열방향으로 이동하며 곱함.

각 원소는 이렇게 계산됨:
C[0,0] = 1×1 + 2×2 + 3×3 = 14
C[0,1] = 1×4 + 2×5 + 3×6 = 32
C[1,0] = 4×1 + 5×2 + 6×3 = 32
C[1,1] = 4×4 + 5×5 + 6×6 = 77

결과:
14 32
32 77

Transformer의 Q K V

일반적인 Shape을 예시로

  • Query: (batch, num_heads, num_tokens, head_dim)
  • Key: (batch, num_heads, num_tokens, head_dim)

이 상황에서 Key의 num_tokens와 head_dim의 위치를 바꿈

  • token과 head를 서로 행렬곱 시켜서 연관성을 파악하기 위함. 바꾸지 않으면 head끼리, token끼리 곱해짐.
attn_scores = queries @ keys.transpose(2, 3)

Q: num_tokens x head_dim
Kᵀ: head_dim x num_tokens
Q x Kᵀ: num_tokens x num_tokens
-> Attention Score Matrix

이로 인해 Q x Kᵀ는 Q(토큰i)가 K(토큰j)를 얼마나 보고 있는가를 뜻한다.
Q[i] · K[j]

profile
Velog.

0개의 댓글