DeepSeek의 MLA (Multi-Head Latent Attention)

Jayce_97·2025년 3월 24일
0

DeepSeek

목록 보기
4/4

DeepSeek-V3에서 사용한 Multi-Head Latent Attention (MLA) 에서 Latent는 "잠재적인(hidden or latent) 변수"를 의미합니다. 여기서 잠재 변수(Latent Variable) 는 모델이 직접 관측할 수 없는 내재적인 정보를 학습하는 데 사용됩니다.

1. Latent이란?

Latent Variable은 데이터에서 직접 관측되지 않지만, 모델이 학습 과정에서 유용한 특징(feature)으로 활용하는 변수입니다.

e.g.

  • 사람의 감정을 예측하는 모델에서, "행복도" 같은 값은 직접 측정할 수 없지만(관측 불가능), 얼굴 표정이나 음성 톤 등의 데이터를 이용해 추론할 수 있습니다.

  • NLP에서는 단어 자체보다 문장의 "의도(intention)" 같은 것을 Latent Variable로 간주할 수 있습니다.

2. DeepSeek의 MLA (Multi-Head Latent Attention)

DeepSeek의 MLA는 기존의 Self-Attention 메커니즘을 확장한 개념입니다.
기존 Transformer의 Self-Attention은 입력 토큰 간의 관계를 직접 학습하는 방식입니다. 하지만 MLA는 Latent Variable을 추가로 활용하여 좀 더 풍부한 표현 학습을 가능하게 합니다.

기존 Self-Attention

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V

  • QWQXQ - W_QX
  • KWKXK - W_KX
  • VWVXV - W_VX
    즉, 입력 𝑋 를 선형 변환하여 Q, K, V를 생성한 후, Query-Key 유사도를 기반으로 가중치를 계산하고 Value에 적용하는 방식입니다.
    여기서 Query (Q), Key (K), Value (V) 는 주어진 입력 토큰에서 계산됩니다.

Multi-Head Latent Attention (MLA)

MLA는 Latent Variable LL을 추가하여 기존 Self-Attention보다 효율적인 정보를 학습하도록 합니다.

(1) Latent Variable 생성

입력 𝑋 를 Latent Space로 변환하는 과정을 추가합니다.

L=f(WLX)L = f(W_L X)

즉, WLW_L을 통해 입력 𝑋를 압축하여 중요한 특징만을 포함하는 잠재 변수 𝐿을 생성합니다.

(2) Latent Variable을 활용한 Attention 연산
기존 Attention 연산에서 Key 𝐾 대신 𝐿을 사용합니다.

Attention(Q,K,L)=softmax(QLTdk)LAttention(Q, K, L) = \text{softmax} \left( \frac{QL^T}{\sqrt{d_k}} \right) L

  • 직접적인 토큰 간의 관계가 아닌 Latent Variable을 통해 학습하여 더 좋은 일반화 성능을 보입니다.
  • 계산량을 줄이면서도 중요한 정보만 추출하는 역할을 합니다.

3. PyTorch 코드 구현

단일

import torch
import torch.nn as nn

class HeadLatentAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, latent_dim):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        
        # Query, Key, Value 변환을 위한 선형 레이어
        self.W_q = nn.Linear(embed_dim, latent_dim)  # embed_dim -> latent_dim으로 변경
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        
        # Latent Variable 생성
        self.latent_proj = nn.Linear(embed_dim, latent_dim)
        
        # 최종 출력 변환
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # Query를 latent_dim으로 투영하도록 변경
        Q = self.W_q(x)  # (batch_size, seq_len, latent_dim)
        K = self.W_k(x)  # (batch_size, seq_len, embed_dim)
        V = self.W_v(x)  # (batch_size, seq_len, embed_dim)
        
        # Latent Variable 생성
        L = self.latent_proj(x).mean(dim=1, keepdim=True)  # (batch_size, 1, latent_dim)

        # Latent Variable과 Query 간의 Attention 계산 (이제 차원이 맞습니다)
        attn_scores = torch.matmul(Q, L.transpose(-2, -1)) / (self.latent_dim ** 0.5)  # (batch_size, seq_len, 1)
        attn_weights = torch.softmax(attn_scores, dim=1)  # (batch_size, seq_len, 1)
        
        # Value와 attention weights를 곱합
        output = attn_weights.transpose(-2, -1) @ V  # (batch_size, 1, embed_dim)
        
        # 최종 출력 변환
        output = self.out_proj(output)  # (batch_size, 1, embed_dim)

        return output


# 테스트 실행
batch_size = 2
seq_len = 5
embed_dim = 32
latent_dim = 16
num_heads = 4

x = torch.randn(batch_size, seq_len, embed_dim)  # 입력

mla = MultiHeadLatentAttention(embed_dim, num_heads, latent_dim)
output = mla(x)

print(output.shape)  # (batch_size, 1, embed_dim)

코드 설명

  1. Query, Key, Value를 계산
  • 입력 𝑋를 선형 변환하여 Q, K, V를 생성
  1. Latent Variable 𝐿생성
  • 입력 𝑋를 latent_proj를 통해 변환 후, 평균 연산을 적용하여 잠재 변수 𝐿을 만듦
  1. Latent Variable을 이용한 Attention 계산
  • 와 𝐿 이의 유사도(attention score) 계산
  • Softmax를 적용하여 가중치(attention weight) 생성
  1. Latent Variable을 활용한 최종 출력 생성
  • V와 Attention Weight를 곱해 Latent Context를 계산
  • 기존 입력의 평균과 결합하여 최종 출력을 만듦

멀티 헤드

import torch
import torch.nn as nn

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, latent_dim):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.head_dim = latent_dim // num_heads
        
        # Query, Key 투영
        self.q_proj = nn.Linear(embed_dim, latent_dim)
        self.k_proj = nn.Linear(embed_dim, latent_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        
        # Latent Variable 생성
        self.latent_proj = nn.Linear(embed_dim, latent_dim)
        
        # 출력 투영
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # Query, Key, Value 투영
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(x)
        
        # Latent Variable 생성 (시퀀스 차원에서 평균)
        latent = self.latent_proj(x).mean(dim=1, keepdim=True)  # [batch_size, 1, latent_dim]
        latent = latent.view(batch_size, 1, self.num_heads, self.head_dim)
        
        # 차원 재배열 (multi-head attention 계산을 위해)
        q = q.transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        k = latent.transpose(1, 2)  # [batch_size, num_heads, 1, head_dim]
        
        # Scaled Dot-Product Attention
        attn_scores = torch.matmul(q, k.transpose(2, 3)) / (self.head_dim ** 0.5)  # [batch_size, num_heads, seq_len, 1]
        attn_weights = torch.softmax(attn_scores, dim=2)  # [batch_size, num_heads, seq_len, 1]
        
        # 원래 시퀀스 차원으로 재배열
        attn_weights = attn_weights.transpose(1, 2).mean(dim=2, keepdim=True)  # [batch_size, seq_len, 1, 1]
        attn_weights = attn_weights.squeeze(-1)  # [batch_size, seq_len, 1]
        
        # Weighted sum
        output = torch.bmm(attn_weights.transpose(1, 2), v)  # [batch_size, 1, embed_dim]
        
        # 출력 투영
        output = self.out_proj(output)  # [batch_size, 1, embed_dim]
        
        return output


# 테스트 실행
batch_size = 2
seq_len = 5
embed_dim = 32
latent_dim = 16
num_heads = 4

x = torch.randn(batch_size, seq_len, embed_dim)
mla = MultiHeadLatentAttention(embed_dim, num_heads, latent_dim)
output = mla(x)

print(output.shape)  # 예상 출력: torch.Size([2, 1, 32])

결과

이 방식은 기존 Transformer보다 계산량이 줄어들면서도, Latent Variable을 활용하여 더 강력한 표현 학습이 가능하도록 설계되었습니다.
DeepSeek-V3에서 사용한 MLA의 개념을 단순화한 구현이며, 실제 모델에서는 Multi-Head 방식이 더 복잡하게 적용될 수 있습니다.

profile
AI (ML/DL) 학습

0개의 댓글

관련 채용 정보