오늘은 딥러닝 학습 방법 중 하나인 Knowledge Distillation, 또는 지식 증류에 대해 알아보겠다. 이름에서 알 수 있듯, 하나의 모델의 지식을 다른 모델로 전달하는 방식이다. 그 구체적인 방법을 알아보고 pytorch로 간단하게 구현해보도록 하겠다.
복잡하고 큰 모델 (Teacher Model)이 학습한 지식을 상대적으로 간단하고 작은 모델 (Student Model)에 전달하는 기법이다. 이 방법은 작은 모델이 큰 모델의 성능을 최대한 따라가도록 훈련할 때 사용된다.

다만 Transfer Learning (전이학습)과는 다른 개념이다. Transfer learning은 parameter를 일부 복사하여 하위 task에 대해 fine-tuning을 하는 방식이지만, knowledge distillation은 parameter 복사 없이 teacher model의 예측값 자체를 label로 사용하여 학습한다.
Teacher model의 좋은 성능과 지식을 copy 하는 것으로, 모델 경량화 기법이라 볼 수 있다.
전이학습과 지식증류를 구체적으로 비교하면 다음과 같다:
| Knowledge Distillation | Transfer Learning | |
|---|---|---|
| 목적 | Teacher 모델의 지식을 Student 모델로 전달 | 기존 모델의 지식을 새로운 작업에 활용 |
| 효과 | 모델 압축 및 경량화 | 새로운 작업에서 빠른 학습 (데이터 부족 시 유용) |
| 모델 | Teacher와 Student 모델 구조가 다를 수 있음 | 기존 모델 구조를 유지하거나 약간 수정 |
| 데이터 | Teacher 모델과 동일한 데이터로 학습 | 새로운 작업에 맞는 데이터 필요 |
이제 구체적으로 knowledge distillation을 어떻게 하는지 알아보자.
Teacher model은 학습 데이터의 복잡한 패턴과 높은 수준의 지식을 학습한다. 이를 직접 Student model에 적용하기 어렵기 때문에, knowledge distillation에서는 다음과 같은 과정이 진행된다:
1) Teacher Model 학습
크고 복잡한 모델(예: ResNet-152, GPT-3 등)을 먼저 학습한다. 이 모델은 높은 정확도와 강력한 추론 능력을 가지고 있다.
2) Soft Target 생성
Teacher Model은 입력 데이터를 예측할 때, 단순히 정답 레이블만 예측하지 않고, 클래스 간 확률 분포 (Soft Target)도 생성한다.
예를 들어, 동물 분류 문제에서 Teacher Model의 출력이
[고양이: 0.8, 개: 0.15, 말: 0.05]
와 같은 확률 분포를 보인다고 가정하자.
이 분포는 단순히
[고양이: 1, 개: 0, 말: 0]
와 같은 Hard Target (데이터셋의 true label) 보다 깊은 의미를 지닌다.
어떤 고양이는 개와 유사한 모습일 수도 있다. 이러한 Hard Target은 그런 정보를 전혀 반영하지 않고 그저 고양이라고 단정적으로 판단한 것이다.
반면 Soft Target은 데이터 간의 미묘한 관계를 파악하여 pure dataset 보다 더 높은 수준의 지식과 insight를 가진다는 것을 알 수 있다.
Student Model의 학습 label은 바로 Teacher Model의 출력, 즉 Soft Target이다. 데이터셋으로만 학습하는 것보다 더 유연하고 깊은 수준의 학습을 할 수 있다.
이는 마치 학생이 그냥 교과서만 보고 공부하는 것이 아니라 선생님의 심도있는 지식과 인사이트를 습득하는 것과 비슷하다!
3) Student Model 훈련
Student Model은 Teacher Model이 제공한 Soft Target과 원래의 정답 레이블(Hard Target)을 조합하여 학습한다.
따라서 distillation loss는 두 개의 loss가 결합된 형태인데,
1) soft loss
2) hard loss
가 일정한 가중치로 더해진다.
먼저 soft loss는 일반적으로 Kullback-Leibler Divergence 함수를 통해 Soft Target과 Student 모델의 예측값 간 차이를 최소화한다. 이때 Soft target은 모델의 출력값을 temperature으로 나눈 뒤 softmax 하여 만들어진다.
** T가 클수록 label이 더 soft 해진다는 것을 알 수 있다.
soft target:
Hard loss는 일반적으로 볼 수 있는 Cross Entropy Loss를 활용한 함수이다. Supervised learning 할 때처럼 Student 모델 예측값과 데이터셋 true label (hard target) 간 차이를 최소화한다.
최종적인 loss는 위 두 loss를 적절한 비율로 더한 것으로, 일반적으로 soft loss 가중치가 높을수록 좋다고 한다.
4) 다양한 방식
앞서 설명한 방식은 Teacher 모델의 output에 기반하여 학습하는 'response-based KD'라 불린다. 그 외에도 Teacher 모델의 지식을 전달하는 확장된 기법이 존재한다.
먼저 'feature-based KD'란 Teacher 모델의 중간 레이어에서 추출한 정보를 Student 모델이 학습하도록 하는 방법이다. Teacher 모델의 특정 레이어의 출력 (Feature Map)과 Student 모델의 해당 레이어 출력을 비교하는데, MSE 또는 L2 Loss로 그 차이를 줄인다.
다음으로 'relation-based KD'는 Teacher 모델이 학습한 '데이터들 간 관계 정보'를 전달하는 것이다. 즉 입력 데이터들 간 관계를 Teacher 모델이 학습했을 때, 이를 Student 모델도 학습하도록 유도한다. Distance-wise, angle-wise loss를 사용하여 학습한다.

이제 실제로 Knowledge Distillation 학습 과정을 구현해 보았다. CIFAR-100 데이터로 사전훈련된 ResNet-50을 Teacher 모델로 하고, 그보다 작은 모델인 ResNet-18을 Student 모델로 삼아 훈련하였다. 마찬가지로 CIFAR-100 데이터로 진행하였으며, 기본적인 response-based KD를 구현했다.
동일한 세팅에서 50 에폭으로 훈련한 결과, KD를 적용한 모델이 그렇지 않은 모델보다 약 6%의 성능 향상을 보여주었다.
**다만, loss 함수가 다르기 때문인지 학습 시간은 상당한 차이가 있었다.
깃허브
먼저 huggingface에 있는 사전훈련된 ResNet-50을 로드해주었다. 약 80%의 accuracy를 보인다고 한다.
import detectors
import timm
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
teacher = timm.create_model("resnet50_cifar100", pretrained=True)
teacher.to(device)
teacher.eval()
이후 ResNet-18을 Student 모델로 정의했는데, CIFAR-100 데이터에 맞춰서 살짝 모델 구조를 변형했다. 자세한 내용은 이 글을 참고하길 바란다.
그 다음 가장 중요한 KD loss를 구현했다. 위에서 설명한 수식을 그대로 반영했는데, Temperature는 4로 두어 더 soft하게 만들었으며 가중치는 0.7로 soft loss의 비중이 더 크도록 했다.
# Knowledge distillation loss
def knowledge_distillation_loss(student_logits, labels, teacher_logits, alpha = 0.3, T = 4):
hard_loss = F.cross_entropy(input=student_logits, target=labels)
soft_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(student_logits/T, dim=1), F.softmax(teacher_logits/T, dim=1)) * (T ** 2)
total_loss = alpha * hard_loss + (1-alpha) * soft_loss
return total_loss
학습 시엔 각 모델의 출력값과 실제 target을 모두 loss function에 전달한다.
for epoch in range(epochs):
n_iter = 0
loss_total = 0
acc_total = 0
for i, data in enumerate(tqdm(cifar100_train_loader)):
inputs, targets = data[0].to(device), data[1].to(device)
teacher_output = teacher(inputs)
student_output = student_model(inputs)
optimizer.zero_grad()
loss = knowledge_distillation_loss(student_output, targets, teacher_output)
loss.backward()
optimizer.step()
상세 코드: https://github.com/tony3ynot/Knowledge_Distillation_CIFAR-100/tree/main
Knowledge Distillation에 대해 정리해보고 실제로 구현하면서 생각보다 간단하지만 강력한 방법이라고 느꼈다. 무엇보다 학부생 입장에서 컴퓨팅 자원과 시간을 많이 쓰기는 어려운데, 이 방법으로 모델을 경량화하면서도 상당히 높은 성능을 낼 수 있어서 좋은 것 같다.
특히 학생이 선생님에게 배우는 방식을 딥러닝에 적용한 점이 흥미로우면서도 이해가 잘된 것 같다!
서울대학교 김태섭 교수님의 MLDL2 강의 자료
Hinton, et al. "Distilling the Knowledge in a Neural Network". 2014.
A Gentle Introduction to Hint Learning & Knowledge Distillation