nn.MultiheadAttention
사용 (실전에서 가장 많이 사용됨)import torch
import torch.nn as nn
seq_len = 10
batch_size = 2
embed_dim = 32
num_heads = 4
x = torch.randn(seq_len, batch_size, embed_dim)
mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
attn_output, attn_weights = mha(x, x, x) # Q=K=V
attn_output
: [seq_len, batch_size, embed_dim]attn_weights
: [batch_size * num_heads, seq_len, seq_len]import torch
import torch.nn.functional as F
x = torch.randn(2, 10, 32)
W_Q = torch.randn(32, 32)
W_K = torch.randn(32, 32)
W_V = torch.randn(32, 32)
Q = x @ W_Q
K = x @ W_K
V = x @ W_V
scores = Q @ K.transpose(-2, -1) / (32 ** 0.5)
weights = F.softmax(scores, dim=-1)
output = weights @ V
Self Attention: Q = K = V = 동일 입력
Cross Attention:
Q
= 디코더 입력K, V
= 인코더 출력# Cross Attention 예제
attn_output, attn_weights = mha(decoder_input, encoder_output, encoder_output)
Encoder: Self Attention
Decoder:
PyTorch의
nn.Transformer
는 Cross Attention을 자동으로 처리합니다.
transformer = nn.Transformer(
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6
)
output = transformer(src, tgt)
src
: 인코더 입력tgt
: 디코더 입력nn.Transformer
는 다음 파라미터로 층 수 조절:nn.Transformer(
num_encoder_layers=6,
num_decoder_layers=6,
)
nn.ModuleList
로 여러 층 쌓아 사용