Decoupled Weight Decay

iissaacc·2022년 3월 4일
0

paper reading

목록 보기
17/20
post-custom-banner

Prologue

L2_2 regularization과 weight decay는 같다.

많은 자료에서 이렇게 소개한다. tensorflow와 pytorch에서도 두 가지를 혼용해서 구현하고 있다. 여기에 그렇지 않다고 주장한 연구자들이 나왔다.

What did the authors try to accomplish?

연구에서는 3가지를 주장하면서 앞서 이야기했던 오해를 바로 잡고 L2_2 regularizaer를 더 잘 쓸 수 있도록 Adam의 알고리즘을 개선했다.

  1. L2_2 regularization과 weight decay는 같지 않다.
  2. L2_2 regularization은 Adam에서 제대로 작동하지 않는다.
  3. Weight decay는 SGD, Adam에서 모두 효과적이다.

먼저 앞으로 나아가기 전에 L2_2 regularization과 weight decay를 간락하게 알아보자.

L2_2 Regularization

Ltreg=f(θt)+λ2θt22gt=Ltreg=f(θt)+λθtOt=optimizer(gt)θt+1=θtαOt\begin{array}{l} L_t^{reg}=f(\theta_t)+\frac{\lambda}{2}||\theta_t||_2^2 \\ g_t=\nabla L_t^{reg}=\nabla f(\theta_t)+\lambda\theta_t\\ O_t=\text{optimizer}(g_t) \\ \theta_{t+1}=\theta_t-\alpha O_t \end{array}

Weight decay

Lt=f(θt)gt=LtOt=optimizer(gt)  θt+1=θtαOtλθt\begin{array}{l} L_t=f(\theta_t) \\ g_t=\nabla L_t \\ O_t=\text{optimizer}(g_t)\ \ \\ \theta_{t+1}=\theta_t-\alpha O_t-\lambda'\theta_t \end{array}

익숙하게 알고 있다시피 L2_2 regularization과 weight decay는 각각 loss function과 optimizer를 조작해서 θ\theta가 0에 가까워지도록 한다. SGDAdam에 이 두 가지를 적용한 식을 보면 문제가 좀더 명확히 드러난다.

What were the key elements of the approach?

SGD

Lt=f(θt)gt=Lt Ot=1gtθt+1=θtαOt \begin{array}{l} L_t=f(\theta_t) \\ g_t=\nabla L_t\quad \ \\ O_t = 1\cdot g_t\\ \theta_{t+1}=\theta_t-\alpha O_t \ \end{array}

SGD with L2_2 regularization

Ltreg=f(θt)+λ2θt22gt=f(θt)+λθt θt+1=θtαf(θt)αλθt \begin{array}{l} L_t^{reg}=f(\theta_t)+\frac{\lambda}{2}||\theta_t||_2^2 \\ g_t=\nabla f(\theta_t)+\lambda\theta_t \ \\ \theta_{t+1}=\theta_t-\alpha\nabla f(\theta_t)-\alpha\lambda\theta_t \ \end{array}

SGD with Weight decay

Lt=f(θt) gt=Ltθt+1=θtαf(θt)λθt\begin{array}{l} L_t=f(\theta_t) \ \\ g_t=\nabla L_t \qquad\\ \theta_{t+1}=\theta_t-\alpha\nabla f(\theta_t)-\lambda'\theta_t \end{array}

L2_2 regularization과 weight decay가 같다고 할 때 αλθt=λθt\alpha\lambda\theta_t=\lambda'\theta_t가 성립한다. λ, λ,α\lambda,\ \lambda', \alpha는 모두 scaler이므로 λ=λα\lambda=\frac{\lambda'}{\alpha}일 때 SGD는 L2_2 regularization과 weight decay가 완전히 같은 모양이 된다. SGD에 한정해서는 참인 명제다. 이런 이유로 pytorch에서는 weight decay를 L2_2 regularization으로, tensorflow에서는 weight decay만 구현하고 있다.

사족을 달자면 L2_2 regularization의 λ\lambdaα\alpha의 영향에서 자유로울 수 없다는 점을 주목해야 한다. 이 말은 L2_2 regularization을 사용하는 SGD에 LR scheduler를 함께 쓴다면 최적의 λ\lambda를 찾았다고 해도 α\alpha가 바뀌면서 λ\lambda는 힘을 잃는다는 말과 같다. adaptive gradient method에 L2_2 regularization과 weight decay를 적용해보면서 연구자의 주장한 3가지의 근거를 알아볼 수 있다.

Adaptive gradient method

Adam을 대표로 하는 adaptive gradient method는 점화식이 복잡해서 α\alpha를 줄여주는 부분은 MtM_t로 단순하게 해서 볼 거다.

Lt=f(θt)gt=Lt Ot=Mtgtθt+1=θtαOt\begin{array}{l} L_t=f(\theta_t) \\ g_t=\nabla L_t\quad \ \\ O_t=M_tg_t\\ \theta_{t+1}=\theta_t-\alpha O_t \end{array}

Adaptive gradient method with L2_2 regularization

Ltreg=f(θt)+λ2θt22gt=f(θt)+λθt Ot=Mt(f(θt)+λθt)θt+1=θtαMtf(θt)αMtλθt\begin{array}{l} L_t^{reg}=f(\theta_t)+\frac{\lambda}{2}||\theta_t||_2^2\\ g_t=\nabla f(\theta_t)+\lambda\theta_t \ \\ O_t=M_t(\nabla f(\theta_t)+\lambda\theta_t) \\ \theta_{t+1}=\theta_t-\alpha M_t\nabla f(\theta_t)-\alpha M_t\lambda\theta_t \end{array}

Adaptive gradient method with Weight decay

Lt=f(θt)gt=LtOt=Mtgtθt+1=θtαMtf(θt)λθt\begin{array}{l} L_t=f(\theta_t) \\ g_t=\nabla L_t \\ O_t=M_tg_t\\ \theta_{t+1}=\theta_t-\alpha M_t\nabla f(\theta_t)-\lambda'\theta_t \end{array}

마찬가지로 L2_2 regularization과 weight decay가 같을 때 Mt, θtM_t,\ \theta_t는 matrix, λ, λ,α\lambda,\ \lambda', \alpha는 모두 scaler이므로 모든 θt\theta_t에 대해 λθt=αMtλθt\lambda'\theta_t=\alpha M_t\lambda\theta_t가 성립해야 한다. 그러면 반드시 Mt=kIM_t=k\text{I}이어야 한다. 그렇지만 optimizer의 정의에 의해 λθt=αMtλθt\lambda'\theta_t=\alpha M_t\lambda\theta_t는 성립할 수 없다. αMtλ\alpha M_t\lambdatt가 변할때마다 바뀌는 반면 λ\lambda'는 고정값이기 때문이다. 이제 연구자들의 주장에 대한 근거가 드러났다.

  1. L2_2 regularization과 weight decay는 같지 않다.
    위와 같은 이유로 대부분의 경우 adaptive gradient method에서는 L2_2 regularization과 weight decay가 반드시 달라야 하므로 MtkIM_t\neq k\text{I}가 성립하고 L2_2 regularization과 weight decay가 같다는 주장은 거짓이다.

  2. L2_2 regularization은 Adam에서 제대로 작동하지 않는다.
    adaptive gradient method는 L2_2 regularization에서 SGD보다 MtM_t만큼 가중치를 준다는 점을 알 수 있는데 이것은 L2_2 regularization를 쓰면서 기대하는 행동이 아니다.

  3. Weight decay는 SGD, Adam에서 모두 효과적이다.
    두 가지 optimizer에서 weight decay λ\lambda'α\alphaMtM_t 두 가지로부터 자유롭다.

그럼에도 불구하고 L2_2 regurlarization과 weight decay가 같은 특수한 경우를 생각해보자. 그러면 L2_2 regularization은 θ\thetas\sqrt{s}로 scale하고 Mt=1sIM_t=\frac{1}{s}\text{I}로 고정해야 한다.

Adaptive gradient method with scaled adjusted L2_2 Regularization

Ltsreg=f(θt)+λ2θs22gt=f(θt)+λθtsθt+1=θtα1sf(θt)α1sλθs=θtα1sf(θt)αλθ\begin{array}{lcl} L_t^{sreg}=f(\theta_t)+\frac{\lambda}{2}||\theta\odot\sqrt s||_2^2\\ g_t=\nabla f(\theta_t)+\lambda\theta_t\odot s\\ \theta_{t+1}=\theta_t-\alpha\frac{1}{s}\nabla f(\theta_t)-\alpha\frac{1}{s}\lambda\theta\odot s\\ \qquad=\theta_t-\alpha\frac{1}{s}\nabla f(\theta_t)-\alpha\lambda\theta \end{array}

Adaptive gradient method with weight decay

Lt=f(θt)gt=Ltθt+1=θtα1sf(θt)λθ\begin{array}{lcl} L_t=f(\theta_t)\\ g_t=\nabla L_t\\ \theta_{t+1}=\theta_t-\alpha\frac{1}{s}\nabla f(\theta_t)-\lambda'\theta \end{array}

이렇게 해야 그나마 SGD처럼 λ=λα\lambda=\frac{\lambda'}{\alpha}가 성립한다. 그렇지만 adaptive gradient method라는 이름에 어울리는 optimizer가 되려면 tt를 업데이트 할 때마다 ss도 함께 바꿔줘야 한다. 여기에서는 θ\thetas\sqrt{s}를 곱하고 있는데 이것은 L2_2 regularization를 사용한 Adam보다 강하게 regularize하는 효과가 있다. 그렇지만 이렇게 해도 SGD가 가진 문제가 여전히 남아 있다.

Decoupling

지금까지의 증명으로 기존의 optimizer에서 구분없이 사용하고 있던 L2_2 regularizaer와 weight decay를 따로 떼놔야 한다고 주장하고 있다.

좀 복잡하니까 weight update부분만 따로 떼서 보면 SGDW에서 L2_2 regularization 대신 weight decay를 사용해서 α\alpha의 영향에서 벗어나서 좀더 안정적으로 generalization할 수 있게 했다.

θtSGD=θt1αft(θt1)αλθt1  =(1αλ)θt1αft(θt1)θtSGDW=θt1ηtαft(θt1)ηtλθt1 =(1ηtλ)θt1ηtαft(θt1)\begin{array}{l} \theta_t^{SGD}=\theta_{t-1}-\alpha\nabla f_t(\theta_{t-1})-\alpha\lambda\theta_{t-1}\\ \ \ \quad\quad=(1-\alpha\lambda)\theta_{t-1}-\alpha\nabla f_t(\theta_{t-1})\\ \\ \theta_t^{SGDW}=\theta_{t-1}-\eta_t\alpha\nabla f_t(\theta_{t-1})-\eta_t\lambda\theta_{t-1}\\ \ \quad\qquad=(1-\eta_t\lambda)\theta_{t-1}-\eta_t\alpha\nabla f_t(\theta_{t-1}) \end{array}

특히 AdamW는 보다 안정적인 generalization을 할 수 있을 뿐만 아니라 더 강하게 regularization하는 효과를 볼 수 있다.

θtAdam=θt1αMtft(θt1)αMtλθt1=(1αMtλ)θt1αMtft(θt1)θtAdamW=θt1ηtαMtft(θt1)ηtλθt1  =(1ηtλ)θt1ηtαMtft(θt1)\begin{array}{l} \theta_t^{Adam}=\theta_{t-1}-\alpha M_t\nabla f_t(\theta_{t-1})- \alpha M_t\lambda\theta_{t-1}\\ \quad\quad\quad=(1- \alpha M_t\lambda)\theta_{t-1}-\alpha M_t\nabla f_t(\theta_{t-1})\\ \\ \theta_t^{AdamW}=\theta_{t-1}-\eta_t\alpha M_t\nabla f_t(\theta_{t-1})-\eta_t\lambda\theta_{t-1}\\ \ \ \quad\qquad=(1-\eta_t\lambda)\theta_{t-1}-\eta_t\alpha M_t\nabla f_t(\theta_{t-1}) \end{array}

Experiment 1

CIFAR-10을 각각의 optimzier로 학습한 해서 test error를 측정한 결과물을 봐도 decouple한 optimzier의 gerneralization성능이 뛰어나다고 알 수 있다.

Experiment 2

첫번째 줄은 L2_2 Regularized Adam, 두번째 줄은 AdamW로 CIFAR-10을 학습한 모델이다. 여기에 첫 행부터 fixed, step, cosine annealing을 LR policy로 했다. 처음에 Adam을 접했을 때 optimizer가 자체적으로 LR을 decay해줘서 LR policy가 오히려 학습을 방해할거라고 생각했다. 연구에서는 오히려 LR policy를 사용하면 넓은 search space와 성능도 함께 가져갈 수 있다는 점이 장점으로 작용한다고 밝혔다.

Epilogue

  1. 나는 tensorflow를 주로 써서 이걸 읽는 내내 weight decay가 optimizer에서 잘 분리돼있는데 왜 자꾸 따로 떼어놓자는 건지 와닿지 않았다. pytorch에서는 weight decay라고 적고 L2_2 regularization으로 구현하고 있어서 이렇게 주장할 수도 있겠다 싶었다.

  2. pytorch를 보면서 문서화의 중요성을 다시 한 번 느꼈다.

Reference

  1. Decoupled Weight Decay Regularization
  2. PyTorch Docs - SGD
post-custom-banner

0개의 댓글