state_dict는 간단히 말해 각 계층 매개변수 텐서로 매핑되는 파이썬 사전 객체다. 모델의 학습 가능한 매개변수들은 모델의 매개변수에 포함되어 있다. state_dict는 간단히 말해 각 계층을 매개변수 텐서로 매핑되는 dict 객체이다. 이 때, 학습 가능한 매개변수를 갖는 계층 및 등록된 버퍼들만이 모델의 state_dict 항목을 가진다.
torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
모델의 state_dict를 저장하는 것이 나주에 모델을 사용할 때 가장 유연하게 사용할 수 있는 모델 저장 시 권장하는 방법이다.
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
일반 체크포인트를 저장할 때는 반드시 모델의 state_dict보다 많은 것을 저장해야 한다. state_dict만을 저장하는 것보다 파일의 크기 자체는 2~3배 정도 커진다.