[논문리뷰]Overcoming catastrophic forgetting in neural networks (EWC)([PNAS 2017])

이은비·2024년 8월 20일

정리)
old task인 A, new task인 B에 대해서 no penalty인 경우에는 파란선처럼 보이는 것과 같이 B에대해서는 잘 학습하지만 A에 대해서는 잘 학습하지 못하는 것을 볼 수 있고, L2인 경우에는 초록선처럼 보이는 것과 같이 A,B모두 잘 학습하지 못하는 것을 볼 수 있고,EWC의 경우가 A,B에 대해서 잘 학습하는 것을 확인 할 수 있었습니다.이것이 가능한 이유가 loss값에 regularization term을 붙여서 이전에 학습한 작업에서 중요한 가중치들이 새로운 작업을 학습할 때 급격히 변하지 않도록 제약을 두는 방식으로 동작합니다. Loss function의 F,즉 Fisher information matrix는 weight가 얼마나 중요한지를 나타냅니다. Score function의 covariance matrix과 같은 개념으로 피셔 정보 값이 크다는 것은 매개변수
𝜃에 대해 많은 정보를 가지고 있다는 것을 의미합니다. 그림에서 보이는 것과 같이 매개변수의 작은 변화가 데이터에 큰 영향을 끼칠 것을 예측할 수 있고 즉, 관측된 데이터가 매개변수의 변화를 잘 반영하고 있어 해당 매개변수를 더 정확하게 추정할 수 있다는 뜻입니다.그리고 이러한 F값이 크다면 전체 loss function이 최소가 되기 위해서 뒤에 나오는 현재 모델의 파라미터와 이전 작업을 학습한 후 최적화된 파라미터 값의 차의 제곱 term이 최소가 되어야 합니다. 그리고 이와 같은 방법으로 continual learning을 할때 발생할 수 있는 catastropic forgetting을 극복하고자 하였습니다.

EWC (Elastic Weight Consolidation)는 인공 신경망이 연속적인 학습 환경에서 새로운 작업을 학습하면서 이전 작업에 대해 학습한 내용을 잃지 않도록 하는 일종의 지속 학습(Continual Learning) 기법입니다. 이 방법은 특히 "망각 문제(catastrophic forgetting)"를 해결하기 위해 제안되었습니다.

EWC의 기본 개념
EWC는 이전에 학습한 작업에서 중요한 가중치들이 새로운 작업을 학습할 때 급격히 변하지 않도록 제약을 두는 방식으로 동작합니다. 이는 특정 가중치에 대해 이전 작업에서 얼마나 중요한지를 나타내는 정보를 활용하여, 이 정보에 기반해 가중치의 변화를 억제합니다.

EWC의 손실 함수 (Loss Function)
EWC에서 사용되는 손실 함수는 다음과 같은 형태로 표현됩니다:

EWC의 효과
EWC는 새로운 작업을 학습하면서도 이전 작업의 성능을 유지하도록 도와줍니다. 피셔 정보를 활용하여 이전 작업에서 중요했던 가중치를 유지하려는 특성 덕분에, 새로운 정보를 학습할 때 기존 지식을 잃는 문제를 줄일 수 있습니다.

이 기법은 연속 학습 환경에서 인공 신경망의 성능을 안정적으로 유지하는 데 중요한 역할을 합니다. 추가적인 질문이 있다면 언제든지 물어보세요! 여기에서 더 많은 정보를 확인할 수 있습니다.

  1. No Penalty
    문제점: No penalty 접근법은 새로운 task를 학습할 때 이전 task에 대해 학습한 가중치를 보호하지 않습니다. 즉, 네트워크가 새로운 데이터를 학습하면서 이전에 학습한 task의 정보를 잃어버리게 됩니다. 이는 네트워크가 새로운 task에 과적합(overfitting)되거나, 기존 task의 성능이 급격히 떨어지는 catastrophic forgetting 문제를 유발합니다.

결과: 결국, no penalty 방식은 연속적으로 여러 task를 학습할 때 이전에 학습한 task의 성능을 유지하지 못하기 때문에 효과적인 학습이 어렵습니다.

  1. L2 Regularization
    문제점: L2 regularization(릿지 회귀)은 가중치의 크기를 최소화하는 방식으로 네트워크의 복잡도를 줄이고, 일반화 성능을 향상시키기 위한 방법입니다. 그러나 이 방법은 각 task에 대해 동일한 방식으로 가중치를 규제하기 때문에, 특정 task의 성능을 보호하기 위해 가중치를 조정하는 데 한계가 있습니다. 즉, 특정 task에 특화된 중요한 가중치까지도 줄어들 수 있어, 이전 task에서 학습한 중요한 정보를 잃어버릴 수 있습니다.

결과: L2 regularization은 새로운 task를 학습하면서 가중치를 과도하게 제약하여, 네트워크가 기존의 task를 잘 유지하지 못하게 만듭니다. 이는 결국 모든 task에 대해 적절한 성능을 발휘하는 데 실패할 수 있습니다.

EWC:Nuts and Bolts-> L2 regularization ~~Stringent regularization.
EWC ~~>Flexible regularization
출처)
https://ffighting.net/deep-learning-paper-review/incremental-learning/ewc/
https://winnerus.medium.com/ai-%EB%85%BC%EB%AC%B8%EB%A6%AC%EB%B7%B0-continual-learning-on-deep-learning-16969792acc7

2–1. Overcoming catastrophic forgetting in neural networks (Elastic Weight Consolidation)

첫번째 논문은 Deep Mind에서 나온 Overcoming catastrophic forgetting in neural networks 입니다. 주로 이 논문에서 사용한 알고리즘의 이름인 Elastic Weight Consolidation(EWC)라고도 불리기도 하는듯 합니다.

이 알고리즘을 간단히 설명하면, 이전 데이터(A)를 우선 학습합니다. 그리고 딥뉴럴넷에서 이 데이터를 학습+분류하는데 중요한 Weight(뉴런?) 와 중요하지 않은 Weight를 계산하여, 이 중요한 Weight들은 이후 추가되는 새로운 데이터 학습 시에 최대한 변화되지 않도록 하고, 중요하지 않았던 Weight들 위주로 학습하도록 합니다.

Figure2. Loss function from the ‘EWC’ paper [3]
논문에서는 이 방법을 고려한 Loss function을 제안합니다. 이전데이터(A)와 새로운 학습데이터(B)를 학습한다고 할때, F_i는 기존 학습데이터(A)에 대한 Weight들의 중요도를 나타내는 정보를 담고있으며, 세타는 현재 Weight값, 세타_A는 이전 데이터를 학습한 후의 Weight값들 입니다. 이 차이가 커질수록 기존 모델로부터 변화함을 뜻하며, 이 중 중요한 Weight들의 변화가 커질수록 Loss값이 커지도록 되어있습니다. 람다값이 커질수록 이 중요한 Weight값들의 변화를 억제하는 방향으로 학습을 하는 것으로 보입니다.

profile
cs/ce 전공 재학생입니다.

0개의 댓글