머신러닝 - Teacher Forcing

Ann Jongmin·2025년 2월 7일

머신러닝

목록 보기
1/1

Teacher Forcing 기법이란?

Teacher Forcing은 RNN(Recurrent Neural Network) 및 Transformer 모델의 디코더 학습 과정에서 적용되는 기법으로, 이전 시점의 예측값 대신 실제 정답(ground truth)을 디코더 입력으로 사용하는 방법이다.

✅ Teacher Forcing의 목적

  • 학습을 더 빠르게 수렴시킨다.
  • 초기 학습 과정에서 모델이 틀린 예측을 연속적으로 반복하는 문제를 방지한다.

Teacher Forcing의 동작 방식

(1) Teacher Forcing을 적용하지 않은 경우 (Autoregressive 방식)
모델이 이전 시점의 자신의 예측값을 디코더 입력으로 사용하며, 반복적으로 다음 시점을 예측한다.

t = 1: 디코더 입력 → 모델이 예측한 값 ŷ₁ → t = 2의 입력으로 사용
t = 2: 디코더 입력 → 모델이 예측한 값 ŷ₂ → t = 3의 입력으로 사용

✅ 문제점

  • 초기 학습 단계에서 잘못된 예측이 누적되어 학습이 어려워질 가능성이 있다.
  • 학습 속도가 느려질 수 있다.

(2) Teacher Forcing을 적용한 경우
학습 중에는 실제 정답(ground truth)을 디코더 입력으로 사용하여 모델이 더 안정적으로 학습할 수 있도록 한다.

t = 1: 디코더 입력 → 실제 정답 y₁ → t = 2의 입력으로 사용
t = 2: 디코더 입력 → 실제 정답 y₂ → t = 3의 입력으로 사용

✅ 장점

  • 빠른 수렴 (학습 속도 증가)
  • 초기 학습 안정화 (잘못된 예측이 누적되지 않음)

❌ 단점

  • 학습과 실제 테스트(Inference) 환경이 다름 → 테스트 시에는 실제 정답이 없으므로, 모델이 적응하지 못할 수 있다.
  • 이를 해결하기 위해 Teacher Forcing 비율을 점진적으로 줄이는 기법(Scheduled Sampling)을 사용하기도 함.






학습 전체 코드

def train_model(model, train_loader, device, epochs, learning_rate=0.001, teacher_forcing_ratio=0.5):
    model.train()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=20, gamma=0.5)
    
    for epoch in range(epochs):
        epoch_loss = 0
        for src, tgt in train_loader:
            src = src.to(device)   # [batch, seq_length, input_dim]
            tgt = tgt.to(device)   # [batch, pred_length, input_dim]
            
            src = src.transpose(0, 1)  # [seq_length, batch, input_dim] <- Transformer 입력 형태로 reshape
            batch_size = src.size(1) # decode 함수에서 사용하기 위해 batch_size 저장
            input_dim = src.size(2) # decode 함수에서 사용하기 위해 input_dim 저장
            pred_length = tgt.size(1) # decode 함수에서 사용하기 위해 pred_length 저장
            
            # teacher forcing 적용 여부 결정
            use_teacher_forcing = True if np.random.rand() < teacher_forcing_ratio else False
            
            if use_teacher_forcing:
                # 디코더 입력: 시작 토큰(0벡터) + 타겟 시퀀스의 앞쪽 토큰들
                start_token = torch.zeros(1, batch_size, input_dim).to(device)
                tgt_transposed = tgt.transpose(0, 1)  # [pred_length, batch, input_dim]
                decoder_input = torch.cat([start_token, tgt_transposed[:-1]], dim=0)  # [pred_length + 1, batch, input_dim]

                # Slicing tgt_transposed[:-1] is equivalent to removing the last element of tgt_transposed
                # 슬라이싱 문법 [start:end]을 사용하면, 기본적으로 첫 번째 차원(시간 차원)이 조작됩니다.
                # tgt_transposed : [pred_length, batch, input_dim]
                # tgt_transposed[:-1] : [pred_length-1, batch, input_dim]

                # decoder_input : start_token(0벡터) + tgt_transposed[:-1] = [pred_length, batch, input_dim]
                # decoder_input : [1, batch, input_dim] + [pred_length-1, batch, input_dim] = [pred_length, batch, input_dim]
            else:
                # teacher forcing 없이 0벡터만 사용
                decoder_input = torch.zeros(pred_length, batch_size, input_dim).to(device)

                # decoder_input : [pred_length, batch, input_dim]
            
            optimizer.zero_grad()
            output = model(src, decoder_input)  # [pred_length, batch, input_dim]
            loss = criterion(output, tgt.transpose(0, 1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]
        print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(train_loader):.6f}, LR: {current_lr:.6f}")






코드에서 Teacher Forcing을 어떻게 적용하는지

배치가 실행될때 마다 디코더에 Start_token을 입력하는 구조로 되어있으며, use_teacher_forcing가 true인지, false인지에 따라 Teacher Forcing을 사용할지 결정된다.

for epoch in range(epochs):
        epoch_loss = 0
        for src, tgt in train_loader: # batch가 돌아감 (src, tgt)
            src = src.to(device)   # [batch, seq_length, input_dim]
            tgt = tgt.to(device)   # [batch, pred_length, input_dim]
            
            src = src.transpose(0, 1)  # [seq_length, batch, input_dim]
            batch_size = src.size(1)
            input_dim = src.size(2)
            pred_length = tgt.size(1)
            
            # Teacher Forcing 결정
            use_teacher_forcing = True if np.random.rand() < teacher_forcing_ratio else False
            if use_teacher_forcing:
                start_token = torch.zeros(1, batch_size, input_dim).to(device)
                tgt_transposed = tgt.transpose(0, 1)  # [pred_length, batch, input_dim]
                decoder_input = torch.cat([start_token, tgt_transposed[:-1]], dim=0) # 배치마다 디코더 입력에 start_token을 추가
            else:
                decoder_input = torch.zeros(pred_length, batch_size, input_dim).to(device) # 배치마다 디코더 입력에 start_token을 추가
            
            optimizer.zero_grad()
            output = model(src, decoder_input)  # [pred_length, batch, input_dim]
            loss = criterion(output, tgt.transpose(0, 1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()






왜 Time Step을 1씩 Shift할까

✅ Transformer 디코더는 자기회귀적(Autoregressive) 방식으로 학습되기 때문에, 학습 시 디코더 입력과 타겟이 1 step 차이를 가져야 한다.

(1) 일반적인 시퀀스-투-시퀀스 학습 구조

Decoder Input (입력) → [START, y₀, y₁, y₂, ..., yₙ₋₂]
Target (정답)       → [y₀, y₁, y₂, ..., yₙ₋₁]

💡 즉, 디코더의 입력(Decoder Input)은 타겟보다 한 스텝 앞서 있어야 한다.

(2) Time Step을 Shift하지 않는 경우 (잘못된 학습 구조)

Decoder Input (입력) → [y₀, y₁, y₂, ..., yₙ₋₁]
Target (정답)       → [y₀, y₁, y₂, ..., yₙ₋₁]

❌ 문제점

  • 모델이 미래 시점 정보를 미리 알게 되어, 학습이 비정상적으로 이루어질 수 있다.
  • 모델이 올바르게 예측하는 것이 아니라, 정답을 그대로 베끼는 형태가 될 가능성이 있다.

(3) Time Step을 1씩 Shift하는 경우 (올바른 학습 구조)

start_token = torch.zeros(1, batch_size, input_dim).to(device)
tgt_transposed = tgt.transpose(0, 1)  # [pred_length, batch, input_dim]
decoder_input = torch.cat([start_token, tgt_transposed[:-1]], dim=0)  # [pred_length, batch, input_dim]

📌 설명

  • start_token: 디코더의 첫 입력은 항상 0 벡터(시작 토큰)이어야 함.
  • tgt_transposed[:-1]: 타겟 시퀀스에서 마지막 스텝을 제거하여 Shifted Target을 생성.
  • torch.cat([start_token, tgt_transposed[:-1]], dim=0): 이전 타겟 시퀀스를 한 칸씩 밀어서 디코더 입력을 생성.

✅ 결과:

  • 디코더 입력(Decoder Input)과 타겟(Target)의 크기가 동일하게 맞춰짐.
  • 모델이 현재 시점의 입력을 보고 다음 시점의 출력을 학습하도록 유도할 수 있음.

결론

✅ Teacher Forcing은 학습 시 디코더의 입력으로 실제 정답(ground truth)을 사용하여 모델이 빠르고 안정적으로 학습하도록 도와주는 기법이다.
✅ Transformer 디코더는 자기회귀적(Autoregressive)이므로, 디코더 입력과 타겟을 1 step 차이 나도록 맞춰야 한다.
✅ 이를 위해, 디코더 입력을 만들 때 첫 입력(시작 토큰)은 0 벡터를 사용하고, 타겟 시퀀스를 한 칸씩 Shift하여 사용한다.
✅ 이 과정을 통해 학습된 모델은 예측 시에도 Autoregressive 방식으로 동작할 수 있도록 학습된다.

profile
AI Study

0개의 댓글