PyTorch 기초 - 모델을 저장하고 불러오기

sp·2022년 3월 1일
1

PyTorch 기초

목록 보기
6/7
post-thumbnail
post-custom-banner

모델을 학습한 다음에는 이를 바로 활용하기보다는, 학습된 모델 자체나 가중치를 저장한 다음에 실질적으로 테스트할 때 가져와서 활용하는 과정을 거칩니다. 이 포스트에서는 이 과정을 수행하기 위한 학습된 모델을 저장하고 불러오는 방법에 대해 알아보겠습니다.

모델 전체를 저장하고 불러오기

학습한 모델을 저장하고, 불러오는 가장 간단한 방법은 인스턴스 자체를 저장하고 불러오는 방법이라고 할 수 있습니다. 이 방식을 사용해서 모델을 저장하고 불러오겠습니다.

# save model
torch.save(model, "./model.pth")

# load model
model = torch.load("./model.pth")

위 코드와 같이 saveload 함수를 사용해서 직렬화/역직렬화를 통해 바로 저장과 로딩을 수행할 수 있습니다. 참고로 모델의 확장자는 .pt.pth를 일반적으로 사용합니다.

state_dict를 저장하고 불러오기

state_dict는 레이어들을 딕셔너리로 나타내, 해당 레이어의 가중치를 텐서로 가지고 있습니다. 모델을 기준으로 간단하게 말하면, 각 레이어에 해당하는 가중치 정보라고 보면 됩니다. 이 state_dict를 저장하고, 추론할 때 모델을 정의한 뒤 가중치만 state_dict에서 가져오는 방식이라고 할 수 있습니다. 이 방식이 모델 전체를 저장하고 불러오는 방식보다 자주 사용됩니다.

그러면 state_dict를 저장하고, 불러와보겠습니다.

# save model state_dict
torch.save(model.state_dict(), "./model_state_dict.pth")

# load model state_dict
model = MyModel()
model.load_state_dict(torch.load("./model_state_dict.pth"))

모델 전체를 불러오는 코드와는 다르게, state_dict를 불러오는 경우에는 load를 사용한 뒤에 추가적으로 load_state_dict를 사용해야 합니다.

체크포인트로 저장하고 불러오기

몇 에포크를 학습했는지 알고 싶거나, 재학습을 하고 싶는 상황에서는 모델의 가중치뿐만 아니라 다양한 값들을 저장할 필요가 있습니다. 이런 모델 가중치뿐만 아니라 다양한 정보를 통틀어서 체크포인트라고 합니다. 이는 딕셔너리를 활용하는 방식을 사용하는데, 이 방식을 활용해 체크포인트를 저장하고 불러오는 코드를 나타내보겠습니다.

# save checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch
}, "./checkpoint.tar")

# load checkpoint
checkpoint = torch.load("./checkpoint.tar")

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']

save 함수를 통해 체크포인트를 저장할 때 딕셔너리에 여러 state_dict와 학습한 epoch를 넣어서 저장했습니다. load 함수로 딕셔너리를 불러온 다음에, 해당하는 값들을 매핑해주는 방식으로 수행하게 됩니다. 참고로 체크포인트의 확장자는 보통 .tar를 사용합니다.

장치 간 모델 불러오기

학습은 GPU로 하고, 추론은 CPU로 하는 등 모델을 불러와서 사용하는 장치가 바뀔 수 있습니다. 이 때 별도로 사용할 위치를 지정해야 합니다. 이 때 load 함수에서 사용하는 map_location을 사용합니다. 여기에 cpucuda 등을 넣으면 되는데, 환경에 따라 자동으로 배치하는 코드를 구현해보겠습니다.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("./model_state_dict.pth", map_location=device)

이 방법은 device를 미리 정의해서 load의 인수로 넣어주는 방식입니다. torch.cuda.is_available() 함수를 통해 사용할 수 있는 GPU가 있으면 cuda, 아니면 cpu를 사용하도록 설정할 수 있습니다.

모델의 구조가 다를 때 가중치 불러오기

전이 학습에서 이미 정형화된 모델의 구조를 커스텀하게 뜯어고친 상황에서, 바로 가중치를 불러오면 저장된 state_dict와 모델의 구조가 맞지 않기 때문에 에러가 발생합니다. 이를 해결하기 위해 다음과 같은 방법을 사용합니다.

model.load_state_dict(torch.load("./model_state_dict.pth"), strict=False)

load_state_dictstrict 인수에 False를 명시하면 됩니다. 이 상황에서 키와 맞지 않는 가중치들은 무시하고 사용할 수 있는 것들만 불러오는 효과가 나타납니다.

post-custom-banner

0개의 댓글