Distilling the Knowledge in a Neural Network[.,2015]

tjdcjffff·2021년 12월 29일
1

Knowledge-distillation

목록 보기
1/4
post-thumbnail

Model compression방법으로 knowledge distillation를 설명하도록 하겠습니다.

Knowledge distillation은 teacher network와 student network의 ensemble을 기반으로 한 방법이라 설명할 수 있습니다. 여기서 knowledge distillation이란 teacher network(capacity가 큰 모델)를 학습한 이후, 이 모델로부터 student network(teacher network보다 capacity가 작은 모델)에게 knowledge를 transfer함을 의미합니다.

우선 Knowledge Distillation이 어떻게 동작하는지 순차적으로 설명하도록 하겠습니다.

Distillation loss

첫번째, distillation loss에 대해서 설명하도록 하겠습니다.

👉 input → Teacher model → output(logits) → softmax(temperature scaling) : soft labels
   input → Student model → output(logits) → softmax(temperature scaling) : soft predictions

본 연구에서는 기존에 많이 다뤄지는 softmax function에 temperature scaling을 적용하였습니다. temperature scaling에 관하여 간략하게 설명하도록 하겠습니다. 수식은 아래과 같습니다.

Temperature scaling = exp(xi)jexp(xj)exp(xi/T)jexp(xj/T)\frac{\exp(x_i)}{ \sum_{j} \exp(x_j)} \to \frac{\exp(x_i/T)}{ \sum_{j} \exp(x_j/T)}

  • 이때,  T\ T가 커질수록 좀 더 soft한 확률 분포를 얻게 됩니다.

좀 더 쉬운 이해를 돕기위해, 임의의 logit값으로 softmax/softmax with temperature scaling 각각에 대해 시각화를 진행 해보았습니다. 왼쪽 그림을 보시면 정확한 수치를 파악하기 힘들 정도로 값이 작은 그래프가 존재합니다.

그러나 후자의 경우, 값들이 좀더 smooth해짐을 알 수 있습니다. 학습에 더 잘 반영할 수 있도록 (정보를 좀 더 잘 전달하기 위해서) temperature scaling을 적용하였습니다.
Temperature scaling은 모든 class에 대해 single scalar parameter  T\ T를 logit vector xi\mathcal{x_i} 에 나눠주는 방법입니다.
Distillation loss의 경우 teacher model로 학습했을때의 output과 student model로 학습했을때의 output의 확률분포를 최소화 하는 방향으로 학습합니다. 수식은 다음과 같습니다.

Distillation loss = Lce(σ(ZtT),σ(ZsT))Distillation\ loss \ = \ L_{ce}(\sigma({\frac{Z_{t}}{T}}),\sigma(\frac{Z_s}{T}))

σ\sigma는 softmax, ZtZ_{t}/ZsZ_{s}는 각각 teacher/student model의 output logits,  T\ T는 temperature를 의미합니다.

Student loss

두번째, student loss에 대해서 설명하도록 하겠습니다.

👉 input → Student model → output(logits) → softmax(t=1) : hard predictions
   hard label(ground truth : one-hot)

Student loss는 hard predictions, hard label간의 loss를 최소화 하는 방향으로 학습을 진행합니다. 수식은 다음과 같습니다.

Student loss = Lce(σ(Zs),y^)Student \ loss \ = \ L_{ce}(\sigma({{Z_{s}}}),\hat{y})

ZsZ_{s}는 Student model의 output logits을 의미하고 y^\hat{y}는 ground truth를 의미합니다.

Framework

최종적으로 distillation loss, student loss를 더하여 total loss값을 최소화 하는 방향으로 학습을 진행합니다. total loss는 다음과 같습니다.

Total loss = (1α)Lce(σ(ZtT),σ(ZsT))+2αT2Lce(σ(Zs),y^)Total\ loss \ = \ (1-\alpha)L_{ce}(\sigma({\frac{Z_{t}}{T}}),\sigma(\frac{Z_s}{T})) + 2\alpha T^{2}L_{ce}(\sigma({{Z_{s}}}),\hat{y})

α\alpha : distillation loss / student loss에 가하는 weight hyperparameter

Conclusions

  • Knowledge distillation이 어떻게 동작하는지 확인
  • Knowledge distillation이 가져오는 effect : reduce inference time
profile
김성철

0개의 댓글