Beyer, Lucas, et al. "Knowledge distillation: A good teacher is patient and consistent." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
computer vision 분야에서는 SOTA를 달성하는 large-scale model과 practical applications에서 사용가능한 model 간의 괴리가 점점 커지고 있다.
이 논문에서는 이러한 문제를 해결하고 두 유형의 모델 간의 격차를 크게 줄이고자 한다.
새로운 방법을 제안하려는 것이 아니라,
large scale models을 실제로 사용할 수 있을 정도로 robust and effective recipe를 확인하는 것을 목표로한다.
우리의 실험적 조사에서 KD가 올바르게 수행될 경우 large scale model의 크기를 줄이면서도 성능을 유지할 수 있는 강력한 도구가 될 수 있음을 입증한다.
특히, 우리는 KD 효과에 크게 영향을 미칠 수 있는 몇 가지 암묵적인 design choices가 있다는 것을 밝혀냈다.
우리의 주요한 contribution은 기존 문헌에서 명확하게 설명되지 않았던 이러한 design choices를 명시적으로 식별한 것이다.
우리는 종합적인 실험 연구를 통해 우리의 결과를 뒷받침하고,
다양한 vision dataset에서 강력한 결과를 입증하여,
특히 ImageNet에서 ResNet-50 model로 SOTA를 달성하여 82.8%의 top-1 accuracy를 기록했다.
practitioners들은 ResNet-50 또는 MobileNet과 같이, 일반적으로 much smaller models를 사용한다.
smallest ResNet-50은 larger ResNet-50보다 download횟수가 훨씬 많다.
결과적으로, vision에서 많은 최신의 연구 및 개발들은 real-world applications에서 그대로 이행되지 않는다.
이 문제를 해결하기 위해, 우리는 다음의 task에 집중하였다 :
특정 application과 그 app에서 잘 수행되고 있는 large model이 주어졌을 때,
성능을 저하시키지 않으면서 model을 더 작고 효율적인 구조로 압축하는 것을 목표로 한다.
이 작업을 목표로 하는 두 가지 널리 사용되는 paradigm은 Model Pruning과 Knowledge Distillation이다.
대신, 우리는 이러한 단점이 없는 KD 방법에 집중한다.
KD는 큰 model(혹은 여러 model로 구성된 ensemble model)을 작고 효율적인 student model로 "distill"하는 개념이다.
이는 student model의 prediction(또는 internal activations)을 teacher model과 일치시키도록 강제하는 방식으로, 자연스럽게 model compression 과정에서 model family를 변경할 수 있게 해준다.
우리는 [13]에서 제안된 original KD setup에 따르며, 이것이 올바르게 수행될 경우 매우 효과적이라는 것을 발견했다.
우리는 KD를 teacher model과 student model이 구현하는 함수를 일치시키는 작업으로 해석하며, 이는 Figure 2.에서 설명된다.
이러한 해석을 통해, 우리는 model compression을 위한 KD의 두 가지 원칙을 발견했다.
same crop and augmentations을 처리해야 한다
.mixup의 강력한 variant를 사용
하여 원래 image manifold 밖에서 support points를 생성할 수 있다.consistent image views
, aggressive augmentations
and very long training schedule
이 KD를 통한 model compression을 효과적으로 만드는 핵심이라는 것을 입증한다.우리의 발견이 겉보기에 단순해 보일지라도, 연구자(및 실무자)가 우리가 제안하는 design choices를 따르지 못하게 하는 여러 이유가 있을 수 있다.
첫째, 특히 매우 큰 teacher의 경우, 계산을 절약하기 위해 image에 대한 teacher model의 activation 값을 offline으로 한 번 미리 계산해두고 싶어지는 유혹이 있다.
그러나 우리가 보여줄 바와 같이, 이러한 고정된 teacher 방식은 효과적이지 않다.
둘째, KD는 종종 model compression 이외의 다른 맥락에서 사용되며, 저자들은 이 경우 다르거나 심지어 상반된 design choices를 권장한다.(이는 Figure 2.에서 볼 수 있다.)
셋째, KD는 SOTA에 달성하기 위해 매우 많은 epoch가 필요하며, 이는 일반적인 supervised training에서 사용되는 것보다 훨씬 많다.
마지막으로, 짧은 학습에서 suboptimal로 보이는 선택이 긴 학습에서는 optimal이 되는 경우가 많으며, 그 반대도 마찬가지이다.
우리의 실험 연구에서는 주로 [23]에서 ImageNet-21k dataset으로 pretrained된 large BiT-ResNet-152x2 model을 관련 dataset에 맞게 finetuning한 model을 압축하는 데 집중했다.
우리는 standard ResNet-50 architecture로 distill했으며, BN을 Group Normalization으로 교체했다.
또한 ImageNet dataset에서 매우 좋은 결과를 달성했다.
총 9600 epoch의 distillation 과정을 통해, 우리는 ImageNet에서 새로운 ResNet-50의 SOTA를 기록하며 82.8%의 top-1 accuracy를 달성했다.
우리는 5가지 인기 있는 image classification dataset에서 실험을 수행했다 :
flowers102 [31], pets [33], food101 [21], sun397 [48], 그리고 ILSVRC-2012 ("ImageNet") [36].
이 dataset들은 다양한 image classification scenario를 다루며,
특히 class 수가 37에서 1000개까지, training image의 총 수가 1020에서 1281167개까지 다양하다.
이를 통해 우리는 다양한 practical setting에서 우리의 distillation recipe를 검증하고, 그 견고성을 보장할 수 있다.
metric으로는 항상 classificatino accuracy를 보고한다.
모든 dataset에서 우리는 validation split을 사용하여 design choices and hyperparameters selection을 하고,
test set에서 최종 결과를 보고한다. (이 splits은 appendix E에서 정의)
논문 전체에서 우리는 BiT [23]에서 제공하는 pretrained된 teacher model을 사용하기로 했다.
이는 ILSVRC-2012 및 ImageNet-21k dataset에서 pretrained된 ResNet model의 크 collection을 제공하며, SOTA accuracy를 갖는다.
BiT-ResNet과 standard ResNet의 주요 차이점은 group normalization layer와 weight standardization[34]를 사용하는 것이다.
이는 batch normalization를 대신해서 사용된다.
특히, 우리는 BiT-M-R152x2 architecture에 집중했다.
이는 ImageNet-21k에서 pretrained된 BiT-ResNet-152x2(152 layers, 'x2' width multiplier를 나타냄)이다.
student model로는 BiT-ResNet-50 variant를 사용하며, ResNet-50으로 간략히 표현한다.
(??? student model로 standard ResNet-50 architecture but replace BN with GN을 썼다고 말했는데, 여기서는 BiT-ResNet-50 variant 사용했다 함...)
(결국 BiT-ResNet-50 variant를 사용했다고 판단하면 될듯...)
우리 recipe 중에 추가로 중요한 component는 mixup data augmentation strategy이다.
구체적으로, 우리는 Section 3.1.1에 있는 mixup variant인 우리의 "function matching" strategy를 소개한다.
"funcion matching"은 [0, 1] 범위에서 uniformly sampled된 "agressive" mixing coefficients를 사용한다.
이는 원래 제안된 -distirbution에서 sampling하는 것의 극단적인 경우로 볼 수 있다.
별도로 명시하지 않는 한, preprocessing에는 "inception-style" crop[39]을 사용한 후 image를 fixed square size로 resize한다.
또한, 수천 개의 model을 훈련시키는 우리의 광범위한 분석을 computationally feasible하게 만들기 위해, 우리는 상대적으로 낮은 input resolution을 사용하고 input image를 size로 resize했다.
다만, ImageNet 실험에서는 standard input 를 사용했다.
이 section에서는, introduction에서 제기한 hypothesis를 실험적으로 검증했다.
이는 Figure 2에 보여진 바와 같이, KD는 function matching으로 간주될 때 가장 잘 동작한다는 것이다.
즉, student model과 teacher model이 (1)consistent views of the input images(일관된 입력 이미지 뷰)를 보고, (2)mixup을 통해 인위적으로 채워진 image를 사용하는 경우,
그리고 (3)student model이 긴 training schedule(즉, "teacher"가 patient한 경우) 동안 훈련되는 경우가 효과적이라는 가설이다.
우리의 결과가 견고함을 보장하기 위해,
우리는 Flowers102(1020개 training images),
Pets(3312개 training images),
Food101(약 68k개 training images),
SUN397(76k개 training images) 등
4개의 small and medium scale에서 매우 철저한 분석을 수행했다.
혼란을 야기할 수 있는 요인을 제거하기 위해,
각 개별 distillation setting에서
learning rate {0.0003, 0.001, 0.003, 0.01},
weight decays {1·10⁻⁵, 3·10⁻⁵, 1·10⁻⁴, 3·10⁻⁴, 1·10⁻³},
distillation temperatures {1, 2, 5, 10}의 모든 조합을 test했다.
모든 보고된 그림에서는 각 실험을 a low opacity curve(저불투명도 곡선)으로 표시하며,
final validation accuracy가 가장 높은 실험을 highlight했다.
(해당 test accuracies는 Appendix A에 제공)
우선, 우리는 consistency criterion(즉, student와 teacher가 동일한 view를 보는 것)이 모든 dataset에서 student 성능의 최고치를 꾸준히 달성하는 유일한 방법임을 보여준다.
이를 연구하기 위해 Figure 2.에 sketch된 4가지 옵션의 모든 instantiations을 나타내는 여러 distillation configurations을 정의하며, 동일한 color coding을 사용했다 :
Fixed teacher.
:Independent noise
:Consistent teaching.
:Function matching.
:Figure 3은 Flowers102 dataset에서 10,000epoch 동안의 모든 설정에서의 training curves를 보여준다.
이 결과는 "consistency"가 핵심임을 명확하게 보여준다.
모든 "inconsistency" distillation settings은 더 낮은 점수에서 plateau(정체되는 구간)를 갖지만,
consistency setting은 student performance를 크게 향상시키며, 특히 function matching 방식이 가장 효과적이다.
또한 small dataset에서 fixed teacher를 사용하는 것은 strong overfitting을 초래한다는 것을 training loss가 보여준다.
반면, function matching은 training set에서 과도한 loss에 도달하지 않으면서도 validation set에 훨씬 더 잘 generalizing된다.
distilation을 teacher model이 제공하는 label(잠재적으로 soft label)을 활용하는 supervised learning으로 해석할 수 있다.
특히 teacher의 prediction이 single image view에 대해 (pre)computed된 경우에 해당한다.
이 접근 방식은 standard supervised learning의 모든 문제를 그대로 갖고 있는데,
예를 들어, aggressive augmentation은 실제 image label을 왜곡할 수 있고,
less aggressive augmentation은 overfitting을 초래할 수 있다.
하지만, distillation을 function matching으로 해석하고,
결정적으로 student와 teacher에게 일관된 input을 제공한다면 상황을 달라진다.
이 경우, image augmentation을 매우 공격적으로 수행할 수 있다.
image view가 너무 왜곡되더라도, 여전히 이 input에서 관련된 함수들을 matching하는 방향으로 진행될 수 있다.
따라서 우리는 augmentation에 대해 더욱 적극적일 수 있고, 적극적인 image augmentation을 수행하면서 overfitting을 피할 수 있으며,
student의 함수가 teacher의 함수와 가까워질 때까지 매우 긴 시간 동안 optimize를 진행할 수 있다.
우리는 Figure 4에서 이러한 직관을 실험적으로 확인했다.
각 dataset에 대해, validation 기준으로 best function matching student의 학습 과정에서 test accuracy의 발전을 보여줬다.
중요한 점은, 1M epoch 동안 optimize를 진행해도 overfitting이 발생하지 않는다는 것이다.
또한, 두 가지 baseline model을 추가로 학습하고 조정했다.
(1) 하나는 dataset의 원래 hard label을 사용해 처음부터 ResNet-50을 학습하는 것이고,
(2) 다른 하나는 ImageNet-21k에서 pretrained된 ResNet-50을 transferring하는 것이다.
두 baseline 모두에서 learning rate와 weight decay를 강력하게 tuning했다 (Section 3.1 참고)
원래의 label을 사용해 scratch trained된 model은 우리의 student model보다 성능이 현저히 떨어졌다.
transfer model은 훨씬 더 나은 성능을 보였지만, 결국에도 student model이 더 뛰어난 성능을 보였다.
전반적으로, ResNet-50 student model은 ResNet-152x2 teacher를 꾸준히 그리고 성실하게 matching했다.
지금까지 우리는 student와 teacher가 모두 동일한 standard input resolution인 224px를 받는다고 가정했다.
하지만 여전히 일관성을 유지하면서 student와 teacher에게 다른 resolution의 image를 전달하는 것도 가능하다.
original high-resolution image에서 crop을 수행한 후, student와 teacher에게 각각 다른 resolution으로 image를 resize하면 일관성을 유지하면서도 resolution은 다를 수 있다.
이 insight는 더 나은, 더 높은 resolution의 teacher model로부터 학습하거나, 더 작고 더 빠른 student model을 훈련하는 데 활용될 수 있다.
우리는 두 가지 방향을 모두 조사했다.
우리는 "function matching" 관점에서 distillation recipe의 optimizatino difficulties가 긴 학습 일정으로 인해 computational bottleneck 현상을 일으킨다는 것을 관찰했다.
직관적으로, 우리는 최적화의 어려움이 fixed image-level labels이 아니라
multivariate(다변수) outputs이 있는 일반적인 함수를 맞추는 것이
훨씬 더 어렵기 때문에 발생한다고 생각한다.
이를 위해 기본 optimizer는 Adam에서 Shampoo로 변경했으며, Shampoo는 2차 preconditioner를 사용한다.Figure 5(middle)에서 Shampoo는 Adam이 4800 epoch에서 도달한 test accuracy를 단 1200 epoch만에 달성했으며,
step time overhead도 최소화되었음을 확인했다.