Overcoming Catastrophic Forgetting beyond Continual Learning: Balanced Training for NMT

jihyelee·2023년 2월 6일
0

Overcoming Catastrophic Forgetting beyond Continual Learning: Balanced Training for NMT
ACL 2022

분야 및 배경지식

Neural Machine Translaton (NMT), Catastrophic Forgetting, Knowledge Distillation

  • Catastrophic Forgetting
    • 흔히 continual learning(연속학습)에서 많이 발생하는 문제로, 새로 지식을 배우면서 이전에 배운 지식을 까먹는 현상
  • Knowledge Distillation
    • 사전학습된 선생(teacher) 네트워크로부터 학생(student) 네트워크로 지식을 전이하는 방법
    • 실제 레이블(ground-truth)과 모델의 결과확률값 사이의 cross-entropy loss를 최소화하는 대신에, 선생 네트워크의 예측값을 soft target으로 삼아 loss를 최소화
    • 이를 통해 학생 네트워크가 선생 네트워크의 예측을 모방

문제점

  • Catastrophic Forgetting이 전통적인 기계번역 학습에서도 발생
  • Imbalanced Training
    • 학습 샘플에 불균등한 attention을 부여 (가장 최근에 학습한 데이터에 더 많은 attention)
    • 특히 low-resource NMT에서 심각하게 발생 (태스크 복잡도와 긴밀한 연관성, low-resource dataset일 경우 더욱 심각)
    • 불균형 학습의 원인은 mini-batch gradient descent

해결책

Complementary Online Knowledge Distillation (COKD)

  • checkpoint averaging technique의 경우 checkpoint의 구간을 얼마로 설정해야 할지가 어려움 (경험적 선택)
  • 보완(complementary)을 위한 선생(teacher) 네트워크들을 구축
    • 학습 데이터셋을 상호 배타적인 서브셋으로 구분
    • 학생 네트워크가 학습하지 않은 데이터셋들이 각기 다른 데이터셋 순서를 갖도록 선생 네트워크들을 학습
    • 이를 통해 학생(student) 네트워크가 잊은 지식을 다시 제공
  • 비용함수
    • Knowledge Distillation Loss(word-level)와 Cross-Entropy Loss를 동시에 학습
    • 알파값으로 비율 조정
  • two-way knowledge transfer
    • complementary teachers가 upperbound가 되지 않도록, 선생 모델을 학생 모델의 파라미터로 재초기화하는 방식으로 선생-학습 사이의 지식전이가 지속적으로 발생하도록 함

평가

  • 데이터셋
    • WMT14 En-De(high-resource), WMT17 En-TR, IWSLT15 En-Vi, TED bilingual (low-resource)
  • 평가기준
    • case-sensitive Sacre-BLEU(WMT17 En-TR, IWSLT15 En-Vi), tokenized BLEU(TED bilingual), tokenized BLUE with compound split(WMT14 En-De)
  • 모델
    • joint BPE model
  • 결과
    • low-resource일 경우 COKD에서 3 BLEU score 상승, high-resource일 경우 COKD의 개선 정도가 상대적으로 작음

한계

  • 선생 네트워크의 개수가 많다고 해서 성능이 향상되는 것이 아님 (main improvement is not due to the ensemble of multiple teachers)
    • 해당 논문에서는 위와 같은 연유로 main experiment를 1개의 선생 네트워크를 이용해 진행하였으나, 선생 네트워크 개수와 성능 사이의 상관관계에 대해 충분히 설명하지 않음
  • 비용함수의 알파값을 얼마로 지정하는 것이 가장 적절한지 실험을 통해 heuristic하게 제시하였으나, 해당 데이터셋에 국한되지는 않은지 확인 필요(data-agnostic한지)

의의

  • checkpoint averaging technique의 정확도 개선 이유를 설명
    • chckpoint averaging technique이란 마지막 몇 체크포인트의 평균을 최종 모델로써 활용하는 방식
    • 해당 방식이 잘 작동하는 internal mechanism이 이전에는 충분히 설명되지 않았으나, 이 논문에서는 불균형 학습(imbalanced training)의 경감 덕분일 수 있음을 밝힘 (random noise의 variance를 줄여줌)
  • loss가 지속적으로 감소하다가 마지막에 갑자기 증가하는 일반적이지 않은 현상을 설명
    • adam optimizer는 momentum의 형태로 gradient을 갖고 있고, 해당 momentum은 이후 몇 단계의 gradient update에 영향을 미침. 하지만 마지막 몇 학습 단계에서는 momentum이 완전히 사용되지 않기 때문에 loss가 마지막에 상승한다고 설명
    • 해당 논문에서 이 부분이 중요한 것은 아니나, 자신들이 한 실험이나 결과에 대해 충분한 설명을 제공하는 방식이 유의미하다고 밝힘
  • Continual Learning이 아닌 다른 분야에 Catastrophic forgetting 개념을 적용
  • Imbalaced traning에 대해 knowledge distillation을 사용하여 간단하지만 적절한 방식의 해결책 제시
profile
Graduate student at Seoul National University, majoring in Artificial Intelligence (NLP), AI Researcher at LG CNS AI Lab

0개의 댓글