특히 큰 모델을 학습할 때일수록 Learning rate가 중요해진다. Learning rate를 키웠다가 (Warmup) 다시 줄이는 (Decay) 하는 방식은 모델이 Local minima 에 빠지지 않도록 하는 데 굉장한 도움이 된다.
이에 Pytorch 코드를 첨부하여 Learning rate warmup & Decay를 한번에 관장할 수 있는 프레임워크를 정리하고자 한다.
class WarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer: torch.optim, warmup_steps: int, base_lr: float, last_epoch: int = -1):
self.warmup_steps = warmup_steps
self.base_lr = base_lr
super(WarmupScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
# Linearly increase the learning rate
return [self.base_lr * (self.last_epoch + 1) / self.warmup_steps for _ in self.optimizer.param_groups]
else:
# Use the base learning rate after warmup
return [self.base_lr for _ in self.optimizer.param_groups]
scheduler_g_warmup = WarmupScheduler(optim_g, warmup_steps=5, base_lr=cfg.train["lr"])
이렇게 하면 정해진 epoch까지, 그리고 initial learning rate까지 0에서부터 천천히 Linear 하게 learning rate를 증가시킬 수 있는 스케줄러를 만들 수 있다.
optim_g = torch.optim.AdamW(generator.parameters(),
cfg.train["lr"], betas=[cfg.train["beta1"], cfg.train["beta2"]])
scheduler_g_decay = torch.optim.lr_scheduler.ExponentialLR(optim_g,
gamma=cfg.train["weight_decay"], last_epoch=last_epoch)
원하는 Optimizer를 설정하고, Decay scheduler 도 설정한다. 예시에서는 decay scheduler가 exponential decay이다.
if epoch < warmup_epoch:
scheduler_g_warmup.step()
else:
scheduler_g_decay.step()
마지막으로 epoch 별로 scheduler를 업데이트 할 때 미리 정의한 warmup epoch 과 비교하여 어떤 스케줄러로 learning rate를 설정할 것인지 정의하면 끝난다.
print("Epoch: {} Learning Rate: {}".format(epoch + 1, optim_g.param_groups[0]["lr"]))
이런식으로 각 epoch 마다 learning rate가 얼마인지 확인해 볼 수도 있다.