Knowledge Distillation(KD; 지식 증류)은 고성능의 Teacher 모델로부터 지식을 전달 받아서 Student 모델을 학습 시키는 기법이다. 이는 파라미터 수가 많고 성능은 좋지만 속도가 느린 teacher 모델을 성능은 보통이지만 파라미터 수가 적고 속도가 빠른 student 모델로 가볍게 만들어 메모리 사용을 줄이고 연산 속도를 높일 수 있는 방법이다.
KD는 크게 knowledge와 transparency의 관점에서 접근할 수 있다.

증류할 지식의 종류에 따라 Response-based KD와 Feature-based KD로 구분할 수 있다.

teacher 모델의 중간 레이어의 feature/representation을 배운다. 그러나 보통 teacher 모델의 중간 레이어 feature map과 student 모델의 중간 레이어 feature map은 차원이 다르기 때문에 student 모델에 regressor layer를 도입하여 차원을 맞추게 된다. 이 때 두 feature map이 유사해지도록 MSE loss를 추가한다.
모델 내부 구조/파라미터 열람 가능 여부에 따라 white-box, gray-box, black-box로 구분할 수 있다.
teacher 모델의 내부 구조, 파라미터 등을 알 수 있는 경우
teacher 모델의 output 및 최종 logit 값 등 제한된 정보만을 알 수 있는 경우
teacher 모델의 내부(레이어, 파라미터 등), 추론 과정은 알 수 없고 입력에 따른 결과만 접근 가능할 때, student 모델이 입력에 따른 결과 출력 행동을 관찰하고 이를 모방하여 자신의 정책을 학습하는 방법이다. 모델의 출력을 입력 데이터로 학습하기 때문에 teacher 모델이 오류가 있는 예측을 했을 때 그 영향을 받는다는 단점이 있다. 하지만 데이터 수집 비용이 저렴하고, 이 지식이 인간이 해석가능한 형태라는 장점이 있다.
Step 1. 모방 데이터 수집
teacher 모델에게 할 질문(seed 질문)을 설정하고, 더 유의미한 답변을 얻기 위해 prompt engineering을 수행한다. (ex) 반품 정책은 어떻게 되나요? → 반품 절차가 몇 단계로 이루어져 있는지 각 단계별로 자세히 설명해주세요.)
Step 2. 데이터 전처리
의미 없는 대화, 매우 짧거나 불충분한 답변, hallucination 등을 제거하고 특정 유형의 질문-답변이 지나치게 많거나 적지 않도록 균형을 맞춘다.
Step 3. 모델 학습
수집하여 전처리한 데이터를 이용해 student 모델을 학습시킨다.
코드 부분 실습 참고해서 다시 작성
# Teacher 모델 정의
class Teacher(nn.Module):
pass
teacher = Teacher(num_classes=10)
print(sum(p.numel() for p in teacher.parameters())) # teacher 모델의 파라미터 수
# Student 모델 정의 (Teacher과 유사하지만 shallow)
class Student(nn.Module):
pass
student = Student(num_classes=10)
print(print(sum(p.numel() for p in teacher.parameters())) # student 모델의 파라미터 수
# teacher 모델 학습 & 테스트
# 분류 문제를 Cross-Entropy Loss로 학습 (hard label -> 0 or 1)
train(teacher, train_data, epochs=10, learning_rate=1e3)
test(teacher, test_data)
# student 모델 학습 & 테스트
# 분류 문제를 Cross-Entropy Loss로 학습 (hard label) + teacher 모델의 확률 값을 KL divergence loss로 학습 (soft label)
train(student, train_data, epochs=10, learning_rate=1e3)
test(student, test_data)
def