PyTorch 모델 저장하기

seokj·2023년 1월 10일
0

PyTorch는 3가지 방법으로 모델 저장을 지원한다.

1. state_dict 저장하기

torch.save(model.state_dict(), 'model_state_dict.pt')
loaded_model = Model()
loaded_model.load_state_dict(torch.load('model_state_dict.pt'))

nn.Module을 상속하는 클래스와 optimizerstate_dict()함수를 통해 state_dict를 얻을 수 있다. state_dict는 해당 모델에 존재하는 모든 학습가능한 파라미터를 값으로 가지는 OrderedDict이다. state_dict값만 있으면 모델이 학습된 상태를 그대로 복원할 수 있고 torch.save(state_dict, path)를 통해 저장할 수 있다. torch.save는 내부적으로 pickle을 사용하여 모든 파이썬 객체를 직렬화한다. torch.load(path)를 통해 state_dict를 불러와서 모델의 load_state_dict(state_dict)를 통해 파라미터를 저장되었던 값으로 복원할 수 있다.

2. 모델을 통째로 저장하기

torch.save(model, 'model_state_dict.pt')
loaded_model = torch.load('model_state_dict.pt')

torch.save(model, path)를 통해 모델 자체를 넘겨 저장할 수 있다. 불러올 땐 torch.load(path)를 통해 바로 파라미터가 저장된 모델을 얻을 수 있다. 1번 방법에 비해 코드가 아주 간단해지지만 지양하자. 모델 클래스의 구조가 저장되지는 않고 모델 클래스의 경로가 저장된다. 모델을 불러올 때마다 저장된 모델 클래스의 경로를 참조하여 모델 객체를 생성하여 파라미터를 채워 반환하는 방식이다. 따라서 나중에 프로젝트 파일구조가 수정되면 모델을 불러올 수 없게 된다.

3. TorchScript로 저장하기

model_scripted = torch.jit.script(model)
model_scripted.save('model_state_dict.pt')
loaded_model = torch.jit.load('model_state_dict.pt')

torch.jit.script(model)를 통해 TorchScript로 만들 수 있고 이것을 save(path)함수를 통해 저장할 수 있다. 불러올 땐 torch.jit.load(path)로 불러올 수 있다. 모델을 통째로 저장할 수 있고 모델 클래스의 구조를 저장하기 때문에 2번 방법의 단점을 보완한 방법이다.


무엇을 사용해야할까?

2번과 3번은 모델과 묶여 저장되기 때문에 저장된 파라미터만을 가지고 다른 일을 할 수가 없다. 반면 1번 방법은 모델과 상관없이 파라미터를 저장하기 때문에 파라미터 자체에 조작을 가할 수 있고 다른 모델으로 불러와 전이학습을 할 수도 있다. 따라서 1번 방법을 앞으로 사용할 것 같다.


참고자료
https://pytorch.org/tutorials/beginner/saving_loading_models.html

profile
안녕하세요

0개의 댓글