[Lucid] Optimizer와 LR Scheduler 시스템

안암동컴맹·2025년 12월 11일

Lucid Development

목록 보기
10/20
post-thumbnail

⚙️ Optimizer와 LR Scheduler 시스템

Lucid의 Optimizer와 LR Scheduler는 PyTorch를 벤치마킹해 동일한 흐름을 목표로 설계됐다. 핵심은 Optimizernn.Parameter에만 작동하고, LRScheduler가 Optimizer의 param_groups를 동적으로 갱신하는 구조다. 이 글에서는 각 베이스 클래스의 시그니처와 내부 흐름, nn.Module/nn.Parameter와의 상호작용, 그리고 실제 사용 예제를 코드 스니펫과 함께 상세히 풀어본다.


🧭 설계 목표와 역할 분리

  • Optimizer: 파라미터 집합(nn.Parameter)에 대해 step/zero_grad/state_dict/param_groups 관리. 파라미터 이외의 Tensor는 거부한다.
  • LRScheduler: Optimizer의 param_groups에 저장된 학습률을 시점(epoch/step)에 따라 갱신. Optimizer와 독립적으로 저장/로드 가능.
  • 호환성: PyTorch의 API(closure, param_groups, state_dict, verbose 등)와 사용 흐름을 최대한 모사.

Lucid의 Optimizer는 파라미터/버퍼 시스템 위에 얹혀 있으며, nn.Module.parameters()가 반환하는 Parameter만 받아들인다. LRScheduler는 Optimizer를 입력받아 학습률을 조정하되, Optimizer 내부 상태(state_dict)와 별도로 직렬화된다.

🧱 Optimizer 베이스 클래스

class Optimizer(ABC):
    def __init__(self, params: Iterable[nn.Parameter], defaults: dict[str, Any]) -> None: ...

    @abstractmethod
    def step(self, closure: _OptimClosure | None = None) -> Any | None: ...

    def zero_grad(self) -> None: ...

    def param_groups_setup(self, params: list[nn.Parameter], defaults: dict[str, Any]) -> list[dict[str, Any]]: ...

    def add_param_group(self, param_group: dict[str, Any]) -> None: ...

    def state_dict(self) -> dict: ...

    def load_state_dict(self, state_dict: dict) -> None: ...

초기화와 검증

  • params반드시 nn.Parameter 반복자여야 하며, 타입 검증 후 리스트로 보관.
  • defaults는 학습률, weight decay 등 하위 optimizer가 공통으로 쓰는 하이퍼파라미터를 담는다.
  • param_groups: param_groups_setup으로 그룹화(기본은 단일 그룹). 그룹마다 {"params": [...], **defaults} 형태.
  • state: defaultdict(dict)로 파라미터별 상태 저장(모멘텀 버퍼 등), 직렬화 시 인덱스로 매핑.

step/zero_grad

  • step(closure=None): 하위 클래스가 구현. 클로저는 재계산이 필요한 optimizer(SGD with line search 등) 호환용.
  • zero_grad(): 모든 param_group의 Parameter.grad를 0으로 설정. nn.Parameter.zero_grad를 호출해 grad 누적을 초기화.

param_group 관리

  • add_param_group: 새로운 파라미터 세트를 추가. 중복 파라미터가 존재하면 예외를 던져 버그를 방지.
  • 그룹별 하이퍼파라미터를 덮어쓰되, params 키만 별도로 취급해 리스트를 유지한다.

state_dict 직렬화

def state_dict(self) -> dict:
    param_to_idx = {p: i for i, p in enumerate(self._flat_params())}
    packed_state = {
        param_to_idx[p]: copy.deepcopy(st) 
        for p, st in self.state.items() if p in param_to_idx
    }
    packed_groups = [...]
    return {"state": packed_state, "param_groups": packed_groups}
  • 파라미터 객체 참조를 인덱스로 치환해 직렬화. 로드시 현재 파라미터 순서와 매핑해 상태를 복원.
  • param_group도 파라미터 인덱스로 저장해 재구성 가능. PyTorch state_dict 포맷과 유사.

🔗 Optimizer ↔ nn.Module/nn.Parameter

  • Optimizer는 Parameter만 받는다. 버퍼나 일반 Tensor는 타입 에러를 일으킨다.
  • 일반적 흐름:
    1) model = MyModule()
    2) opt = OptimizerSubclass(model.parameters(), lr=...)
    3) loss = criterion(model(x), y)
    4) opt.zero_grad()loss.backward()opt.step()
  • zero_grad()nn.Parameter.zero_grad()를 호출해 grad 필드를 초기화한다.
  • param_groups는 모듈 트리와 무관하게 optimizer 내부에서 관리되므로, 특정 레이어에 다른 하이퍼파라미터를 적용하려면 그룹을 분리해 전달한다.

예시: 커스텀 그룹

params = [
    {"params": model.backbone.parameters(), "lr": 1e-3},
    {"params": model.head.parameters(), "lr": 1e-2, "weight_decay": 1e-4},
]
opt = MyOptimizer(params, defaults={"lr": 1e-3, "weight_decay": 0.0})

defaults는 공통값, 그룹 딕셔너리는 특정 하이퍼파라미터를 덮어쓴다.

🧭 LRScheduler 베이스 클래스

class LRScheduler(ABC):
    def __init__(self, optimizer: Optimizer, last_epoch: int = -1, verbose: bool = False) -> None: ...

    @abstractmethod
    def get_lr(self) -> list[float]: ...

    def step(self, epoch: int | None = None) -> None: ...

    def state_dict(self) -> dict[str, Any]: ...

    def load_state_dict(self, state_dict: dict[str, Any]) -> None: ...

    @property
    def last_lr(self) -> list[float]: ...

초기화와 베이스 속성

  • Optimizer 유효성 검증: param_groups를 가져올 수 있어야 한다.
  • base_lrs: 초기 각 param_group의 학습률. 스케줄 계산의 기준.
  • last_epoch: 현재까지의 step/epoch 카운터. 초기값 -1은 step() 호출 시 0부터 시작하도록 함.
  • _last_lr: 직전 step 이후 설정된 학습률 기록.

step과 get_lr

def step(self, epoch=None):
    if epoch is None: 
        self.last_epoch += 1
    else: 
        self.last_epoch = int(epoch)

    self._step_count += 1
    new_lrs = self.get_lr()
    ...
    for group, lr in zip(self.optimizer.param_groups, new_lrs):
        group["lr"] = float(lr)
    self._last_lr = [float(g["lr"]) for g in self.optimizer.param_groups]

    if self.verbose:
        print(f"Epoch {self.last_epoch}: setting learning rates to {self._last_lr}.")
  • 하위 클래스는 get_lr만 구현하면 된다. 반환 리스트 길이는 param_groups와 같아야 한다.
  • epoch 인자를 직접 주면 스케줄을 외부 카운터와 동기화할 수 있다.

state_dict

  • last_epoch, base_lrs, _step_count, _last_lr, _group_count를 저장.
  • 로드시 그룹 수가 다르면 에러를 던져 호환성을 명확히 한다.

🔄 Scheduler ↔ Optimizer 상호작용

  • Scheduler는 Optimizer 내부의 param_groups를 직접 수정해 학습률을 바꾼다. Optimizer는 이를 참조해 step 수행 시 사용.
  • optimizer state_dict와 별개로 스케줄러 state_dict를 저장/로드해야 동일한 학습 곡선을 재현할 수 있다.
  • 일반적 패턴:
    opt = MyOptimizer(model.parameters(), defaults={"lr": 1e-3})
    sched = MyScheduler(opt, ...)
    for epoch in range(num_epochs):
        for batch in data:
            ...
            loss.backward()
            opt.step()
            opt.zero_grad()
        sched.step()  # 또는 sched.step(epoch)
  • 스케줄러를 먼저(step) 적용 후 optimizer step을 호출하는 패턴, 그 반대 패턴 등은 스케줄러 구현에 따라 달라질 수 있지만, Lucid 베이스는 PyTorch와 동일한 인터페이스를 따른다(사용자가 순서를 정한다).

🛠 Optimizer/Scheduler state_dict 사례

  • Optimizer 직렬화: 파라미터를 인덱스로 매핑해 state를 저장. 모멘텀/적응적 통계 등 하위 클래스 상태를 그대로 포함.
  • Scheduler 직렬화: last_epoch와 최근 학습률, 스텝 카운트를 저장. param_group 수가 다르면 로딩 시 에러.
  • 로드 순서: 일반적으로 Optimizer를 먼저 로드하고, 동일 Optimizer 인스턴스를 가진 Scheduler를 로드한다. 예:
    opt.load_state_dict(opt_state)
    sched = MyScheduler(opt, ...)
    sched.load_state_dict(sched_state)
  • 주의: 모델 파라미터 순서가 바뀌면 Optimizer/Scheduler state 복원이 실패하거나 잘못된 파라미터에 상태가 매핑될 수 있다. model.state_dict와 함께 저장/로드 순서를 유지해야 한다.

📚 사용 예제 (베이스 설계 활용)

단일 그룹 Optimizer + 스케줄러

model = MyModel()
opt = MyOptimizer(model.parameters(), defaults={"lr": 1e-3})
sched = MyScheduler(opt, last_epoch=-1)

for epoch in range(10):
    for x, y in loader:
        loss = criterion(model(x), y)
        opt.zero_grad()
        loss.backward()
        opt.step()
    sched.step()  # epoch 단위 스텝

다중 그룹 + 다른 lr/weight_decay

backbone = {"params": model.backbone.parameters(), "lr": 1e-4}
head = {"params": model.head.parameters(), "lr": 1e-3, "weight_decay": 1e-4}
opt = MyOptimizer([backbone, head], defaults={"lr": 1e-3, "weight_decay": 0.0})
sched = MyScheduler(opt, ...)

state_dict 저장/로드

opt_state = opt.state_dict()
sched_state = sched.state_dict()

# ... save to disk ...

opt.load_state_dict(opt_state)
sched.load_state_dict(sched_state)

🔍 설계 노트와 어려움

  1. 파라미터 타입 검증: optimizer가 Tensor를 받으면 grad 추적은 되지만 상태 저장/업데이트가 어긋난다. → 생성 시 nn.Parameter 타입을 강제 검증.
  2. param_group 중복: 같은 파라미터가 여러 그룹에 들어가면 이중 업데이트가 발생. → add_param_group에서 중복 검사 후 예외.
  3. state_dict 매핑: 파라미터 객체 참조는 직렬화 불가. → 인덱스로 매핑하고 로드시 현재 파라미터 순서로 역매핑.
  4. 스케줄러-Optimizer 동기화: param_group 수가 다르면 학습률 리스트 길이가 맞지 않아 런타임 에러. → state_dict 로드 시 그룹 수 검증.
  5. step 순서: 사용자가 스케줄러 step과 optimizer step 순서를 혼동할 수 있음. → PyTorch와 동일한 인터페이스를 유지하고, 문서에서 사용 패턴을 안내.
  6. verbose/logging: 학습률 변경을 추적하려면 로그가 필요. → verbose 플래그로 각 스텝의 lr을 출력하는 옵션 제공.

✅ 정리

Lucid의 Optimizer/LRScheduler 베이스는 PyTorch 호환성을 목표로, Parameter만 다루는 optimizerparam_group 학습률을 갱신하는 스케줄러라는 단순한 원리를 따른다. 파라미터/버퍼 시스템, 모듈 트리, state_dict 직렬화 흐름과 결합해, 학습 루프에서 opt.zero_grad → backward → opt.step → sched.step 이라는 익숙한 패턴을 그대로 사용할 수 있다. 구체적 알고리즘(SGD, Adam, CosineLR 등)은 이 베이스 위에 얹기만 하면 되며, 상태 저장/로드, 그룹 관리, lr 갱신은 베이스가 일관되게 처리한다.

profile
Korea Univ. Computer Science & Engineering

0개의 댓글