PyTorch에서는 학습된 모델을 저장하고 불러오는 기능을 제공한다. 이는 추후 모델 재학습 없이 예측에 활용할 수 있도록 함.
모델의 학습된 매개변수는 state_dict에 저장된다. 이를 torch.save로 파일에 저장할 수 있다.
import torch
import torchvision.models as models
# VGG16 모델 가중치 저장
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
저장한 가중치를 불러오려면 동일한 모델을 생성한 후 load_state_dict()로 가중치를 불러온다.
# 동일한 VGG16 모델 생성
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 모델을 평가 모드로 설정
모델의 구조와 가중치를 함께 저장하려면 torch.save로 모델 자체를 저장한다.
# 모델과 구조를 함께 저장
torch.save(model, 'model.pth')
# 모델 불러오기
model = torch.load('model.pth')
model.eval(): 추론(inference) 전에 호출하여 드롭아웃과 배치 정규화가 올바르게 동작하도록 설정.pickle을 사용.