PyTorch는 3가지 방법으로 모델 저장을 지원한다.
state_dict
저장하기torch.save(model.state_dict(), 'model_state_dict.pt')
loaded_model = Model()
loaded_model.load_state_dict(torch.load('model_state_dict.pt'))
nn.Module
을 상속하는 클래스와 optimizer
는 state_dict()
함수를 통해 state_dict
를 얻을 수 있다. state_dict
는 해당 모델에 존재하는 모든 학습가능한 파라미터를 값으로 가지는 OrderedDict
이다. state_dict
값만 있으면 모델이 학습된 상태를 그대로 복원할 수 있고 torch.save(state_dict, path)
를 통해 저장할 수 있다. torch.save
는 내부적으로 pickle
을 사용하여 모든 파이썬 객체를 직렬화한다. torch.load(path)
를 통해 state_dict
를 불러와서 모델의 load_state_dict(state_dict)
를 통해 파라미터를 저장되었던 값으로 복원할 수 있다.
torch.save(model, 'model_state_dict.pt')
loaded_model = torch.load('model_state_dict.pt')
torch.save(model, path)
를 통해 모델 자체를 넘겨 저장할 수 있다. 불러올 땐 torch.load(path)
를 통해 바로 파라미터가 저장된 모델을 얻을 수 있다. 1번 방법에 비해 코드가 아주 간단해지지만 지양하자. 모델 클래스의 구조가 저장되지는 않고 모델 클래스의 경로가 저장된다. 모델을 불러올 때마다 저장된 모델 클래스의 경로를 참조하여 모델 객체를 생성하여 파라미터를 채워 반환하는 방식이다. 따라서 나중에 프로젝트 파일구조가 수정되면 모델을 불러올 수 없게 된다.
TorchScript
로 저장하기model_scripted = torch.jit.script(model)
model_scripted.save('model_state_dict.pt')
loaded_model = torch.jit.load('model_state_dict.pt')
torch.jit.script(model)
를 통해 TorchScript
로 만들 수 있고 이것을 save(path)
함수를 통해 저장할 수 있다. 불러올 땐 torch.jit.load(path)
로 불러올 수 있다. 모델을 통째로 저장할 수 있고 모델 클래스의 구조를 저장하기 때문에 2번 방법의 단점을 보완한 방법이다.
2번과 3번은 모델과 묶여 저장되기 때문에 저장된 파라미터만을 가지고 다른 일을 할 수가 없다. 반면 1번 방법은 모델과 상관없이 파라미터를 저장하기 때문에 파라미터 자체에 조작을 가할 수 있고 다른 모델으로 불러와 전이학습을 할 수도 있다. 따라서 1번 방법을 앞으로 사용할 것 같다.
참고자료
https://pytorch.org/tutorials/beginner/saving_loading_models.html