긴 시간동안 학습을 해야하는 경우 모델을 학습 도중에 저장해야하는 상황이나 학습이 끝난 모델을 저장해야하는 경우에 torch가 지원하는 함수를 통해서 model save 기능을 사용할 수 있다.
save()
- 학습 결과를 저장해 주는 함수
- 모텔 architecture와 parameter를 저장
- early stopping 구현 가능
torch.save(model.state_dict(), "user_path/model.pt") # 모델의 parameter 정보를 담은 state_dict()과, 저장 위치
# architecture 함께 저장하려면
# torch.save(model, "user_path/model.pt")
model_load = ModelCalss() # 같은 모델 형태에서만 load 가능
model_load.load_state_dict(torch.load("user_path/model.pt"))
checkpoints
- 학습 중간 결과를 저장해서 최선의 결과를 선택할 수 있다.
- loss, metric 값을 지속적으로 저장
- epoch, loss, metric 함께 저장
torch.save({
'epoch':epoch,
'model_state_dict':model.state_dict(),
'optimizer_state_dict':optimizer.state_dict(),
'loss':epoch_loss,
}, "path/model.pt")
resnet = models.resnet18(pretrained=True).to(device) # models가 제공하는 ResNet-18 모델 사용
resnet.fc = nn.Linear(in_features=512, out_features=class_num) # resnet의 마지막 FullyConnected layer를 나의 데이터셋에 맞게 수정(out_features)
#resnet.fc.weight = nn.init.xavier 같은 메서드로 초기화 가능
#resnet.fc.bias = 1.0/np.sqrt(class_num)
학습시간이 길어질수록 기록이 필요하다. 학습과정을 기록할 수 있는 도구들을 사용한다.
from torch.utils.tensorboard import SummaryWriter
weriter = SummaryWriter(path)
for n_iter in range(100):
wrtier.add_scalar('Loss/train', value, n_iter)
...
weriter.flush()