[torch] checkpoint 저장 및 재학습

영이·2024년 8월 28일

pytorch

목록 보기
3/3

체크포인트를 저장할 때 epoch, state_dict를 저장할 수 있다. 필요에 따라 step도 저장하면 된다.

if i % 100 == 0:
	print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_loader)}], Loss: {running_loss / i:.4f}")

	# 100 스텝마다 체크포인트 저장
	torch.save({
				'epoch': epoch,
        		'model_state_dict': model.state_dict(),
        		'optimizer_state_dict': optimizer.state_dict(),
        		'loss': running_loss,
        		'step': i,
        		}, f".../checkpoint/epoch{epoch}_step{i}.pth")
	print(f"Checkpoint saved at epoch {epoch+1}, step {i}.")

하나의 epoch를 다 돌기에 step 수가 너무 많아서 일부러 중간에 저장하게끔 코드를 짰다. 이를 활용하면 다시 학습하게 될 때 step 수를 불러와서 중간부터 시작할 수 있다... 라고 생각을 했는데,
뭔가 잘못된 판단이었나보다.

다시 학습을 시키기 위해서 다음과 같은 코드를 사용해서 배치를 넘기고 중간부터 시작할 수 있게끔 코드를 짰다.

for _ in tqdm(range(start_step)):
	next(train_loader_iterator, None)

(물론 GPT의 도움을 받았다.)
그런데 남은 실행시간을 보니;

Checkpoint loaded. Resuming training from epoch 1, step 5200.
  7%|▋         | 366/5200 [33:05<6:10:57,  4.60s/it]

이렇게나 오래걸린다는 것이다.
별로 의미 없는 전략이었을 수도 있겠다는 생각이 들었다. 😅

profile
연구가 싫었는데 어쩌다보니 대학원생이 되어버린 몸

0개의 댓글