본 논문은 Test-time domain adaptation 연구에서 Target domain이 Static한 상황만을 고려하는 기존 연구의 한계점을 극복하기 위해 Continual learning 방식을 접목한 CoTTA를 제안한 논문이다. CoTTA는 Continual한 학습을 진행하는 과정에서 오류 누적과 Catastrophic forgetting 문제를 완화하기 위해서 여러 방법을 사용한다.
- 지속적으로 변화하는 시나리오에서는 분포 shift로 인해 pseudo label에 영향을 주는 노이즈가 더 심해지고 잘못 보정될 수 있음
- 모델이 지속적으로 새로운 분포에 Adaptation됨에 따라 source domain의 정보를 보존하기 어려워저 Catastrophic forgetting으로 이어질 수 있음
Labeling된 Source domain과 label이 없는 target domain 간의 shift가 발생한 상황에서 adaptation을 수행하는 것이 Unsupervised Domain Adaptation(UDA)라고 함. 최근 연구에서는 UDA를 위해 target pseudo label을 반복적으로 사용하여 네트워크를 학습시키는 self-training도 유망한 결과를 보여줌
Test-time adaptation (TTA)는 Source-free domain adaptation으로도 불리는데, TTA는 adaptation을 위해 source domain data에 access할 필요가 없음. source가 없는 상황에서 adaptation을 수행하기 위해 생성 모델을 활용하여 feature alignment를 수행함
TTA를 위한 또 다른 접근법은 source model을 fine-tuning하는 것임 (e.g. TENT, SHOT).
TENT (Test entropy minimization): Pre-trained model을 사용하여 Entropy minimization을 통해 test data에 adaption함
SHOT (Source hypothesis transfer): 적응을 위해 Entropy minization 뿐만 아니라 Density regularizer를 활용함
또한, Pseudo protytypes를 사용하거나 Bayes 관점으로 분석하는 연구도 존재함
Self-training 접근 방식에서 일반적인 TTA의 목표는 예측값과 pseudo label 간의 Cross-entropy 일관성을 최소화하는 것임. 이러한 접근 방식은 stationary target domain에서는 효과적일 수 있지만, 분포 변화가 발생하는 target data에서는 pseudo label의 품질이 크게 저하될 수 있음
따라서, weight-averaged model이 훈련단계에서 최종 모델보다 종종 더 정확한 성능을 도출한다는 것에서 영감을 받아, 본 논문에서는 weight-averaged tearcher model을 활용하여 pseudo label을 생성함. Time step 에서 pseudo label은 tearch model에 의해 생성되며, 다음과 같음
이후, student model은 student의 예측값과 teacher의 예측값 간의 cross-entropy loss를 통해 업데이트되며 이 업데이트는 student가 teacher의 안정된 pseudo label에 기반하여 학습하도록 유도함. 이를 통해 분포 변화가 있는 환경에서도 모델의 adaptation 능력을 유지하고 Error accumulation을 줄이는 데 도움을 줌
Student model가 업데이트 된 이후, teacher model의 가중치는 student model의 가중치를 기반으로 지수 이동 평균을 통해 업데이트 되며 관련 수식은 아래와 같음
더욱 정확한 pseudo label은 error accumulation을 완화할 수 있지만, self-training을 통한 장기적인 Continuous domain adaptation은 필연적으로 error와 forgetting을 동반함. 더 나아가, 분포 변화가 강한 sample을 처리한 후에도 continuous adaptation의 영향으로 인해 잘못된 예측을 하는 방향으로 모델이 학습될 가능성이 존재함
이러한 문제를 해결하기 위해 CoTTA는 Stochastic restoration 방법을 제안하며, source model로부터 지식을 명시적으로 복원함
이러한 stochastic restoration은 Dropout의 특별한 형태로 간주될 수 잇으며, 훈련 가능한 가중치의 일부 Tensor 요소를 초기 가중치로 확률적으로 복원함으로써, 네트워크는 초기 source model에서 지나치게 달라지는 것을 방지하고, 결과적으로 catastrophic forgetting을 완화할 수 있음
또한, source model의 정보를 보존함으로써 네트워크는 모든 훈련 가능한 파라미터를 학습할 수 있고, 이는 Model collapse를 방지함. 이를 통해 adaptation을 위한 더 큰 model capacity를 제공함
- CIFAR10-to-CIFAR10C (standard)
- CIFAR10-to-CIFAR10C (gradual)
- CIFAR100-to-CIFAR100C
- ImageNet-to-ImageNet-C
- Cityscapses-to-ACDC