8. 파이토치(PyTorch) 튜토리얼 - 모델 저장하고 불러오기

Yeonghyeon·2022년 7월 30일
0
post-custom-banner

본 포스팅은 파이토치(PYTORCH) 한국어 튜토리얼을 참고하여 공부하고 정리한 글임을 밝힙니다.


모델 저장하고 불러오기

모델 저장하기나 불러오기를 통해 모델의 상태를 유지하고 모델의 예측을 실행해보자

import torch
import torchvision.models as models

모델 가중치 저장하고 불러오기

  • PyTorch 모델은 학습한 매게변수를 state_dict라고 불리는 내부 상태 사전에 저장
  • 이 상태 값들은 torch.save 메소드를 사용하여 저장 가능
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
  • 모델 가중치 불러오기
    • 동일한 모델의 인스턴스 생성 후 load_state_dict() 메소드 사용
model = model.vgg16() # 기본 가중치를 불러오지 않으므로 pretrained=True 지정 x
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

추론을 하기 전에 model.eval() 메소드 호출하여 드롭아웃과 배치 정규화를 평가 모드로 설정 ➡️ 그렇지 않으면 일관성 없는 추론 결과가 생성됨

모델의 형태를 포함하여 저장하고 불러오기

  • 모델 가중치 불러올 때, 신경망 구조를 정의하기 위해 모델 클래스를 먼저 생성해야 했음
  • 이 클래스의 구조를 모델과 함께 저장 ➡️ model.state_dict()가 아닌 model 자체를 저장 함수에 전달
torch.save(model, 'model.pth')

모델 불러오기

model = torch.load('model.pth')

이 접근 방식은 Python pickle 모듈을 사용하여 모델을 직렬화하므로, 모델 불러올 때 실제 클래스 정의를 적용해야 함

post-custom-banner

0개의 댓글