Save and Load the Model

안소희·2024년 10월 11일

PyTorch

목록 보기
8/8

PyTorch에서는 학습된 모델을 저장하고 불러오는 기능을 제공한다. 이는 추후 모델 재학습 없이 예측에 활용할 수 있도록 함.

1. 모델 가중치 저장하기

모델의 학습된 매개변수는 state_dict에 저장된다. 이를 torch.save로 파일에 저장할 수 있다.

import torch
import torchvision.models as models

# VGG16 모델 가중치 저장
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

2. 모델 가중치 불러오기

저장한 가중치를 불러오려면 동일한 모델을 생성한 후 load_state_dict()로 가중치를 불러온다.

# 동일한 VGG16 모델 생성
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # 모델을 평가 모드로 설정

3. 모델 전체 저장하기

모델의 구조와 가중치를 함께 저장하려면 torch.save로 모델 자체를 저장한다.

# 모델과 구조를 함께 저장
torch.save(model, 'model.pth')

# 모델 불러오기
model = torch.load('model.pth')

참고사항

  • model.eval(): 추론(inference) 전에 호출하여 드롭아웃과 배치 정규화가 올바르게 동작하도록 설정.
  • 모델 저장 방식: 모델 구조와 함께 저장하면 클래스 정의가 필요하며, pickle을 사용.
profile
인공지능.관심 있습니다.

0개의 댓글