[Simple Review] Knowledge distillation: A good teacher is patient and consistent

Hyungseop Lee·2024년 9월 19일
0

Beyer, Lucas, et al. "Knowledge distillation: A good teacher is patient and consistent." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.


Abstract

  • computer vision 분야에서는 SOTA를 달성하는 large-scale model과 practical applications에서 사용가능한 model 간의 괴리가 점점 커지고 있다.
    이 논문에서는 이러한 문제를 해결하고 두 유형의 모델 간의 격차를 크게 줄이고자 한다.

  • 새로운 방법을 제안하려는 것이 아니라,
    large scale models을 실제로 사용할 수 있을 정도로 robust and effective recipe를 확인하는 것을 목표
    로한다.

  • 우리의 실험적 조사에서 KD가 올바르게 수행될 경우 large scale model의 크기를 줄이면서도 성능을 유지할 수 있는 강력한 도구가 될 수 있음을 입증한다.
    특히, 우리는 KD 효과에 크게 영향을 미칠 수 있는 몇 가지 암묵적인 design choices가 있다는 것을 밝혀냈다.

  • 우리의 주요한 contribution은 기존 문헌에서 명확하게 설명되지 않았던 이러한 design choices를 명시적으로 식별한 것이다.
    우리는 종합적인 실험 연구를 통해 우리의 결과를 뒷받침하고,
    다양한 vision dataset에서 강력한 결과를 입증하여,
    특히 ImageNet에서 ResNet-50 model로 SOTA를 달성하여 82.8%의 top-1 accuracy를 기록했다.


1. Introduction

  • practitioners들은 ResNet-50 또는 MobileNet과 같이, 일반적으로 much smaller models를 사용한다.
    smallest ResNet-50은 larger ResNet-50보다 download횟수가 훨씬 많다.
    결과적으로, vision에서 많은 최신의 연구 및 개발들은 real-world applications에서 그대로 이행되지 않는다.

  • 이 문제를 해결하기 위해, 우리는 다음의 task에 집중하였다 :
    특정 application과 그 app에서 잘 수행되고 있는 large model이 주어졌을 때,
    성능을 저하시키지 않으면서 model을 더 작고 효율적인 구조로 압축하는 것을 목표로 한다.
    이 작업을 목표로 하는 두 가지 널리 사용되는 paradigm은 Model PruningKnowledge Distillation이다.

    • Model Pruning은 model의 일부를 제거하여 large scale model의 size를 줄이는 방식이다.
      하지만 이 방식은 실무에서 제약이 있을 수 있다.
      첫째, ResNet에서 MobileNet으로 model family를 변경하는 것과 같은 구조 변경을 허용하지 못한다.
      둘째, 만약 model이 Group Normalization을 사용하는 경우, channel을 줄이면 channel group을 동적으로 재조정해야 하는 architecture-dependent 문제가 발생할 수 있다.
  • 대신, 우리는 이러한 단점이 없는 KD 방법에 집중한다.
    KD는 큰 model(혹은 여러 model로 구성된 ensemble model)을 작고 효율적인 student model로 "distill"하는 개념이다.
    이는 student model의 prediction(또는 internal activations)을 teacher model과 일치시키도록 강제하는 방식으로, 자연스럽게 model compression 과정에서 model family를 변경할 수 있게 해준다.
    우리는 [13]에서 제안된 original KD setup에 따르며, 이것이 올바르게 수행될 경우 매우 효과적이라는 것을 발견했다.

  • 우리는 KD를 teacher model과 student model이 구현하는 함수를 일치시키는 작업으로 해석하며, 이는 Figure 2.에서 설명된다.
    이러한 해석을 통해, 우리는 model compression을 위한 KD의 두 가지 원칙을 발견했다.

    1. 첫째, teacher와 student model은 정확히 same input image view, 구체적으로는 same crop and augmentations을 처리해야 한다.
    2. 둘째, 좋은 generalization을 위해 많은 support points에서 함수가 일치해야 한다.
      우리는 mixup의 강력한 variant를 사용하여 원래 image manifold 밖에서 support points를 생성할 수 있다.


      이를 바탕으로, 우리는 실험적으로
      consistent image views, aggressive augmentations and very long training schedule이 KD를 통한 model compression을 효과적으로 만드는 핵심이라는 것을 입증
      한다.
  • 우리의 발견이 겉보기에 단순해 보일지라도, 연구자(및 실무자)가 우리가 제안하는 design choices를 따르지 못하게 하는 여러 이유가 있을 수 있다.
    첫째, 특히 매우 큰 teacher의 경우, 계산을 절약하기 위해 image에 대한 teacher model의 activation 값을 offline으로 한 번 미리 계산해두고 싶어지는 유혹이 있다.
    그러나 우리가 보여줄 바와 같이, 이러한 고정된 teacher 방식은 효과적이지 않다.
    둘째, KD는 종종 model compression 이외의 다른 맥락에서 사용되며, 저자들은 이 경우 다르거나 심지어 상반된 design choices를 권장한다.(이는 Figure 2.에서 볼 수 있다.)
    셋째, KD는 SOTA에 달성하기 위해 매우 많은 epoch가 필요하며, 이는 일반적인 supervised training에서 사용되는 것보다 훨씬 많다.
    마지막으로, 짧은 학습에서 suboptimal로 보이는 선택이 긴 학습에서는 optimal이 되는 경우가 많으며, 그 반대도 마찬가지이다.

  • 우리의 실험 연구에서는 주로 [23]에서 ImageNet-21k dataset으로 pretrained된 large BiT-ResNet-152x2 model을 관련 dataset에 맞게 finetuning한 model을 압축하는 데 집중했다.
    우리는 standard ResNet-50 architecture로 distill했으며, BN을 Group Normalization으로 교체했다.
    또한 ImageNet dataset에서 매우 좋은 결과를 달성했다.
    총 9600 epoch의 distillation 과정을 통해, 우리는 ImageNet에서 새로운 ResNet-50의 SOTA를 기록하며 82.8%의 top-1 accuracy를 달성했다.


2. Experimental setup

  • 이 section에서는 논문 전체에서 사용하는 실험 설정과 benchmark를 소개한다.
    특정 작업에서 높은 accuracy를 가진 large scale vision model(teacher or T)이 주어졌을 때,
    우리는 성능을 저하시키지 않으면서 이 model을 훨씬 더 작은 model(student or S)로 압축하는 것을 목표로한다.
    우리의 compression recipe는 [13]에서 소개된 KD에 의존하며, 훈련 설정에서 몇 가지 주요 요소에 대한 신중한 조사를 포함한다.

Datasets, metrics and evaluation protocol.

  • 우리는 5가지 인기 있는 image classification dataset에서 실험을 수행했다 :
    flowers102 [31], pets [33], food101 [21], sun397 [48], 그리고 ILSVRC-2012 ("ImageNet") [36].
    이 dataset들은 다양한 image classification scenario를 다루며,
    특히 class 수가 37에서 1000개까지, training image의 총 수가 1020에서 1281167개까지 다양하다.
    이를 통해 우리는 다양한 practical setting에서 우리의 distillation recipe를 검증하고, 그 견고성을 보장할 수 있다.

  • metric으로는 항상 classificatino accuracy를 보고한다.
    모든 dataset에서 우리는 validation split을 사용하여 design choices and hyperparameters selection을 하고,
    test set에서 최종 결과를 보고한다. (이 splits은 appendix E에서 정의)

Teacher and student models.

  • 논문 전체에서 우리는 BiT [23]에서 제공하는 pretrained된 teacher model을 사용하기로 했다.
    이는 ILSVRC-2012 및 ImageNet-21k dataset에서 pretrained된 ResNet model의 크 collection을 제공하며, SOTA accuracy를 갖는다.
    BiT-ResNet과 standard ResNet의 주요 차이점은 group normalization layer와 weight standardization[34]를 사용하는 것이다.
    이는 batch normalization를 대신해서 사용된다.

  • 특히, 우리는 BiT-M-R152x2 architecture에 집중했다.
    이는 ImageNet-21k에서 pretrained된 BiT-ResNet-152x2(152 layers, 'x2' width multiplier를 나타냄)이다.
    student model로는 BiT-ResNet-50 variant를 사용하며, ResNet-50으로 간략히 표현한다.
    (??? student model로 standard ResNet-50 architecture but replace BN with GN을 썼다고 말했는데, 여기서는 BiT-ResNet-50 variant 사용했다 함...)
    (결국 BiT-ResNet-50 variant를 사용했다고 판단하면 될듯...)

Distillation loss.

  • 우리는 teacher model과 student model이 예측한 class probability vectors ptp_t, psp_s간의 KL-divergence를 distillation loss로 사용했다.
    이는 원래 [13]에서 소개된 방식이며, original dataset의 hard labels와 관련된 추가적인 loss term을 사용하지 않았다.

Training setup.

  • Adam optimizer with default hyperparameter exploration
  • cosine learning rate schedule without warm restarts
  • weight decay loss coefficient for all experiments
  • gradient clipping with a threshold of 1.0 on the global L2-norm of a gradient
  • batch sie 512 for all our experiments, except for models trained on ImageNet, where we train with batch size 4096.

  • 우리 recipe 중에 추가로 중요한 component는 mixup data augmentation strategy이다.
    구체적으로, 우리는 Section 3.1.1에 있는 mixup variant인 우리의 "function matching" strategy를 소개한다.
    "funcion matching"은 [0, 1] 범위에서 uniformly sampled된 "agressive" mixing coefficients를 사용한다.
    이는 원래 제안된 β\beta-distirbution에서 sampling하는 것의 극단적인 경우로 볼 수 있다.

  • 별도로 명시하지 않는 한, preprocessing에는 "inception-style" crop[39]을 사용한 후 image를 fixed square size로 resize한다.
    또한, 수천 개의 model을 훈련시키는 우리의 광범위한 분석을 computationally feasible하게 만들기 위해, 우리는 상대적으로 낮은 input resolution을 사용하고 input image를 128×128128 \times 128 size로 resize했다.
    다만, ImageNet 실험에서는 standard input 224×224224 \times 224를 사용
    했다.


3. Distillation for model compression

3.1. Investigating the "consistent and patient teacher" hypothesis

  • 이 section에서는, introduction에서 제기한 hypothesis를 실험적으로 검증했다.
    이는 Figure 2에 보여진 바와 같이, KD는 function matching으로 간주될 때 가장 잘 동작한다는 것이다.
    즉, student model과 teacher model이 (1)consistent views of the input images(일관된 입력 이미지 뷰)를 보고, (2)mixup을 통해 인위적으로 채워진 image를 사용하는 경우,
    그리고 (3)student model이 긴 training schedule(즉, "teacher"가 patient한 경우) 동안 훈련되는 경우가 효과적이라는 가설
    이다.

  • 우리의 결과가 견고함을 보장하기 위해,
    우리는 Flowers102(1020개 training images),
    Pets(3312개 training images),
    Food101(약 68k개 training images),
    SUN397(76k개 training images) 등
    4개의 small and medium scale에서 매우 철저한 분석을 수행했다.

  • 혼란을 야기할 수 있는 요인을 제거하기 위해,
    각 개별 distillation setting에서
    learning rate {0.0003, 0.001, 0.003, 0.01},
    weight decays {1·10⁻⁵, 3·10⁻⁵, 1·10⁻⁴, 3·10⁻⁴, 1·10⁻³},
    distillation temperatures {1, 2, 5, 10}의 모든 조합을 test했다.
    모든 보고된 그림에서는 각 실험을 a low opacity curve(저불투명도 곡선)으로 표시하며,
    final validation accuracy가 가장 높은 실험을 highlight했다.
    (해당 test accuracies는 Appendix A에 제공)

3.1.1 Importance of "consistent" teaching

  • 우선, 우리는 consistency criterion(즉, student와 teacher가 동일한 view를 보는 것)이 모든 dataset에서 student 성능의 최고치를 꾸준히 달성하는 유일한 방법임을 보여준다.
    이를 연구하기 위해 Figure 2.에 sketch된 4가지 옵션의 모든 instantiations을 나타내는 여러 distillation configurations을 정의하며, 동일한 color coding을 사용했다 :

    • Fixed teacher. :
      주어진 image에 대해 teacher의 prediction이 고정되는 몇 가지 옵션을 탐구했다.
      simplest(and worst) method는 fix/rs로, student와 teacher 모두에게 image를 2242224^2 pixel로 resize한다.
      fix/cc는 teacher에게는 fixed central crop을, student에게는 mild(간단한) random crop을 사용하는 일반적인 방법이다.
      fix/ic_ens는 우리가 teacher의 성능을 향상시키는 것으로 확인한
      1k개의 inception crops의 평균을 사용한 강력한 data augmentation 방법이다.
      student도 random inception crops을 사용한다.
      fix/cc와 fix/in_ens 두 setting은 "noisy student" paper[49]의 input noise strategy와 유사하다.
    • Independent noise :
      이 일반적인 전략을 두 가지 방식으로 구현한다.
      ind/rc는 teacher와 student 각각에 대해 independent mild random crops을 계산하며,
      ind/ic는 대신 더 강력한 inception crop을 사용한다.
      유사한 설정은 [41]에서 사용되었다.
    • Consistent teaching. :
      이 방법에서, 우리는 image를 한 번만 randomly crop한 후,
      mild random croping(same/rc) 또는 heavy inception crop(same/ic)를 사용하고,
      이 동일한 crop을 teacher와 student 모두에게 input으로 제공한다.
    • Function matching. :
      이 접근 방식은 consistent teaching을 확장하여,
      mixup을 통해 image의 input manifold를 확장하고,
      다시 teacher와 student 모두에게 consistent inputs을 제공한다.
      간결함을 위해, 우리는 이 접근 방식을 "FunMatch"라고 부르기도 한다.
      same/rc, mixsame/ic, mix
  • Figure 3은 Flowers102 dataset에서 10,000epoch 동안의 모든 설정에서의 training curves를 보여준다.
    이 결과는 "consistency"가 핵심임을 명확하게 보여준다.
    모든 "inconsistency" distillation settings은 더 낮은 점수에서 plateau(정체되는 구간)를 갖지만,
    consistency setting은 student performance를 크게 향상시키며, 특히 function matching 방식이 가장 효과적이다.
    또한 small dataset에서 fixed teacher를 사용하는 것은 strong overfitting을 초래한다는 것을 training loss가 보여준다.
    반면, function matching은 training set에서 과도한 loss에 도달하지 않으면서도 validation set에 훨씬 더 잘 generalizing된다.

3.1.2 Importance of "patient" teaching

  • distilation을 teacher model이 제공하는 label(잠재적으로 soft label)을 활용하는 supervised learning으로 해석할 수 있다.
    특히 teacher의 prediction이 single image view에 대해 (pre)computed된 경우에 해당한다.
    이 접근 방식은 standard supervised learning의 모든 문제를 그대로 갖고 있는데,
    예를 들어, aggressive augmentation은 실제 image label을 왜곡할 수 있고,
    less aggressive augmentation은 overfitting을 초래할 수 있다.

  • 하지만, distillation을 function matching으로 해석하고,
    결정적으로 student와 teacher에게 일관된 input을 제공한다면 상황을 달라진다.
    이 경우, image augmentation을 매우 공격적으로 수행할 수 있다.
    image view가 너무 왜곡되더라도, 여전히 이 input에서 관련된 함수들을 matching하는 방향으로 진행될 수 있다.
    따라서 우리는 augmentation에 대해 더욱 적극적일 수 있고, 적극적인 image augmentation을 수행하면서 overfitting을 피할 수 있으며,
    student의 함수가 teacher의 함수와 가까워질 때까지 매우 긴 시간 동안 optimize를 진행할 수 있다.

  • 우리는 Figure 4에서 이러한 직관을 실험적으로 확인했다.
    각 dataset에 대해, validation 기준으로 best function matching student의 학습 과정에서 test accuracy의 발전을 보여줬다.
    중요한 점은, 1M epoch 동안 optimize를 진행해도 overfitting이 발생하지 않는다는 것이다.

    또한, 두 가지 baseline model을 추가로 학습하고 조정했다.
    (1) 하나는 dataset의 원래 hard label을 사용해 처음부터 ResNet-50을 학습하는 것이고,
    (2) 다른 하나는 ImageNet-21k에서 pretrained된 ResNet-50을 transferring하는 것이다.
    두 baseline 모두에서 learning rate와 weight decay를 강력하게 tuning했다 (Section 3.1 참고)
    원래의 label을 사용해 scratch trained된 model은 우리의 student model보다 성능이 현저히 떨어졌다.
    transfer model은 훨씬 더 나은 성능을 보였지만, 결국에도 student model이 더 뛰어난 성능을 보였다.
    전반적으로, ResNet-50 student model은 ResNet-152x2 teacher를 꾸준히 그리고 성실하게 matching했다.

3.2. Scaling up to ImageNet

  • 앞서 말한 우리의 insights를 기반으로,
    우리는 more challenging ImageNet dataset에 적용했다.left에는 3가지 distillation settings : (1) fixed teacher, (2) consistent teaching and (3) function matching에 대한 student accuracy curves를 보고했다.
    (참고로, base teacher model은 83.0% top-1 accuracy를 달성)
    fixed teacher는 긴 training epoch에서 문제가 발생하였으며, 600 epoch 이후 overfitting이 시작되었다.
    반면, consistent teaching 방법은 training 기간이 길어질수록 성능이 계속 향상되었다.
    이를 통해 consistency가 ImageNet에서 distillation이 효과적으로 작동하는 데 중요한 요소임을 결론지을 수 있다.
    simple consistent teaching과 비교했을 때, function matching은 짧은 schedule에서는 성능이 약간 떨어지는데, 이는 underfitting때문일 가능성이 크다.
    하지만 training epoch을 늘리면 function matching의 개선이 명확해진다.
    예를 들어, 1200 epoch만으로도 function matching은 4800 epoch의 consistent teaching을 따라잡을 수 있으며, 이는 75%의 compute resource를 절약하는 효과를 가져왔다.
    function matching의 가장 긴 실험에서는, ResNet-50 student model이 ImageNet에서 82.31%의 Top-1 accuracy를 달성했다.

3.3. Distilling across different input resolutions

  • 지금까지 우리는 student와 teacher가 모두 동일한 standard input resolution인 224px를 받는다고 가정했다.
    하지만 여전히 일관성을 유지하면서 student와 teacher에게 다른 resolution의 image를 전달하는 것도 가능하다.
    original high-resolution image에서 crop을 수행한 후, student와 teacher에게 각각 다른 resolution으로 image를 resize하면 일관성을 유지하면서도 resolution은 다를 수 있다.
    이 insight는 더 나은, 더 높은 resolution의 teacher model로부터 학습하거나, 더 작고 더 빠른 student model을 훈련하는 데 활용될 수 있다.

  • 우리는 두 가지 방향을 모두 조사했다.

    1. [2]를 따라, input resolution이 160px인 ResNet-50 student를 훈련하면서 teacher resolution은 224px로 유지했다.
      이로 인해 model 속도는 두 배 빨라졌지만 80.49%의 top-1 accuracy를 달성했.
    2. [23]에 따라, resolution 384px에서 finetuning된 teacher를 distillation했고,
      이 teacher는 83.7%의 Top-1 accuracy를 달성했다.
      이번에는 student resolution은 변경하지 않고, 224px input image를 사용했다.
      base teacher model과 비교했을 때, 이는 모든 측면에서 꾸준히 소폭의 개선을 보였다.

3.4. Optimization: A second order preconditioner improves training efficiency

  • 우리는 "function matching" 관점에서 distillation recipe의 optimizatino difficulties가 긴 학습 일정으로 인해 computational bottleneck 현상을 일으킨다는 것을 관찰했다.
    직관적으로, 우리는 최적화의 어려움이 fixed image-level labels이 아니라
    multivariate(다변수) outputs이 있는 일반적인 함수를 맞추는 것이
    훨씬 더 어렵기 때문에 발생한다고 생각한다.

  • 이를 위해 기본 optimizer는 Adam에서 Shampoo로 변경했으며, Shampoo는 2차 preconditioner를 사용한다.Figure 5(middle)에서 Shampoo는 Adam이 4800 epoch에서 도달한 test accuracy를 단 1200 epoch만에 달성했으며,
    step time overhead도 최소화되었음을 확인했다.


내 생각

  • 아래 세가지 사항이 이 논문에서 주장하는 contribution이라고 하는데...
    1.은 당연한 이야기를 실험적으로 보여준거라 생각한다... (동일하게 preprocessing된 image를 forwarding해야 feature든 logit이든 distillation이 잘 될 것이다.)
    2.+3.는 KD에 걸맞는 data augmentation 방법을 소개함으로써
    overfitting 걱정 없는 오랜 시간 training이 결국 성능을 올렸다라는 이야기임.
    하지만 너무 많은 training epoch이 필요하기 때문에 결과가 좋더라도 현실적으로 training이 너무 힘듦.
    1. 결국은 teacher와 student를 같은 image로 훈련시켜야 한다는 사실을 실험적 증명.
    2. mixup을 발전시킨 새로운 data augmentation을 소개
    3. 오래 학습시켜야 한다.
profile
Efficient Deep Learning Model, Compression

0개의 댓글