Teacher 모델의 레이어 특징값(feature)을 사용.


Logit-based: KL Loss 추가
Feature-based: MSE Loss 추가
: Teacher의 중간 레이어 feature
: Student의 중간 레이어 feature
두 feature 는 차원이 다를 수 있기 때문에 Student 가 regressor layer 사용함.
이 때 Loss 를 추가함
=


import torch
import torch.nn as nn
import torch.nn.functional as F
# 예시 Teacher 모델
class Teacher(nn.Module):
def __init__(self, num_classes=10):
super(Teacher, self).__init__()
# CNN 부분
self.features = nn.Sequential(
nn.Conv2d(3, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# MLP 부분
self.classifier = nn.Sequential(
nn.Linear(32 * 8 * 8, 512), # 입력 크기는 입력 이미지 크기에 따라 달라집니다(여기선 예시로 32 x 8 x 8).
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, num_classes)
)
def forward(self, x):
# CNN 특징 추출
features = self.features(x) # shape: (N, 32, H/4, W/4)
# 분류기 통과를 위해 펼치기
# 예시: 입력 이미지가 32x32 라고 가정하면,
# features의 크기는 (N, 32, 8, 8) -> 펼친 크기 (N, 32*8*8) = (N, 2048)
flattened = features.view(features.size(0), -1)
# 최종 로짓 계산
logits = self.classifier(flattened)
# logit, 중간 feature 모두 반환
return logits, features
# 예시 Student 모델
class Student(nn.Module):
def __init__(self, num_classes=10):
super(Student, self).__init__()
# CNN 부분
self.features = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# MLP 부분
self.classifier = nn.Sequential(
nn.Linear(16 * 8 * 8, 256), # 위와 마찬가지로 입력 이미지 크기에 맞춰 조정
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_classes)
)
# 중간 피처 회귀(regression) 부분
# Teacher의 중간 피처 채널 수(32)와 Student의 피처 채널 수(16)를 맞춰주기 위해 16->32로 변환
self.regressor = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3, padding=1),
)
def forward(self, x):
# Student CNN 특징 추출
features = self.features(x) # shape: (N, 16, H/4, W/4)
# 로짓 계산
flattened = features.view(features.size(0), -1)
logits = self.classifier(flattened)
# 중간 피처를 Teacher 피처와 형상이 맞도록 변환
regressed_features = self.regressor(features) # shape: (N, 32, H/4, W/4)
# 로짓, 변환된 중간 피처 반환
return logits, regressed_features
if __name__ == "__main__":
# 예시 사용
teacher = Teacher(num_classes=10)
student = Student(num_classes=10)
# 가상의 입력(32x32 RGB 이미지, 배치 크기=2) 예시
x = torch.randn(2, 3, 32, 32)
# Teacher forward
t_logits, t_features = teacher(x)
print(f"Teacher logits shape: {t_logits.shape}")
print(f"Teacher features shape: {t_features.shape}")
# Student forward
s_logits, s_regressed = student(x)
print(f"Student logits shape: {s_logits.shape}")
print(f"Student regressed features shape: {s_regressed.shape}")
위 코드의 flowchart
Teacher model flow
┌─────────────────────┐
│ (Input) x │
└─────────┬───────────┘
│ 3채널 입력
▼
┌───────────────────────────────────────┐
│ Teacher CNN (3 → 128 → 64 → ... → 32)│
└───────────────────────────────────────┘
│ (N, 32, H, W)
▼ (중간 피처, features)
┌───────────────────────────┐
│ Flatten (e.g. Nx2048) │
└───────────────────────────┘
▼
┌───────────────────────────┐
│ Classifier (Linear+ReLU) │
└───────────────────────────┘
│ (N, num_classes)
▼
(logits = Teacher의 예측)
return → (logits, features)
Student model flow
┌─────────────────────┐
│ (Input) x │
└─────────┬───────────┘
│ 3채널 입력
▼
┌───────────────────────────────────────┐
│Student CNN (3 → 16 → 16 → ... → 16ch)│
└───────────────────────────────────────┘
│ (N, 16, H, W)
▼ (중간 피처)
┌───────────────────────────┐
│ Flatten (e.g. Nx1024) │
└───────────────────────────┘
▼
┌───────────────────────────┐
│ Classifier (Linear+ReLU) │
└───────────────────────────┘
│ (N, num_classes)
▼
(logits = Student의 예측)
│
│ (지식 증류를 위해 Teacher의 32채널과 맞춰야 함)
▼
┌─────────────────────────────────┐
│ Regressor (Conv2d 16→32) │
└─────────────────────────────────┘
│ (N, 32, H, W)
▼
(regressed_features)
return → (logits, regressed_features)
Black-box 란, 모델의 내부(레이어, 파라미터 등) 추론 과정 등은 알 수 없고, 입력에 따른 결과만 접근 가능한 모델
다른 에이전트의 행동을 관찰하고 이를 모방하여 자신의 정책을 학습하는 기계 학습 방법

(1) 모방 데이터 수집

Seed 질문 설정: Teacher 모델에게 어떠한 질문을 할 것인지?
지식 추출: 어떻게 Teacher 모델로부터 더 유의미한 답변(지식)을 추출할지?
(2) 데이터전처리
(3) 모델 학습

품질 검증이 끝난 데이터를 활용한 모델 재학습


