모델을 학습한 다음에는 이를 바로 활용하기보다는, 학습된 모델 자체나 가중치를 저장한 다음에 실질적으로 테스트할 때 가져와서 활용하는 과정을 거칩니다. 이 포스트에서는 이 과정을 수행하기 위한 학습된 모델을 저장하고 불러오는 방법에 대해 알아보겠습니다.
학습한 모델을 저장하고, 불러오는 가장 간단한 방법은 인스턴스 자체를 저장하고 불러오는 방법이라고 할 수 있습니다. 이 방식을 사용해서 모델을 저장하고 불러오겠습니다.
# save model
torch.save(model, "./model.pth")
# load model
model = torch.load("./model.pth")
위 코드와 같이 save
와 load
함수를 사용해서 직렬화/역직렬화를 통해 바로 저장과 로딩을 수행할 수 있습니다. 참고로 모델의 확장자는 .pt
나 .pth
를 일반적으로 사용합니다.
state_dict
는 레이어들을 딕셔너리로 나타내, 해당 레이어의 가중치를 텐서로 가지고 있습니다. 모델을 기준으로 간단하게 말하면, 각 레이어에 해당하는 가중치 정보라고 보면 됩니다. 이 state_dict
를 저장하고, 추론할 때 모델을 정의한 뒤 가중치만 state_dict
에서 가져오는 방식이라고 할 수 있습니다. 이 방식이 모델 전체를 저장하고 불러오는 방식보다 자주 사용됩니다.
그러면 state_dict
를 저장하고, 불러와보겠습니다.
# save model state_dict
torch.save(model.state_dict(), "./model_state_dict.pth")
# load model state_dict
model = MyModel()
model.load_state_dict(torch.load("./model_state_dict.pth"))
모델 전체를 불러오는 코드와는 다르게, state_dict
를 불러오는 경우에는 load
를 사용한 뒤에 추가적으로 load_state_dict
를 사용해야 합니다.
몇 에포크를 학습했는지 알고 싶거나, 재학습을 하고 싶는 상황에서는 모델의 가중치뿐만 아니라 다양한 값들을 저장할 필요가 있습니다. 이런 모델 가중치뿐만 아니라 다양한 정보를 통틀어서 체크포인트라고 합니다. 이는 딕셔너리를 활용하는 방식을 사용하는데, 이 방식을 활용해 체크포인트를 저장하고 불러오는 코드를 나타내보겠습니다.
# save checkpoint
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}, "./checkpoint.tar")
# load checkpoint
checkpoint = torch.load("./checkpoint.tar")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
save
함수를 통해 체크포인트를 저장할 때 딕셔너리에 여러 state_dict
와 학습한 epoch를 넣어서 저장했습니다. load
함수로 딕셔너리를 불러온 다음에, 해당하는 값들을 매핑해주는 방식으로 수행하게 됩니다. 참고로 체크포인트의 확장자는 보통 .tar
를 사용합니다.
학습은 GPU로 하고, 추론은 CPU로 하는 등 모델을 불러와서 사용하는 장치가 바뀔 수 있습니다. 이 때 별도로 사용할 위치를 지정해야 합니다. 이 때 load
함수에서 사용하는 map_location
을 사용합니다. 여기에 cpu
나 cuda
등을 넣으면 되는데, 환경에 따라 자동으로 배치하는 코드를 구현해보겠습니다.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("./model_state_dict.pth", map_location=device)
이 방법은 device
를 미리 정의해서 load
의 인수로 넣어주는 방식입니다. torch.cuda.is_available()
함수를 통해 사용할 수 있는 GPU가 있으면 cuda
, 아니면 cpu
를 사용하도록 설정할 수 있습니다.
전이 학습에서 이미 정형화된 모델의 구조를 커스텀하게 뜯어고친 상황에서, 바로 가중치를 불러오면 저장된 state_dict
와 모델의 구조가 맞지 않기 때문에 에러가 발생합니다. 이를 해결하기 위해 다음과 같은 방법을 사용합니다.
model.load_state_dict(torch.load("./model_state_dict.pth"), strict=False)
load_state_dict
의 strict
인수에 False
를 명시하면 됩니다. 이 상황에서 키와 맞지 않는 가중치들은 무시하고 사용할 수 있는 것들만 불러오는 효과가 나타납니다.