체크포인트를 저장할 때 epoch, state_dict를 저장할 수 있다. 필요에 따라 step도 저장하면 된다.
if i % 100 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_loader)}], Loss: {running_loss / i:.4f}")
# 100 스텝마다 체크포인트 저장
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': running_loss,
'step': i,
}, f".../checkpoint/epoch{epoch}_step{i}.pth")
print(f"Checkpoint saved at epoch {epoch+1}, step {i}.")
하나의 epoch를 다 돌기에 step 수가 너무 많아서 일부러 중간에 저장하게끔 코드를 짰다. 이를 활용하면 다시 학습하게 될 때 step 수를 불러와서 중간부터 시작할 수 있다... 라고 생각을 했는데,
뭔가 잘못된 판단이었나보다.
다시 학습을 시키기 위해서 다음과 같은 코드를 사용해서 배치를 넘기고 중간부터 시작할 수 있게끔 코드를 짰다.
for _ in tqdm(range(start_step)):
next(train_loader_iterator, None)
(물론 GPT의 도움을 받았다.)
그런데 남은 실행시간을 보니;
Checkpoint loaded. Resuming training from epoch 1, step 5200.
7%|▋ | 366/5200 [33:05<6:10:57, 4.60s/it]
이렇게나 오래걸린다는 것이다.
별로 의미 없는 전략이었을 수도 있겠다는 생각이 들었다. 😅