[UROP #11] Tailoring Instructions to Student’s Learning Levels Boosts Knowledge Distillation (2)

윤하은·2023년 12월 26일
1

UROP

목록 보기
11/18

Github : https://github.com/twinkle0331/LGTM

Paper : https://arxiv.org/abs/2305.09651

이전글 👉🏻 [논문 읽기 #1] Tailoring Instructions to Student’s Learning Levels Boosts Knowledge Distillation (1)

💡 로 표시된 부분은 제가 이해한 내용을 적은 부분입니다. 오류가 있다면 댓글로 남겨주세요 🙏🏻



5. Experiments



5.1 Experimental Setup


Datasets

GLUE의 텍스트 분류 작업에서 제안된 접근 방식을 평가한다. MRPC, RTE, SST-2, MNLI, QNLI, QQP가 포함된다. MRPC 및 QQP의 경우 F1 및 accuracy를 모두 보고하고 다른 데이터셋에 대해서는 accuracy를 보고한다.


Baselines

  • KD(Hinton et al., 2015)

  • PKD(Sun et al., 2019)

  • SKD(Guo et al., 2022)

  • DIST(Huang et al., 2022)

  • TAKD(Mirzadeh et al., 2020)

  • RCO(Jin et al., 2019)

  • DML(Zhang et al., 2018)

  • ProKT(Shi et al., 2020)

  • PESF-KD(Rao et al., 2022)

  • Meta Distill(Zhou et al., 2022)



Training setup

이전 연구들을 따라가며,

BERTBase를 6-layer BERT 모델로 증류하여 훈련한다. 모든 two-stage baselines에 대해 각 작업에 모델을 fine-tune했다. 공정한 비교를 위해 Meta Distill과 LGTM은 증류 손실의 계산에 검증 세트로부터의 피드백을 활용했다. 자세한 훈련 하이퍼파라미터는 부록 D에서 확인할 수 있다.





5.2 Comparison with Meta Distillation


LGTM은 Meta Distillation과 밀접한 관련이 있기 때문에 먼저 LGTM과 Meta Distill간의 비교를 수행하여 distillation influence를 채택하는 것의 이점을 입증했다.


(a), (b)에서 볼 수 있듯이, 학생 모델의 검증 손실이 나중의 반복에서 점차 증가하는 반면 검증 정확도는 안정적인 평지에 도달할 때까지 지속적으로 향상되었다. 이것은 명백하게 학생 모델이 과적합되고 있다는 것을 나타낸다. 하나의 가능한 설명은 높은 손실을 생성하는 특정 훈련 샘플에 과도하게 중점을 두는 것, 즉 어려운 샘플이나 이상치일 수 있다. 이는 학생 모델의 일반화 능력에 부정적인 영향을 미치며, 과적합으로 이어진다.


  • LGTM : 샘플의 디스틸레이션 영향을 고려하여 일반화 성능에 부정적인 영향을 미치는 샘플을 식별하고 제거함으로써, 학생 모델이 더 나은 성능을 보일 수 있도록 도와준다.

  • Meta Distill : 배치 내의 모든 훈련 샘플을 동등하게 취급하므로 세밀한 조절이 이루어지지 않아, 학습 중에 모델이 부정확한 방향으로 학습될 가능성이 높아질 수 있다.
    

교사 모델은 학생에게 현재 지식을 전달하는 것뿐만 아니라 새로운 정보와 관점을 찾아내어 자신의 이해를 향상시켜야 한다. (c)에서 볼 수 있듯이 LGTM은 교사 보조 손실을 통합함으로써 지식을 효과적으로 전달한다. LGTM의 경우 교사 모델의 검증 정확도가 계속해서 향상되지만, Meta Distill의 경우 빠르게 하락한다.





5.3 Main Results


GLUE 벤치마크의 텍스트 분류 작업의 테스트 세트에서의 결과다. LGTM은 최근 강력한 KD 방법과 10개의 기준선을 모두 능가했다.


더 구체적으로, PKD, SKD, DIST와 같이 정교하게 설계된 훈련 파이프라인이나 손실 함수에 의존하는 모델들과 비교하여 최고 수준의 성능을 달성했다. PKD는 교사 모델의 여러 중간 레이어에서 학생이 증분적인 지식을 추출할 수 있도록 두 가지 디스틸레이션 체계를 제안한다. SKD와 DIST는 두 모델 간의 차이를 줄이기 위해 KL-divergence 손실의 형태를 수정한다. LGTM은 또한 TAKD와 RCO처럼 여러 교사 보조 모델의 시리즈를 필요로하지 않는다.


online distillation 방법과 비교하여, LGTM은 DML, ProKT, PESF-KD보다 우수한 성능을 보인다. 이는 훈련 과정 중에 학생의 피드백을 통합하는 중요성을 강조한다. 훈련 세트로부터의 지식 전달을 지나치게 강조하면 학생이 교사의 출력에 과적합되어 일반화 능력이 감소할 수 있다.



또한, meta distillation 방법과 달리, 개별 훈련 샘플의 증류 영향을 계산할 수 있어 학생의 일반화에 해를 끼칠 수 있는 샘플을 걸러낼 수 있다. 따라서 LGTM은 학생이 과적합 문제를 완화하면서 전반적인 작업에 대한 일반적인 이해를 발전시키도록 도와준다.





5.4 Analysis of Distillation Influence


실제 훈련 과정에서 샘플의 디스틸레이션 영향의 추이를 더 자세히 탐구했다. MRPC 데이터셋에서 실험을 진행했다. 이 작업은 문장 쌍에서 문장이 의미적으로 동등한지 예측하는 것이다.

위의 두 대표적인 샘플을 선택하여 디스틸레이션 영향의 추이와 교사 및 학생 예측 간의 관계를 시각화했다.


왼쪽에서, 훈련 초기 단계에서 교사와 학생이 모두 잘못된 예측을 했다는 것을 볼 수 있다. 이는 이 샘플이 모델 양쪽에 중요한 도전을 제기할 수 있다는 것을 나타낸다. 이 경우, 우리는 학생 모델이 이 샘플에 대한 교사 모델의 출력을 너무 많이 모방하지 않기를 원한다. 교사 모델도 이 샘플에 대해 잘못된 예측을 하고 있기 때문이다.


LGTM은 손실 가중치를 부정적으로 점차 조정할 수 있어, 일단은 이 오해를 일으키는 훈련 샘플을 걸러내고 양쪽 모델이 더 빨리 학습하도록 할 수 있다. 결과적으로, 학생 모델은 먼저 이 궁지에서 벗어날 것이다. 그런 다음 검증 세트에서 학생의 피드백을 통해 교사 모델도 올바른 예측을 하도록 배우게 된다. 마지막으로 훈련이 진행됨에 따라 학생과 교사가 이 샘플을 올바르게 분류할 수 있는 것으로 관찰되며, 디스틸레이션 영향은 거의 제로에 안정화된다.

왜 학생 모델이 먼저 벗어나지.. 🤔


오른쪽에서는 학생과 교사가 특정 샘플을 정확하게 예측할 수 있는 다른 예시를 제시한다. 이는 이 샘플이 교사와 학생 모두에게 너무 쉬운 것일 수 있다는 것을 나타낸다. 이 경우에는 이 샘플에 높은 양의 긍정적 가중치를 주어 학생 친화적인 결정 경계를 형성하고자 한다. 이는 curriculum learning에서 쉬운 샘플부터 어려운 샘플까지 학습하기 위한 것과 유사하다.



또한 MRPC에서 무작위로 선택한 64개의 샘플을 기반으로 한 디스틸레이션 영향의 평균 추이를 시각화했다. distillation influence는 훈련의 초기와 끝에서는 일반적으로 미미하며, 중간에서는 변동이 있다. 이는 LGTM이 훈련 중에 각 샘플에 다양한 가중치를 할당하여 어려운 샘플을 걸러내고 일반화에 더 적합한 샘플에 중점을 두기 위해 적용되기 때문에 합리적인 현상으로 볼 수 있다.





5.5 Ablation Study


Finite difference approximation

4장에서 각 샘플의 증류 영향을 추정하기 위해 유한 차분 근사(FDA)를 소개했다. 이것은 각 샘플의 그라디언트를 계산하는 느린 속도를 해결하기 위해 설계되었다.

MRPC 데이터셋에서 FDA의 유용성을 평가하기 위한 연쇄 실험을 수행했다. FDA를 사용하면 훈련을 완료하는 데 11분만 소요되며, FDA 없이 단순한 훈련은 117분이 소요된다.

훈련 시간에서의 이러한 큰 감소(즉, 10배 이상의 가속)는 제안된 FDA 기술의 계산 효율성을 강조한다. 또한 MRPC 데이터셋의 검증 세트에서의 성능을 평가하고 FDA로 훈련할 경우 F1 스코어가 90.4로, FDA 없이 훈련할 경우 90.7로 나타난다. 근사치로 인한 성능 하락은 매우 작다.



Distillation loss

지식 증류의 맥락에서는 다른 증류 손실이 있다. 위 표는 LGTM이 이러한 목표에 적응할 수 있는지를 평가한 것이다. DIST에서 사용된 수정된 KL-divergence 손실과 일반적인 평균 제곱 오차(MSE)를 고려한다. LGTM은 이러한 증류 목표를 활용하는 원래 방법들을 일관되게 능가하며, 다양한 증류 목표에 대한 호환성을 확인했다.



Student model size

교사 및 학생 모델 간의 용량 차이가 더 큰 시나리오에서 성능을 평가하기 위해 실험을 수행했다. 구체적으로 BERT-Base 모델에서 4-layer BERT 모델로의 지식 증류를 수행했다.


표 4에서 볼 수 있듯이, LGTM은 대부분의 작업에서 다른 기준선을 일관되게 능가했다. SST-2에서 경쟁적인 결과를 제외하고 대체로 우수한 성능을 보인다. 이는 LGTM의 견고성을 나타내며 다양한 지식 증류 설정에서의 폭넓은 사용 가능성을 시사한다.





6. Related Work


지식 증류의 핵심은 지식을 교사에서 학생으로 어떻게 정의하고 전달하는지에 달려 있다. 일반적으로 세 가지 주요 측면이 고려된다:


  1. 지식 전이의 출발점이 되는 교사 모델(학습 대상)

  2. 모델이 훈련되는 데이터(학습 자료)

  3. 학습 목표를 정의하는 목적 함수


지식 디스틸레이션을 더 학생 친화적으로 만들기 위해 이러한 측면의 어려움을 줄이는 데 관한 노력들이 있었다.




학습 대상 측면에서,

  • 교사와 학생 모델 간의 격차를 줄이기 위해 중간 타임스텝이나 훈련 시간 타임스텝의 교사 보조 모델을 도입한다.

  • 교사와 학생을 함께 업데이트하여 교사가 학생의 상태를 인식하도록 제안한다.

  • 더 많은 타임스텝 동안 훈련하여 교사의 분포를 부드럽게 만들어 전송을 쉽게 만든다.



학습 자료 측면에서,

  • 훈련 데이터를 다양하게 만들기 위해 데이터를 보강한다.

  • 교사에게는 쉬운 샘플이지만 학생에게는 어려운 샘플로 학생을 훈련시킨다.



학습 목표 측면에서,

  • 가장 일반적인 접근 방식은 교사 및 학생 모델의 확률적 예측 점수를 KL-divergence를 사용하여 일치시키는 것이다. 그러나 이는 훈련 중에 문제를 일으켜 성능이 떨어지게 할 수 있다.

  • 보다 허용 범위가 넓은 손실을 사용하여 제약을 완화한다.

  • 교사 모델의 최적화 목표로 학생의 성능을 사용하여 교사가 학생의 피드백을 기반으로 지식 전이를 최적화할 수 있도록한다.

  • 적절한 지식을 선택하여 학생의 최적화를 안내한다.





7. Conclusion


전반적으로, LGTM은 교사 모델이 학생 모델의 능력에 적응하고 더 개인화된 지도를 제공하면서 학생 모델의 일반화 능력을 향상시킬 수 있었다.


지식 증류에서 여러 가지 학습 방법론을 재검토한 후, 각 훈련 샘플로부터 지식을 증류하는 것이 학생의 일반화 능력에 어떤 영향을 미치는지 결정하기 위해 distillation influence를 제안했다. 증류 영향을 사용한 간단한 가중치 조정이 학생 훈련에 도움이 되는 것을 확인할 수 있다.


증류 영향을 기반으로 하여, learning to teach 프레임워크인 LGTM을 제안했다. LGTM은 GLUE 벤치마크의 텍스트 분류 작업에서 기존의 지식 증류 방법보다 일관되게 우수한 성능을 보인다.





7.1 Limitations


  • LGTM은 테스크별 지식 증류에서 우수한 성능을 보여주었지만, LGTM을 pre-training KD와 결합하는 잠재적 이점을 조사하는 것이 가치가 있다.

  • 사전 훈련된 언어 모델에 대해 상대적으로 간단한 텍스트 분류 작업으로 제한되었지만, 향후 연구에서는 LGTM을 더 복잡한 텍스트 생성 작업에 적용하는 것을 탐구해 볼 수 있다.





0개의 댓글