https://tutorials.pytorch.kr/beginner/saving_loading_models.html
https://tutorials.pytorch.kr/beginner/transfer_learning_tutorial.html
https://tutorials.pytorch.kr/recipes/recipes/saving_and_loading_a_general_checkpoint.html
https://justkode.kr/deep-learning/pytorch-save
https://gaussian37.github.io/dl-pytorch-snippets/
patameter 확인하는 방법은 이전 post에 있다.
기본적으로 pytorch에서의 저장은 torch.save(무엇을, '어떤경로와 이름으로')
이다.
마찬가지로 불러오기는 torch.load('path')
이다.
이제 여기서 무엇을
에 다양하게 넣으면 된다.
무엇을
에는 크게 두가지가 들어갈 수 있다.
첫번째는 model
이고 두번째는 model.state_dict()
이다.
model
을 저장하면 구조와 parameter 모두 저장된다. 이때 parameter가 중요한데(아래에서 확인할 예정) 이는 현자 시점의 parameter가 저장된다. 이는 다음과 같이 생각하면 편하다.
model을 만들때 nn.Module
을 상속 받는다. 또한 super().__init__()
으로 모두 상속 받는다. 그리고 구성에는 nn.Conv2d
와 같은 layer들이 있고 이 layer 안에는 parameter가 있으며 학습중에 opt.step()
에 의해 update 된다.
즉, 실질적인parameter의 위치는 model인 것이다. 따라서 model은 구조와 parameter가 있는 것이고 model.state_dict()은 model에서 parameter를 분리한 것이다.
torch.save(model, 'path')
model을 저장 하는 방법이다.model_cpu = test_model() model_gpu = model_cpu.to('cuda:1') torch.save(model_cpu, '/home/mskang/hyeokjong/model_cpu.pt') torch.save(model_gpu, '/home/mskang/hyeokjong/model_gpu.pt')
둘다 된다.
이제 불러와 보자.model_load = torch.load('/home/mskang/hyeokjong/model_gpu.pt')
일단
model_cpu
를 찍어 보면 다음과 같이 모델의 정보가 쭉 나온다.test_model( (conv1): Conv2d(3, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (bn1): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2_): Conv2d(50, 10, kernel_size=(1, 1), stride=(1, 1), bias=False) (conv2): Conv2d(10, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (bn2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3_): Conv2d(50, 10, kernel_size=(1, 1), stride=(1, 1), bias=False) (conv3): Conv2d(10, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (bn3): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4_): Conv2d(50, 20, kernel_size=(1, 1), stride=(1, 1), bias=False) (conv4): Conv2d(20, 100, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (maxpool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (bn4): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv5_): Conv2d(100, 10, kernel_size=(1, 1), stride=(1, 1), bias=False) (conv5): Conv2d(10, 100, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (maxpool5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (bn5): BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (fc): Sequential( (0): Flatten() (1): Linear(in_features=4900, out_features=100, bias=True) (2): ReLU() (3): Linear(in_features=100, out_features=100, bias=True) ) )
이때 load한
model_load
를 찍어 봐도 같은 결과가 나온다.
다만model_load == model_cpu -------------------------------------------------- False
이건 왜 다른지 모르겠다.
아무튼 같은거 같다. 심지어는 학습을 시켰봤는데도 결과가 비슷하게 나온다.
이에 대해서는 https://stackoverflow.com/questions/69787273/pytorch-save-and-load-model 여기에 올려두었다.
추가로model_load.state_dict()
을 해주면 parameter를 볼 수 있다.
model을 저장하여 읽고
state_dict()
으로 parameter를 읽는 방법이 좋은 방법은 아니다.
(정보 손실에 문제가 있다고 한다. torch 공식문서에서도 이렇게 한다. stack overflow에 자세한 이유가 있다.)
따라서 아예model.state_dict()
를 저장하고 이를 읽는데 바로 읽을 수는 없고 구조가 같은 model에 parameter에 덮어쓰면 된다. 이제 몇가지를 실험해보고 최적의 방법을 고안해 보자.
Original model
model_before = test_model() torch.save(model_before, '/home/mskang/hyeokjong/model_before.pt') model_before_load = torch.load('/home/mskang/hyeokjong/model_before.pt') torch.save (model_before.state_dict(),'/home/mskang/hyeokjong/model_before_parameters.pt') model_parameter_load = test_model() model_parameter_load.load_state_dict(torch.load('/home/mskang/hyeokjong/model_before_parameters.pt')) ------------------------------------------------------------------------------------------ model_before.state_dict()['conv1.weight'] == model_before_load.state_dict()['conv1.weight'] model_before_load.state_dict()['conv1.weight'] == model_parameter_load.state_dict()['conv1.weight']
두 결과 모두
True
이다. 당연하긴 하다.
model
은 각 layer 들의 초기값 때문에 선언 할때마다 parameter들이 달라 지는데 위에 처럼 model을 저장하면 해당 시점의 parameter들이 저장 됨을 알 수 있다.
Trained model
model_trained = model torch.save(model_trained, '/home/mskang/hyeokjong/model_trained.pt') model_trained_load = torch.load('/home/mskang/hyeokjong/model_trained.pt') torch.save(model_trained.state_dict(),'/home/mskang/hyeokjong/model_trained_parameters.pt') model_parameter_trained_load = test_model().to('cuda:1') model_parameter_trained_load.load_state_dict(torch.load('/home/mskang/hyeokjong/model_trained_parameters.pt')) ------------------------------------------------------------------------------------------ model_trained.state_dict()['conv1.weight'] == model_trained_load.state_dict()['conv1.weight'] model_trained_load.state_dict()['conv1.weight'] == model_parameter_trained_load.state_dict()['conv1.weight']
마찬 가지로 두 결과 모두
True
이다.
Best model
이 경우에는 2-epoch만 해서 Best == trained랑 같음을 확인한다.model_parameter_best_load = test_model().to('cuda:1') model_parameter_best_load.load_state_dict(torch.load('/home/mskang/hyeokjong/birds/best_model.pt')) ------------------------------------------------------------------------------------------ model_parameter_best_load.state_dict()['conv1.weight'] == model_parameter_trained_load.state_dict()['conv1.weight']
True
이다.
state_dict()
만 저장하자.앞의 방법은 transfer하는 데에는 딱히 지장이 없을 듯 한다. (autograd만 추가한다면 -> 이전post)
여기서 하려는 것은 학습을 이어하려는 것이다.
그래서 추가적인 정보가 필요한데 optimizer
에 대한 정보이다.
즉, 이전에 몇 epoch 까지 학습하였고 lr은 얼마나 줄어들었는지 등의 추가적인 정보가 필요하다.
공식 문서에서는 loss 까지 저장 한다. 저장하는 것이 많으면 좋긴한다.
save optimizer and load
parameter를 load할때 model을 먼저 선언 해놨던 것 처럼 optimizer를 load 할때도 먼저 선언 해 둬야 한다.
optimizer의 구성을 보자.opt ----------------------------------- Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 0.001 weight_decay: 0 )
간단해 보이지만
opt.state_dict()
에는 상당히 많은 숫자가 있다. 이는Adam이여서 그렇다.
당장 필요하지는 않으므로 check point는 나중에 하기로 한다.