• 필자는 학습 시 보통 Adam을 사용하였는데, 최근 많은 SOTA 모델에서 최적화(Optimizer) 알고리즘으로 AdamW을 사용한다.

  • 그래서 오늘은 Adam과 AdamW의 차이에 대해 알아보고자 한다!


AdamW

AdamW(Adam Weight decay)란?

  • Adaptive Moment Estimation인 Adam을 개선한 옵티마이저로, 가중치 감소(Weight decay)를 처리하는 방식에서 차이

가중치 감소

  • 직관적으로 파라미터 업데이트 과정에서 현재 가중치에 비례하여 빠지는 항을 의미
    Wnew=Wold>(ηλ>Wold)감소(Decay) 항W_{\text{new}} = W_{\text{old}} - >\dots - \underbrace{(\eta \cdot \lambda >\cdot W_{\text{old}})}_{\text{감소(Decay) 항}}
  • 학습률(η)과 강도(λ)는 항상 양수이기 때문에 W_old가 양수던 음수던 이전보다 0에 가깝게 업데이트 하여 과적합을 방지
  • Adam과의 차이점은 L2 정규화 처리 방식에 있다.

    • Adam

      gtAdam(L2)=Loriginal(Wt1)+λWt1\mathbf{g}_t^{\text{Adam}(\mathbf{L}_2)} = \nabla L_{\text{original}}(W_{t-1}) + \lambda W_{t-1}

      • Adam은 L2 정규화 항(λW)이 backpropagation 시 기울기 계산에 포함되어 갱신할 때 원치 않게 스케일링 되는 문제 발생

    • AdamW

      gtAdamW=Loriginal(Wt1)\mathbf{g}_t^{\text{AdamW}} = \nabla L_{\text{original}}(W_{t-1})

      • 가중치 감소 항이 backpropagation 시 기울기 계산에서 분리(decoupled)

  • 그렇다면, Adam과 같이 기울기 갱신에 가중치 감소 항이 포함되면 왜 안 좋을까?

    • scaling factor에 포함된 v가 매 업데이트마다 달라지고, 가중치 감소 항도 영향을 받기 때문에
      • v가 커지면 -> 가중치 감소 효과 약해짐
      • v가 작아지면 -> 가중치 감소 효과 커짐

    • 이는 어떤 파라미터는 강한 제약 또 다른 파라미터는 약한 제약을 받게되어 불균형 상태가 발생하기 때문에 예측 능력(일반화 성능)이 저하되는 문제를 초래



추가정보

  • 위에서 언급한 v는 scailing factor를 이루는 주요 값으로, 2차 모멘트라고 불린다.
    • V가 왜 2차 모멘트?
      • 모멘트는 확률 변수의 모양과 특성을 나타내는 척도로 사용
      • Adam, AdamW는 기울기(gradient)를 확률 변수로 간주하여 과거 기울기 제곱의 기댓값을 추정하기 때문에 2차 모멘트라고도 불림(엄밀히 말하면 $\mu$가 0이기 때문에 2차 원점 모멘트)

2차 모멘트(v)

  • V는 이전 업데이트에 사용한 gradient를 누적 제곱합으로 계산된다.

  • 이 개념이 처음 소개되었던 Adagrad에서는 과거 모든 기울기를 단순 누적 제곱합으로 사용하였으나,

  • Adagrad, RMSprop, Adam, AdamW로 새로운 옵티마이저가 등장할 때마다 이를 계산하는 방법에 변화가 생김

    • 여기서 등장한 계산 방법이 지수이동평균(Exponentially Weighted Moving Average)

지수이동평균(EWMA)

  • Adam, AdamW에서 2차 모멘트인 V를 계산하는 데 사용하는 방법

  • 이동 평균이란,

    • 시계열 데이터에서 추세를 확인하기 위해 사용하며, 핵심은 시간이 지남에 따라 가중치를 어떻게 부여?에 따라 달라짐
      • 단순 이동 평균 (모두 동일한 가중치)
      • 지수 이동 평균 (과거일수록 적은 가중치)
  • EWMA는 β2\beta_2라는 감쇠율을 사용하여 비율 반영 누적을 수행한다.

  • EWMA 수식

    vt=β2vt1+(1β2)gt2v_t = \beta_2\cdot v_{t-1} + (1-\beta_2) * g_t^2
    • 감쇠율은 0.999(default)이며, 튜닝이 가능하다.

    • 업데이트 반복할수록, 현재 step(=배치) 학습률 결정에 영향을 미치는 이전 기울기 정보는 거듭제곱만큼 망각. 즉, 지수적으로(Exponentially) 빠르게 작아진다.

    • 감쇠율이 아무리 1에 가까워도 1보다 작다면 망각 속도 차이는 있으나 망각은 유지됨

      vt(1β2)gt2가장 큰 가중치+(1β2)β2gt12두 번째 큰 가중치++(1β2)β2t1g12가장 작은 가중치v_{t} \approx \underbrace{(1-\beta_2)g_t^2}_{\text{가장 큰 가중치}} + \underbrace{(1-\beta_2)\beta_2 g_{t-1}^2}_{\text{두 번째 큰 가중치}} + \dots + \underbrace{(1-\beta_2)\beta_2^{t-1} g_1^2}_{\text{가장 작은 가중치}}

      • 위 수식처럼 chain 반응이 일어나, 현재(t) 2차 모멘트에 반영되고 이는 하이퍼파라미터인 학습률에 반영되어 학습 속도를 결정한다.

        η다음 step 학습률=η학습률(초기 고정 값)×1vi+ϵ스케일링 요소\underbrace{\eta}_{\text{다음 step 학습률}} = \underbrace{\eta}_{\text{학습률(초기 고정 값)}} \times \underbrace{\frac{1}{\sqrt{v_i} + \epsilon}}_{\text{스케일링 요소}}
        • v가 크다면 -> 학습 속도 느려짐
        • v가 작다면 -> 학습 속도 빨라짐

최종 파라미터 갱신 방법

  • Adam

    θt+1=θt(ηV^t+ϵ)다음 step 학습률M^t방향 (Momentum)\theta_{t+1} = \theta_t - \underbrace{\left(\frac{\eta}{\sqrt{\hat{V}_t} + \epsilon}\right)}_{\text{다음 step 학습률}} \cdot \underbrace{\hat{M}_t}_{\text{방향 (Momentum)}}

  • AdamW

    θt+1=θtληθt가중치 감소항(ηV^t+ϵ)M^tAdam 업데이트 항\theta_{t+1} = \theta_t - \underbrace{\lambda \eta \theta_t}_{\text{가중치 감소항}} - \underbrace{\left(\frac{\eta}{\sqrt{\hat{V}_t} + \epsilon}\right) \cdot \hat{M}_t}_{\text{Adam 업데이트 항}}

    • 가중치 감소항 분리됨(decoupled)

  • M은 gradient의 1차 원점 모멘트로, 관성을 더해 진동 방지, Local Minima 탈출에 기여

  • 다시 강조하자면 M, V 계산에 gradient 포함되어 Adam은 가중치 감소 항 스케일링 영향이 발생, AdamW는 스케일링 영향에 자유롭다.

  • 학습 시 weight_decay를 부여하지 않으면(= 0) Adam 업데이트 결과 = AdamW 업데이트 결과


profile
Data Scientist & Data Analyst

0개의 댓글