Teacher Forcing 기법이란?
Teacher Forcing은 RNN(Recurrent Neural Network) 및 Transformer 모델의 디코더 학습 과정에서 적용되는 기법으로, 이전 시점의 예측값 대신 실제 정답(ground truth)을 디코더 입력으로 사용하는 방법이다.
✅ Teacher Forcing의 목적
(1) Teacher Forcing을 적용하지 않은 경우 (Autoregressive 방식)
모델이 이전 시점의 자신의 예측값을 디코더 입력으로 사용하며, 반복적으로 다음 시점을 예측한다.
t = 1: 디코더 입력 → 모델이 예측한 값 ŷ₁ → t = 2의 입력으로 사용
t = 2: 디코더 입력 → 모델이 예측한 값 ŷ₂ → t = 3의 입력으로 사용
✅ 문제점
(2) Teacher Forcing을 적용한 경우
학습 중에는 실제 정답(ground truth)을 디코더 입력으로 사용하여 모델이 더 안정적으로 학습할 수 있도록 한다.
t = 1: 디코더 입력 → 실제 정답 y₁ → t = 2의 입력으로 사용
t = 2: 디코더 입력 → 실제 정답 y₂ → t = 3의 입력으로 사용
✅ 장점
❌ 단점
def train_model(model, train_loader, device, epochs, learning_rate=0.001, teacher_forcing_ratio=0.5):
model.train()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
for epoch in range(epochs):
epoch_loss = 0
for src, tgt in train_loader:
src = src.to(device) # [batch, seq_length, input_dim]
tgt = tgt.to(device) # [batch, pred_length, input_dim]
src = src.transpose(0, 1) # [seq_length, batch, input_dim] <- Transformer 입력 형태로 reshape
batch_size = src.size(1) # decode 함수에서 사용하기 위해 batch_size 저장
input_dim = src.size(2) # decode 함수에서 사용하기 위해 input_dim 저장
pred_length = tgt.size(1) # decode 함수에서 사용하기 위해 pred_length 저장
# teacher forcing 적용 여부 결정
use_teacher_forcing = True if np.random.rand() < teacher_forcing_ratio else False
if use_teacher_forcing:
# 디코더 입력: 시작 토큰(0벡터) + 타겟 시퀀스의 앞쪽 토큰들
start_token = torch.zeros(1, batch_size, input_dim).to(device)
tgt_transposed = tgt.transpose(0, 1) # [pred_length, batch, input_dim]
decoder_input = torch.cat([start_token, tgt_transposed[:-1]], dim=0) # [pred_length + 1, batch, input_dim]
# Slicing tgt_transposed[:-1] is equivalent to removing the last element of tgt_transposed
# 슬라이싱 문법 [start:end]을 사용하면, 기본적으로 첫 번째 차원(시간 차원)이 조작됩니다.
# tgt_transposed : [pred_length, batch, input_dim]
# tgt_transposed[:-1] : [pred_length-1, batch, input_dim]
# decoder_input : start_token(0벡터) + tgt_transposed[:-1] = [pred_length, batch, input_dim]
# decoder_input : [1, batch, input_dim] + [pred_length-1, batch, input_dim] = [pred_length, batch, input_dim]
else:
# teacher forcing 없이 0벡터만 사용
decoder_input = torch.zeros(pred_length, batch_size, input_dim).to(device)
# decoder_input : [pred_length, batch, input_dim]
optimizer.zero_grad()
output = model(src, decoder_input) # [pred_length, batch, input_dim]
loss = criterion(output, tgt.transpose(0, 1))
loss.backward()
optimizer.step()
epoch_loss += loss.item()
scheduler.step()
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(train_loader):.6f}, LR: {current_lr:.6f}")
배치가 실행될때 마다 디코더에 Start_token을 입력하는 구조로 되어있으며, use_teacher_forcing가 true인지, false인지에 따라 Teacher Forcing을 사용할지 결정된다.
for epoch in range(epochs):
epoch_loss = 0
for src, tgt in train_loader: # batch가 돌아감 (src, tgt)
src = src.to(device) # [batch, seq_length, input_dim]
tgt = tgt.to(device) # [batch, pred_length, input_dim]
src = src.transpose(0, 1) # [seq_length, batch, input_dim]
batch_size = src.size(1)
input_dim = src.size(2)
pred_length = tgt.size(1)
# Teacher Forcing 결정
use_teacher_forcing = True if np.random.rand() < teacher_forcing_ratio else False
if use_teacher_forcing:
start_token = torch.zeros(1, batch_size, input_dim).to(device)
tgt_transposed = tgt.transpose(0, 1) # [pred_length, batch, input_dim]
decoder_input = torch.cat([start_token, tgt_transposed[:-1]], dim=0) # 배치마다 디코더 입력에 start_token을 추가
else:
decoder_input = torch.zeros(pred_length, batch_size, input_dim).to(device) # 배치마다 디코더 입력에 start_token을 추가
optimizer.zero_grad()
output = model(src, decoder_input) # [pred_length, batch, input_dim]
loss = criterion(output, tgt.transpose(0, 1))
loss.backward()
optimizer.step()
epoch_loss += loss.item()
✅ Transformer 디코더는 자기회귀적(Autoregressive) 방식으로 학습되기 때문에, 학습 시 디코더 입력과 타겟이 1 step 차이를 가져야 한다.
(1) 일반적인 시퀀스-투-시퀀스 학습 구조
Decoder Input (입력) → [START, y₀, y₁, y₂, ..., yₙ₋₂]
Target (정답) → [y₀, y₁, y₂, ..., yₙ₋₁]
💡 즉, 디코더의 입력(Decoder Input)은 타겟보다 한 스텝 앞서 있어야 한다.
(2) Time Step을 Shift하지 않는 경우 (잘못된 학습 구조)
Decoder Input (입력) → [y₀, y₁, y₂, ..., yₙ₋₁]
Target (정답) → [y₀, y₁, y₂, ..., yₙ₋₁]
❌ 문제점
(3) Time Step을 1씩 Shift하는 경우 (올바른 학습 구조)
start_token = torch.zeros(1, batch_size, input_dim).to(device)
tgt_transposed = tgt.transpose(0, 1) # [pred_length, batch, input_dim]
decoder_input = torch.cat([start_token, tgt_transposed[:-1]], dim=0) # [pred_length, batch, input_dim]
📌 설명
✅ 결과:
✅ Teacher Forcing은 학습 시 디코더의 입력으로 실제 정답(ground truth)을 사용하여 모델이 빠르고 안정적으로 학습하도록 도와주는 기법이다.
✅ Transformer 디코더는 자기회귀적(Autoregressive)이므로, 디코더 입력과 타겟을 1 step 차이 나도록 맞춰야 한다.
✅ 이를 위해, 디코더 입력을 만들 때 첫 입력(시작 토큰)은 0 벡터를 사용하고, 타겟 시퀀스를 한 칸씩 Shift하여 사용한다.
✅ 이 과정을 통해 학습된 모델은 예측 시에도 Autoregressive 방식으로 동작할 수 있도록 학습된다.