[Optimizer] AdamW Optimization

안암동컴맹·2024년 4월 12일
0

Deep Learning

목록 보기
15/31

AdamW Optimization

Introduction

AdamW is an extension of the Adam optimization algorithm, specifically designed to better handle weight decay in the training of deep learning models. Introduced by Ilya Loshchilov and Frank Hutter in their paper "Decoupled Weight Decay Regularization," it modifies the traditional Adam optimizer by decoupling the weight decay from the optimization steps. This change addresses issues with the original Adam's approach to L2 regularization, which can lead to suboptimal application of weight decay, especially under adaptive learning rate scenarios.

Background and Theory

Adam Basics

The Adam optimizer is a method that computes adaptive learning rates for each parameter. It does so by estimating the first (mean) and second (variance) moments of the gradients. The original Adam update rule is given by:

θt+1=θtηv^t+ϵm^t\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t

where:

  • θt\theta_t is the parameter vector at time step tt,
  • η\eta is the learning rate,
  • m^t\hat{m}_t and v^t\hat{v}_t are bias-corrected estimates of the first and second moments of the gradients,
  • ϵ\epsilon is a small constant added for numerical stability.

Decoupled Weight Decay

In traditional L2 regularization, the weight decay is directly incorporated into the update rule as a part of the gradient:

gt+1=gt+λθtg_{t+1} = g_t + \lambda \theta_t

where λ\lambda is the weight decay coefficient. Adam applies this adjusted gradient to compute moment estimates, which mixes the effects of weight decay with adaptive learning rates.

AdamW modifies this by separating the weight decay:

θt+1=(θtηλθt)ηv^t+ϵm^t\theta_{t+1} = (\theta_t - \eta \lambda \theta_t) - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t

This ensures that the weight decay is applied directly to the weights, and not through the gradient, allowing the optimizer to preserve the benefits of adaptive learning rates while effectively implementing weight decay.

Procedural Steps

  1. Initialize Parameters: Set initial parameters θ0\theta_0, learning rate η\eta, weight decay λ\lambda, β1\beta_1, β2\beta_2 (exponential decay rates for moment estimates), and ϵ\epsilon.
  2. Compute Gradients: Calculate the gradient gtg_t of the loss function with respect to the parameters θt\theta_t at each iteration tt.
  3. Update Moment Estimates:
    • Update biased first moment estimate: mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1-\beta_1) g_t
    • Update biased second raw moment estimate: vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2
  4. Correct Bias in Moments:
    • Correct bias in first moment: m^t=mt1β1t\hat{m}_t = \frac{m_t}{1-\beta_1^t}
    • Correct bias in second moment: v^t=vt1β2t\hat{v}_t = \frac{v_t}{1-\beta_2^t}
  5. Apply Weight Decay: Modify parameters by directly applying weight decay: θt=θtηλθt\theta_t = \theta_t - \eta \lambda \theta_t.
  6. Update Parameters:
    • Apply the adaptive learning rate to update the parameters: θt+1=θtηv^t+ϵm^t\theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t

Applications

AdamW is widely used in training deep neural networks where precise control over regularization is crucial, such as in convolutional neural networks (CNNs) and recurrent neural networks (RNNs). It is particularly beneficial in tasks where avoiding overfitting while maintaining fast convergence is important.

Strengths and Limitations

Strengths:

  • Improves generalization by preventing overfitting through more effective application of weight decay.
  • Maintains adaptive learning rate benefits, helping in faster convergence particularly in complex networks.

Limitations:

  • Requires careful tuning of hyperparameters such as the learning rate and weight decay coefficient.
  • May still be susceptible to issues common with adaptive learning rate methods, such as poor convergence behavior on some problems.

Advanced Topics

Further exploration into AdamW can involve hybrid approaches that combine it with other regularization techniques, or adapting it to specific types of neural network architectures to optimize performance.

References

  1. Loshchilov, Ilya, and Frank Hutter. "Decoupled Weight Decay Regularization." arXiv preprint arXiv:1711.05101 (2017).
profile
𝖪𝗈𝗋𝖾𝖺 𝖴𝗇𝗂𝗏. 𝖢𝗈𝗆𝗉𝗎𝗍𝖾𝗋 𝖲𝖼𝗂𝖾𝗇𝖼𝖾 & 𝖤𝗇𝗀𝗂𝗇𝖾𝖾𝗋𝗂𝗇𝗀

0개의 댓글