pytorch model save, load

이승화·2022년 2월 25일
0

1.state_dict

model, optimizer의 정보를 저장

import torch

class SomeModel
model = SomeModel()
optim = torch.optim.SGD(model.parameters(),lr)

model.state_dict()
optim.state_dict()

2.save, load state_dict

torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH))

3.체크포인트 설정
학습시 model, optimizer, loss, epoch에 대한 정보를 사전형식으로 저장

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
            
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
  1. tensorboard 지속하기
writer = SummaryWriter('logs') 
writer.add_images('', image, step) <-step ended

0개의 댓글