1.state_dict
model, optimizer의 정보를 저장
import torch
class SomeModel
model = SomeModel()
optim = torch.optim.SGD(model.parameters(),lr)
model.state_dict()
optim.state_dict()
2.save, load state_dict
torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH))
3.체크포인트 설정
학습시 model, optimizer, loss, epoch에 대한 정보를 사전형식으로 저장
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
writer = SummaryWriter('logs')
writer.add_images('', image, step) <-step ended