Transformer FFN 완전 정복

Bean·2025년 8월 5일

인공지능

목록 보기
101/182

✨ 들어가며

Transformer 구조를 공부하다 보면 FFN(Feed-Forward Network)이라는 용어가 자주 등장합니다.
이 FFN은 사실상 MLP(Multi-Layer Perceptron)의 특별한 형태이며, 각 토큰 위치에 독립적으로 동일한 MLP를 적용하는 Position-wise MLP라고 불립니다.

이 글에서는 FFN의 구조와 작동 방식, 그리고 일반적인 MLP와의 차이점을 코드와 함께 쉽게 정리해보겠습니다.


1️⃣ FFN = 특수한 MLP

✔️ 수식으로 보는 FFN

FFN(x)=ReLU(xW+b)W+bFFN(x) = ReLU(xW₁ + b₁)W₂ + b₂

Transformer FFN은 2층으로 구성된 단순한 MLP입니다.

연산설명
1층xW+bxW₁ + b차원 확장 (보통 4배)
활성화ReLUReLU비선형성 부여
2층()W+b(·)W₂ + b원래 차원으로 복원

예시:

  • 입력: xRdmodelx ∈ ℝ^{d_{model}}
  • 은닉층: Rdffℝ^{d_{ff}} (예: dff=4×dmodeld_{ff} = 4 × d_{model})
  • 출력: Rdmodelℝ^{d_{model}}

2️⃣ FFN vs 일반적인 MLP

항목FFN (Transformer)일반적인 MLP
위치 의존성❌ 없음 (Position-wise)보통 없음
파라미터 공유✅ 동일 FFN을 모든 토큰에 적용일반적으로 아님
연산 방식각 토큰에 독립적으로 MLP 적용전체 입력에 한 번에 적용
계층 수2층 (Linear → ReLU → Linear)다양함 (2층 이상 가능)

3️⃣ Position-wise란 무엇인가?

💡 MLP 복습

MLP는 입력 벡터를 선형변환 → 활성화 → 선형변환의 과정을 거치는 기본적인 신경망입니다.

🧠 Position-wise MLP 핵심

원문 정의:

"applied to each position separately and identically"

입력 형태

X=[x,x,...,x]Rn×dmodelX = [x₁, x₂, ..., xₙ] ∈ ℝ^{n × d_{model}}
  • xxᵢ: 각 토큰 벡터
  • nn: 시퀀스 길이
  • 각 토큰에 동일한 MLP 적용

적용 방식

FFN(x)=WReLU(Wx+b)+bFFN(xᵢ) = W₂ · ReLU(W₁xᵢ + b₁) + b₂
설명의미
separately위치마다 독립 처리
identically동일한 가중치 사용
레이어 간에는 다름레이어마다 FFN 파라미터는 다름

4️⃣ 일반 MLP vs Position-wise MLP

구분일반 MLPPosition-wise MLP
입력 구조개별 벡터 or 배치시퀀스 (위치 정보 포함)
위치 인식❌ 없음✅ 위치마다 처리
처리 방식전체 입력에 한 번 적용각 토큰에 독립 적용
파라미터 공유입력 샘플 간 공유시퀀스 내 위치 간 공유

5️⃣ Shape 변화와 PyTorch 구현

🔄 Shape 변화 규칙

입력 Shape의미출력 Shape
[batch_size, d_in]벡터(batch 단위)[batch_size, d_out]
[batch_size, seq_len, d_in]시퀀스 (batch × 토큰)[batch_size, seq_len, d_out]

🔧 PyTorch 예제

일반 MLP

mlp = nn.Sequential(
    nn.Linear(d_in, d_hidden),
    nn.ReLU(),
    nn.Linear(d_hidden, d_out)
)
x = torch.randn(batch_size, d_in)
y = mlp(x)  # 출력: [batch_size, d_out]

Transformer FFN

ffn = nn.Sequential(
    nn.Linear(d_model, d_ff),
    nn.ReLU(),
    nn.Linear(d_ff, d_model)
)
x = torch.randn(batch_size, seq_len, d_model)
y = ffn(x)  # 출력: [batch_size, seq_len, d_model]

📌 nn.Linear는 마지막 차원에만 적용되므로, 각 토큰에 자동으로 position-wise로 처리됩니다.


6️⃣ 시각적 예시

Position-wise MLP 적용 과정

입력 시퀀스:
x₁ (512차원)
x₂ (512차원)
x₃ (512차원)

↓ 동일한 MLP 적용

FFN(x₁)
FFN(x₂)
FFN(x₃)

↓ 출력 시퀀스

y₁ (512차원)
y₂ (512차원)
y₃ (512차원)

✅ 요약

항목내용
FFN 구조Linear → ReLU → Linear (2층 MLP)
핵심 개념Position-wise MLP
처리 방식각 토큰에 동일한 FFN 독립 적용
효율성병렬 처리 가능, shape 보존
PyTorch 구현nn.Linear 사용으로 간단 구현 가능

🎯 마무리

Transformer의 FFN은 단순하지만 강력한 구조입니다.
이를 Position-wise MLP로 이해하면, Transformer의 병렬성과 효율성의 핵심도 함께 이해할 수 있습니다.

profile
AI developer

0개의 댓글