AI_Tech부스트캠프 week16...[5] Knowledge distillation(2)

Leejaegun·2024년 12월 27일

AI_tech_CV트랙 여정

목록 보기
64/74

1. Feature-based KD

Teacher 모델의 레이어 특징값(feature)을 사용.

1.1 Feature-based

logit-based 와 다른점


Logit-based: KL Loss 추가
Feature-based: MSE Loss 추가

Regressor Layer 추가

fTf_T : Teacher의 중간 레이어 feature
fSf_S : Student의 중간 레이어 feature

두 feature 는 차원이 다를 수 있기 때문에 Student 가 regressor layer RR 사용함.
R:RdT    RdS\quad R: \mathbb{R}^{d_T} \;\to\; \mathbb{R}^{d_S}

이 때 Loss 를 추가함
LMSE\mathcal L_{\text MSE} = MSE(fT,R(fS))\text MSE(f_T, R(f_S))

  • Teacher Loss : LCE\mathcal L_{\text CE}
  • Student Loss : LCEλCE\mathcal L_{\text CE}\lambda_{\text CE} +LMSEλMSE\mathcal L_{\text MSE}\lambda_{\text MSE}

1.2 Pytorch Tutorial

import torch
import torch.nn as nn
import torch.nn.functional as F

# 예시 Teacher 모델
class Teacher(nn.Module):
    def __init__(self, num_classes=10):
        super(Teacher, self).__init__()
        # CNN 부분
        self.features = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # MLP 부분
        self.classifier = nn.Sequential(
            nn.Linear(32 * 8 * 8, 512),  # 입력 크기는 입력 이미지 크기에 따라 달라집니다(여기선 예시로 32 x 8 x 8).
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # CNN 특징 추출
        features = self.features(x)  # shape: (N, 32, H/4, W/4)
        
        # 분류기 통과를 위해 펼치기
        # 예시: 입력 이미지가 32x32 라고 가정하면,
        # features의 크기는 (N, 32, 8, 8) -> 펼친 크기 (N, 32*8*8) = (N, 2048)
        flattened = features.view(features.size(0), -1)
        
        # 최종 로짓 계산
        logits = self.classifier(flattened)
        
        # logit, 중간 feature 모두 반환
        return logits, features

# 예시 Student 모델
class Student(nn.Module):
    def __init__(self, num_classes=10):
        super(Student, self).__init__()
        # CNN 부분
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        # MLP 부분
        self.classifier = nn.Sequential(
            nn.Linear(16 * 8 * 8, 256),  # 위와 마찬가지로 입력 이미지 크기에 맞춰 조정
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )
        # 중간 피처 회귀(regression) 부분
        # Teacher의 중간 피처 채널 수(32)와 Student의 피처 채널 수(16)를 맞춰주기 위해 16->32로 변환
        self.regressor = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
        )

    def forward(self, x):
        # Student CNN 특징 추출
        features = self.features(x)  # shape: (N, 16, H/4, W/4)
        
        # 로짓 계산
        flattened = features.view(features.size(0), -1)  
        logits = self.classifier(flattened)
        
        # 중간 피처를 Teacher 피처와 형상이 맞도록 변환
        regressed_features = self.regressor(features)  # shape: (N, 32, H/4, W/4)
        
        # 로짓, 변환된 중간 피처 반환
        return logits, regressed_features


if __name__ == "__main__":
    # 예시 사용
    teacher = Teacher(num_classes=10)
    student = Student(num_classes=10)

    # 가상의 입력(32x32 RGB 이미지, 배치 크기=2) 예시
    x = torch.randn(2, 3, 32, 32)

    # Teacher forward
    t_logits, t_features = teacher(x)
    print(f"Teacher logits shape: {t_logits.shape}")
    print(f"Teacher features shape: {t_features.shape}")

    # Student forward
    s_logits, s_regressed = student(x)
    print(f"Student logits shape: {s_logits.shape}")
    print(f"Student regressed features shape: {s_regressed.shape}")

위 코드의 flowchart

               
               Teacher model flow
               ┌─────────────────────┐
               │   (Input) x        │
               └─────────┬───────────┘
                         │ 3채널 입력
                         ▼
     ┌───────────────────────────────────────┐
     │ Teacher CNN (3 → 128 → 64 → ... → 32)│
     └───────────────────────────────────────┘
                         │ (N, 32, H, W)(중간 피처, features)
               ┌───────────────────────────┐
               │ Flatten (e.g. Nx2048)    │
               └───────────────────────────┘
                         ▼
               ┌───────────────────────────┐
               │ Classifier (Linear+ReLU) │
               └───────────────────────────┘
                         │ (N, num_classes)(logits = Teacher의 예측)

return → (logits, features)



								Student model flow
               ┌─────────────────────┐
               │   (Input) x        │
               └─────────┬───────────┘
                         │ 3채널 입력
                         ▼
     ┌───────────────────────────────────────┐
     │Student CNN (3 → 16 → 16 → ... → 16ch)│
     └───────────────────────────────────────┘
                         │ (N, 16, H, W)(중간 피처)
               ┌───────────────────────────┐
               │ Flatten (e.g. Nx1024)    │
               └───────────────────────────┘
                         ▼
               ┌───────────────────────────┐
               │ Classifier (Linear+ReLU) │
               └───────────────────────────┘
                         │ (N, num_classes)(logits = Student의 예측)
                         │
                         │  (지식 증류를 위해 Teacher의 32채널과 맞춰야 함)
                         ▼
         ┌─────────────────────────────────┐
         │ Regressor (Conv2d 16→32)       │
         └─────────────────────────────────┘
                         │ (N, 32, H, W)(regressed_features)

return → (logits, regressed_features)

2. Imitation Learning

2.1 Black-box KD

Black-box 란, 모델의 내부(레이어, 파라미터 등) 추론 과정 등은 알 수 없고, 입력에 따른 결과만 접근 가능한 모델

2.2 Imitation learning

다른 에이전트의 행동을 관찰하고 이를 모방하여 자신의 정책을 학습하는 기계 학습 방법

(1) 모방 데이터 수집

Seed 질문 설정: Teacher 모델에게 어떠한 질문을 할 것인지?
지식 추출: 어떻게 Teacher 모델로부터 더 유의미한 답변(지식)을 추출할지?

(2) 데이터전처리

  • 의미 없는 대화, 매우 짧거나 불충분한 답변은 제거.
  • Hallucination이 존재하는지 검증 (RAG 활용 등…)
  • 특정 유형의 질문-답변 페어가 지나치게 많거나 적지 않도록 균형.
  • 다양한 검증 방법 존재

(3) 모델 학습

품질 검증이 끝난 데이터를 활용한 모델 재학습

3. Other Variants

3.1 Variants

Mutil-teacher

  • 여러개의 Teacher로부터 평균적으로 학습
  • Ensemble method의 일종

Cross-modal


profile
Lee_AA

0개의 댓글