Transformer의 핵심, Attention과 Cross Attention 쉽게 구현하기

Bean·2025년 5월 25일
0

인공지능

목록 보기
41/123

1. PyTorch에서 Attention 구현하기

1-1. nn.MultiheadAttention 사용 (실전에서 가장 많이 사용됨)

import torch
import torch.nn as nn

seq_len = 10
batch_size = 2
embed_dim = 32
num_heads = 4

x = torch.randn(seq_len, batch_size, embed_dim)

mha = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)

attn_output, attn_weights = mha(x, x, x)  # Q=K=V
  • attn_output: [seq_len, batch_size, embed_dim]
  • attn_weights: [batch_size * num_heads, seq_len, seq_len]

1-2. 직접 구현한 단일 헤드 Attention

import torch
import torch.nn.functional as F

x = torch.randn(2, 10, 32)

W_Q = torch.randn(32, 32)
W_K = torch.randn(32, 32)
W_V = torch.randn(32, 32)

Q = x @ W_Q
K = x @ W_K
V = x @ W_V

scores = Q @ K.transpose(-2, -1) / (32 ** 0.5)
weights = F.softmax(scores, dim=-1)
output = weights @ V

2. Cross Attention이란?

  • Self Attention: Q = K = V = 동일 입력

  • Cross Attention:

    • Q = 디코더 입력
    • K, V = 인코더 출력
# Cross Attention 예제
attn_output, attn_weights = mha(decoder_input, encoder_output, encoder_output)

3. Transformer에서 Cross Attention 위치

  • Encoder: Self Attention

  • Decoder:

    1. Masked Self Attention
    2. Cross Attention (Q=Decoder, K/V=Encoder)
    3. FeedForward

PyTorch의 nn.Transformer는 Cross Attention을 자동으로 처리합니다.

transformer = nn.Transformer(
    d_model=512,
    nhead=8,
    num_encoder_layers=6,
    num_decoder_layers=6
)

output = transformer(src, tgt)
  • src: 인코더 입력
  • tgt: 디코더 입력
  • 내부적으로 Cross Attention이 포함됨 (직접 지정 불필요)

4. Transformer Layer 수 조절

  • PyTorch의 nn.Transformer는 다음 파라미터로 층 수 조절:
nn.Transformer(
    num_encoder_layers=6,
    num_decoder_layers=6,
)
  • 직접 구현 시에는 nn.ModuleList로 여러 층 쌓아 사용

profile
AI developer

0개의 댓글