DeepSeek-V3에서 사용한 Multi-Head Latent Attention (MLA) 에서 Latent는 "잠재적인(hidden or latent) 변수"를 의미합니다. 여기서 잠재 변수(Latent Variable) 는 모델이 직접 관측할 수 없는 내재적인 정보를 학습하는 데 사용됩니다.
Latent Variable은 데이터에서 직접 관측되지 않지만, 모델이 학습 과정에서 유용한 특징(feature)으로 활용하는 변수입니다.
e.g.
사람의 감정을 예측하는 모델에서, "행복도" 같은 값은 직접 측정할 수 없지만(관측 불가능), 얼굴 표정이나 음성 톤 등의 데이터를 이용해 추론할 수 있습니다.
NLP에서는 단어 자체보다 문장의 "의도(intention)" 같은 것을 Latent Variable로 간주할 수 있습니다.
DeepSeek의 MLA는 기존의 Self-Attention 메커니즘을 확장한 개념입니다.
기존 Transformer의 Self-Attention은 입력 토큰 간의 관계를 직접 학습하는 방식입니다. 하지만 MLA는 Latent Variable을 추가로 활용하여 좀 더 풍부한 표현 학습을 가능하게 합니다.
MLA는 Latent Variable 을 추가하여 기존 Self-Attention보다 효율적인 정보를 학습하도록 합니다.
(1) Latent Variable 생성
입력 𝑋 를 Latent Space로 변환하는 과정을 추가합니다.
즉, 을 통해 입력 𝑋를 압축하여 중요한 특징만을 포함하는 잠재 변수 𝐿을 생성합니다.
(2) Latent Variable을 활용한 Attention 연산
기존 Attention 연산에서 Key 𝐾 대신 𝐿을 사용합니다.
import torch
import torch.nn as nn
class HeadLatentAttention(nn.Module):
def __init__(self, embed_dim, num_heads, latent_dim):
super().__init__()
self.num_heads = num_heads
self.embed_dim = embed_dim
self.latent_dim = latent_dim
# Query, Key, Value 변환을 위한 선형 레이어
self.W_q = nn.Linear(embed_dim, latent_dim) # embed_dim -> latent_dim으로 변경
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
# Latent Variable 생성
self.latent_proj = nn.Linear(embed_dim, latent_dim)
# 최종 출력 변환
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.shape
# Query를 latent_dim으로 투영하도록 변경
Q = self.W_q(x) # (batch_size, seq_len, latent_dim)
K = self.W_k(x) # (batch_size, seq_len, embed_dim)
V = self.W_v(x) # (batch_size, seq_len, embed_dim)
# Latent Variable 생성
L = self.latent_proj(x).mean(dim=1, keepdim=True) # (batch_size, 1, latent_dim)
# Latent Variable과 Query 간의 Attention 계산 (이제 차원이 맞습니다)
attn_scores = torch.matmul(Q, L.transpose(-2, -1)) / (self.latent_dim ** 0.5) # (batch_size, seq_len, 1)
attn_weights = torch.softmax(attn_scores, dim=1) # (batch_size, seq_len, 1)
# Value와 attention weights를 곱합
output = attn_weights.transpose(-2, -1) @ V # (batch_size, 1, embed_dim)
# 최종 출력 변환
output = self.out_proj(output) # (batch_size, 1, embed_dim)
return output
# 테스트 실행
batch_size = 2
seq_len = 5
embed_dim = 32
latent_dim = 16
num_heads = 4
x = torch.randn(batch_size, seq_len, embed_dim) # 입력
mla = MultiHeadLatentAttention(embed_dim, num_heads, latent_dim)
output = mla(x)
print(output.shape) # (batch_size, 1, embed_dim)
import torch
import torch.nn as nn
class MultiHeadLatentAttention(nn.Module):
def __init__(self, embed_dim, num_heads, latent_dim):
super().__init__()
self.num_heads = num_heads
self.embed_dim = embed_dim
self.latent_dim = latent_dim
self.head_dim = latent_dim // num_heads
# Query, Key 투영
self.q_proj = nn.Linear(embed_dim, latent_dim)
self.k_proj = nn.Linear(embed_dim, latent_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Latent Variable 생성
self.latent_proj = nn.Linear(embed_dim, latent_dim)
# 출력 투영
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, _ = x.shape
# Query, Key, Value 투영
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x)
# Latent Variable 생성 (시퀀스 차원에서 평균)
latent = self.latent_proj(x).mean(dim=1, keepdim=True) # [batch_size, 1, latent_dim]
latent = latent.view(batch_size, 1, self.num_heads, self.head_dim)
# 차원 재배열 (multi-head attention 계산을 위해)
q = q.transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim]
k = latent.transpose(1, 2) # [batch_size, num_heads, 1, head_dim]
# Scaled Dot-Product Attention
attn_scores = torch.matmul(q, k.transpose(2, 3)) / (self.head_dim ** 0.5) # [batch_size, num_heads, seq_len, 1]
attn_weights = torch.softmax(attn_scores, dim=2) # [batch_size, num_heads, seq_len, 1]
# 원래 시퀀스 차원으로 재배열
attn_weights = attn_weights.transpose(1, 2).mean(dim=2, keepdim=True) # [batch_size, seq_len, 1, 1]
attn_weights = attn_weights.squeeze(-1) # [batch_size, seq_len, 1]
# Weighted sum
output = torch.bmm(attn_weights.transpose(1, 2), v) # [batch_size, 1, embed_dim]
# 출력 투영
output = self.out_proj(output) # [batch_size, 1, embed_dim]
return output
# 테스트 실행
batch_size = 2
seq_len = 5
embed_dim = 32
latent_dim = 16
num_heads = 4
x = torch.randn(batch_size, seq_len, embed_dim)
mla = MultiHeadLatentAttention(embed_dim, num_heads, latent_dim)
output = mla(x)
print(output.shape) # 예상 출력: torch.Size([2, 1, 32])
이 방식은 기존 Transformer보다 계산량이 줄어들면서도, Latent Variable을 활용하여 더 강력한 표현 학습이 가능하도록 설계되었습니다.
DeepSeek-V3에서 사용한 MLA의 개념을 단순화한 구현이며, 실제 모델에서는 Multi-Head 방식이 더 복잡하게 적용될 수 있습니다.