[Lucid] SGD와 Adam 옵티마이저의 구현

안암동컴맹·2025년 12월 12일

Lucid Development

목록 보기
14/20
post-thumbnail

💫 SGD와 Adam 옵티마이저 구현

이번 개발 일지는 Lucid가 제공하는 SGD 계열Adam 계열 옵티마이저를 대표로 삼아, 왜 이 두 축이 실전 학습에서 기본 선택이 되는지, 그리고 수식이 코드에 어떻게 녹아들어 있는지를 집중적으로 다룬다. 둘 다 PyTorch와 호환되는 사용성을 유지하면서, Lucid 특유의 간결한 파라미터 그룹/상태 직렬화/디바이스 가드를 보여준다.


❇️ SGD와 Adam을 선택한 이유

SGD는 가장 단순한 1차 방법이면서도 모멘텀·가중 감쇠를 더해도 구현이 투명해 “베이스라인”을 제시한다. Adam은 1차/2차 모멘텀을 함께 추적해 학습률을 자동 스케일링하는 현대적 기본값이다. 두 방법의 대비(단순 선형 업데이트 vs. 적응형 모멘텀)는 Lucid 옵티마이저 설계가 얼마나 PyTorch에 가깝고 동시에 간결한지를 설명하기에 적절하다.

🧩 Optimizer 베이스 – 파라미터 그룹과 state_dict

모든 옵티마이저는 Optimizer 베이스를 따른다(lucid/optim/_base.py). 파라미터 그룹은 dict 리스트로 관리되고, state_dict는 파라미터 인덱스 매핑으로 상태를 직렬화한다. 파라미터 집합 {pi}i=1N\{p_i\}_{i=1}^N에 대해 상태 SiS_i를 인덱스 ii에 매핑해 저장한다고 보면 된다:

state_dict={iSipiparams}.\text{state\_dict} = \left\{\, i \mapsto S_i \mid p_i \in \text{params} \,\right\}.

코드에서는 _flat_params()로 일렬화한 뒤 param_to_idx[p] = i로 매핑을 만든다. 이 방식 덕분에 파라미터 순서가 바뀌어도 상태를 안정적으로 복원할 수 있고, 여러 그룹의 하이퍼파라미터가 뒤섞이지 않는다.

⚡ SGD – 모멘텀·가중 감쇠 수식과 구현

고전 모멘텀 SGD는 다음으로 요약된다:

vt=μvt1+L(wt),wt+1=wtη(vt+λwt),\begin{aligned} v_t &= \mu\, v_{t-1} + \nabla L(w_t),\\ w_{t+1} &= w_t - \eta\,(v_t + \lambda w_t), \end{aligned}

여기서 μ\mu는 모멘텀, η\eta는 학습률, λ\lambda는 weight decay(L2 penalty)다. Lucid 구현(lucid/optim/sgd.py)은 이 수식을 그대로 옮긴다.

  • grad = Tensor.copy_grad(param.grad)L(wt)\nabla L(w_t) 복사.
  • grad = grad + weight_decay * param.dataλwt\lambda w_t 항.
  • 모멘텀 버퍼 state["momentum_buffer"]vtv_t에 해당하며, 마지막에 param.data = param.data - lr * gradwt+1w_{t+1}을 만든다.
for param in group["params"]:
    grad = Tensor.copy_grad(param.grad)
    if weight_decay != 0:
        grad = grad + weight_decay * param.data

    if momentum != 0:
        buf = state.get("momentum_buffer", Tensor.copy_grad(param.grad))
        buf = momentum * buf + grad
        state["momentum_buffer"] = buf
        grad = buf

    param.data = param.data - lr * grad

weight decay는 decoupled 방식이 아니라 L2 penalty로 grad에 더하는 전통적 형태다. 모멘텀 버퍼는 파라미터별 state에 저장되어 학습 중간 저장/재개 시 함께 직렬화된다.

🔬 Adam – 적응형 모멘텀, bias correction, AMSGrad

Adam의 핵심 수식은 다음과 같다:

mt=β1mt1+(1β1)gt,vt=β2vt1+(1β2)gt2,m^t=mt1β1t,v^t=vt1β2t,wt+1=wtηm^tv^t+ϵ.\begin{aligned} m_t &= \beta_1 m_{t-1} + (1-\beta_1) g_t,\\ v_t &= \beta_2 v_{t-1} + (1-\beta_2) g_t^2,\\ \hat{m}_t &= \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t},\\ w_{t+1} &= w_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}. \end{aligned}

Lucid 구현(lucid/optim/adam.py)은 이 수식을 그대로 옮긴다. 주요 코드와 항 대응은 아래와 같다.

exp_avg = state["exp_avg"]          # m_t
exp_avg_sq = state["exp_avg_sq"]    # v_t

exp_avg[:] = beta1 * exp_avg + (1 - beta1) * grad          # m_t 업데이트
exp_avg_sq[:] = beta2 * exp_avg_sq + (1 - beta2) * (grad**2)  # v_t 업데이트

bias_correct1 = 1 - beta1 ** state["step"]   # 1 - beta1^t
bias_correct2 = 1 - beta2 ** state["step"]   # 1 - beta2^t
step_size = lr * (bias_correct2**0.5) / bias_correct1      # η 보정

denom = lucid.sqrt(exp_avg_sq) + eps         # √(v̂_t) + ε
param.data -= step_size * (exp_avg / denom.data)  # w_{t+1} 업데이트
  • exp_avg/exp_avg_sq는 각각 mtm_t, vtv_t에 대응하며, Tensor.copy_grad로 가져온 grad를 사용해 autograd 그래프에서 분리한다.
  • bias_correct1/2m^t,v^t\hat{m}_t, \hat{v}_t의 분모 역할을 학습률 쪽에 녹여 수치 안정성을 높인다.
  • denom.data로 분모를 계산하는 이유는 역그래프 오염을 피하기 위해서다.
  • 최종 항이 바로 wt+1=wtηm^t/(v^t+ϵ)w_{t+1} = w_t - \eta \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)을 구현한다. 수식→코드 매핑이 눈에 보이도록 구성되어 디버깅이 쉽다.

AMSGrad 변형

amsgrad=True면 분모를 v^t\hat{v}_t 대신 maxitvi\max_{i \le t} v_i로 바꿔 단조 증가하는 분산 추정치를 사용한다. 코드에서는 max_exp_avg_sq = lucid.maximum(max_exp_avg_sq, exp_avg_sq)lucid.sqrt(max_exp_avg_sq) + eps를 분모로 둔다.

🌿 Adam 파생형 – AdamW, NAdam, RAdam

AdamW

파일: lucid/optim/adam.py — weight decay를 grad에 더하지 않고 파라미터에 직접 적용한다. 수식은

wt+1=(1ηλ)wtηm^tv^t+ϵ,w_{t+1} = (1 - \eta \lambda)\, w_t - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon},

이며 코드에서는 먼저 decoupled decay를 적용한 뒤 Adam 업데이트를 동일하게 수행한다.

if weight_decay != 0.0:
    param.data -= lr * weight_decay * param.data      # (1 - ηλ) w_t

exp_avg[:] = beta1 * exp_avg + (1 - beta1) * grad     # m_t
exp_avg_sq[:] = beta2 * exp_avg_sq + (1 - beta2) * (grad**2)  # v_t

denom = lucid.sqrt(exp_avg_sq) + eps                  # √(v̂_t)+ε
step_size = lr * (bias_correct2**0.5) / bias_correct1

param.data -= step_size * (exp_avg / denom.data)      # 남은 Adam 업데이트

decoupled decay가 grad에 섞이지 않아 모멘텀 추정치가 깨지지 않고, L2 규제가 더 직관적으로 적용된다.

NAdam

파일: lucid/optim/adam.py — Nesterov lookahead를 추가한다. 업데이트는

wt+1=wtη(β1m^t+(1β1)gt1β1t)/(v^t+ϵ),w_{t+1} = w_t - \eta \left(\beta_1 \hat{m}_t + (1-\beta_1) \frac{g_t}{1-\beta_1^t}\right) / (\sqrt{\hat{v}_t} + \epsilon),

이고 코드에서는 lookahead 항을 명시적으로 더한다.

exp_avg[:] = beta1 * exp_avg + (1 - beta1) * grad      # m_t
exp_avg_sq[:] = beta2 * exp_avg_sq + (1 - beta2) * (grad**2)  # v_t
lookahead_term = (1 - beta1) / (1 - beta1**step) * grad       # (1-β1)/(1-β1^t) g_t

bias_correct1 = 1 - beta1**step
bias_correct2 = 1 - beta2**step

m_t_hat = exp_avg / bias_correct1
v_t_hat = exp_avg_sq / bias_correct2

step_size = lr * (bias_correct2**0.5) / bias_correct1
param.data -= step_size * (m_t_hat * beta1 + lookahead_term) / (lucid.sqrt(v_t_hat) + eps).data

m_t_hat * beta1 + lookahead_term가 수식의 두 항을 그대로 반영한다. Nesterov 효과로 실제 방향을 약간 앞서 잡아준다.

RAdam

파일: lucid/optim/adam.py — 초반 분산 추정이 불안정할 때 학습률을 축소한다. ρt\rho_trtr_t를 계산해

step_size=ηrt/(v^t+ϵ),step\_size = \eta \cdot r_t / (\sqrt{\hat{v}_t} + \epsilon),

을 사용하며, ρt4\rho_t \le 4이면 모멘텀 보정 없이 η\eta만 적용한다. 이는 작은 배치나 워밍업 초기의 폭주를 줄이는 목적이다.

exp_avg[:] = beta1 * exp_avg + (1 - beta1) * grad
exp_avg_sq[:] = beta2 * exp_avg_sq + (1 - beta2) * (grad**2)

rho_inf = 2 / (1 - beta2) - 1
rho_t = rho_inf - 2 * step * beta2**step / (1 - beta2**step)

if rho_t > 4:
    r_t = (((rho_t - 4) * (rho_t - 2) * rho_inf) / ((rho_inf - 4) * (rho_inf - 2))) ** 0.5
    v_t_hat = exp_avg_sq / bias_correct2
    step_size = lr * r_t / (lucid.sqrt(v_t_hat) + eps).data

else:
    step_size = lr

param.data -= step_size * (exp_avg / bias_correct1)

여기서 rho_t > 4 조건이 충분한 분산 추정치가 모였는지를 검사하고, r_t가 학습률을 축소/확대한다. 초기 단계에서는 plain Adam보다 더 조심스럽게 이동한다.

📦 상태 직렬화와 재현성

SGD/Adam 모두 state_dict를 통해 모멘텀 버퍼와 1·2차 모멘텀 텐서를 인덱스 기반으로 저장한다(lucid/optim/_base.py). 저장/복원 시 파라미터 순서가 바뀌어도 인덱스 재매핑 덕분에 안전하다. 재현성을 위해서는 모델을 .to(device)로 먼저 이동한 뒤 옵티마이저를 생성하고, {"model": model.state_dict(), "optim": optim.state_dict()} 형태로 함께 저장/로드하는 관례를 지키면 된다.

🛰️ Gradient 흐름과 디바이스

옵티마이저는 param.grad를 읽어 업데이트한다. Tensor.backward는 grad를 Tensor.grad에 기록하고, Tensor.copy_grad로 복사해 autograd 그래프를 오염시키지 않는다. 모델·데이터·옵티마이저를 같은 디바이스에 두지 않으면 디바이스 가드에서 실패하므로, GPU 학습 시 모델을 먼저 GPU로 이동한 뒤 옵티마이저를 생성하는 순서를 권장한다. dtype은 Numeric 체계를 따르며, weight decay나 모멘트 업데이트도 Tensor 연산을 통해 장치별 경로가 자동 선택된다.

🧾 짧은 사용 예 – SGD와 Adam 대비

model = MyNet().to("gpu")
opt_sgd = lucid.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=5e-4)
opt_adam = lucid.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4, amsgrad=True)

for inputs, targets in loader:
    inputs, targets = inputs.to("gpu"), targets.to("gpu")
    opt_sgd.zero_grad()

    out = model(inputs)
    loss = F.cross_entropy(out, targets)
    loss.eval()

    loss.backward()
    opt_sgd.step()

Adam도 같은 패턴으로 호출하며, amsgradbetas 조정으로 수렴 특성을 바꿀 수 있다. 파라미터 그룹을 사용하면 부분 모듈에 다른 학습률/감쇠를 부여하는 것도 동일하게 지원한다.

핵심은:

  1. 모델을 먼저 원하는 디바이스로 이동,
  2. 옵티마이저 생성,
  3. loss.eval()(MLX lazy) 후 backward → step 순서를 지키는 것.

🧠 정리

  • SGD: 모멘텀과 L2 penalty를 grad에 더하는 전통적 형태, 단순하면서도 강력한 베이스라인. 모멘텀 버퍼는 state에 저장되어 직렬화된다.
  • Adam: 1차/2차 모멘텀, bias correction, eps 안정화까지 수식 그대로 구현. exp_avg/exp_avg_sq가 핵심 상태.
  • AdamW/NAdam/RAdam: weight decay decoupling, Nesterov lookahead, 분산 스케일링 등 실전에서 자주 쓰이는 변형을 코드로 명료하게 분리. 각 변형이 어떤 수식 항을 추가/수정했는지 코드 스니펫으로 확인 가능.
  • state_dict: 파라미터 인덱스로 상태를 직렬화해 저장/로드 시 순서 변경에도 안전. 모델을 먼저 .to(device)로 옮긴 뒤 옵티마이저를 생성하는 관례가 필수.
  • 디바이스/grad 가드: Tensor.copy_grad로 autograd 그래프 분리, 디바이스 불일치 시 초기 에러, MLX 경로에서는 loss.eval()로 lazy 실행을 수동 materialize.

이 흐름만 지키면 Lucid에서도 PyTorch와 동일한 학습 루프를 최소 수정으로 구현할 수 있다. 원하는 변형을 선택해도 공통의 파라미터 그룹/직렬화/디바이스 가드는 그대로 재사용된다.

profile
Korea Univ. Computer Science & Engineering

0개의 댓글