공식문서 링크 : https://tutorials.pytorch.kr/beginner/saving_loading_models.html
torch.save
직렬화된 객체를 디스크에 저장합니다. 이 함수는 Python의 pickle 을 사용하여 직렬화합니다. 이 함수를 사용하여 모든 종류의 객체의 모델, Tensor 및 사전을 저장할 수 있습니다.
torch.load
pickle을 사용하여 저장된 객체 파일들을 역직렬화하여 메모리에 올립니다. 이 함수는 데이터를 장치에 불러올 때에도 사용됩니다. (장치 간 모델 저장하기 & 불러오기 참고)
torch.nn.Module.load_state_dict
역직렬화된 state_dict 를 사용하여 모델의 매개변수들을 불러옵니다.
state_dict 객체는 Python 사전이기 때문에 쉽게 저장하거나 갱신하거나 바꾸거나 되살릴 수 있으며, PyTorch 모델과 옵티마이저에 엄청난 모듈성(modularity)을 제공합니다.
참고한 블로그 링크 : https://justkode.kr/deep-learning/pytorch-save
import torch
import torch.nn as nn
# 모델 생성
x_data = torch.Tensor([
[0, 0],
[1, 0],
[1, 1],
[0, 0],
[0, 0],
[0, 1]
])
y_data = torch.LongTensor([
0, # etc
1, # mammal
2, # birds
0,
0,
2
])
class DNN(nn.Module):
def __init__(self):
super(DNN, self).__init__()
self.w1 = nn.Linear(2, 10)
self.bias1 = torch.zeros([10])
self.w2 = nn.Linear(10, 3)
self.bias2 = torch.zeros([3])
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=0)
def forward(self, x):
y = self.w1(x) + self.bias1
y = self.relu(y)
y = self.w2(y) + self.bias2
return y
model = DNN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# epoch을 돌 때마다 전체모델, 모델state_dict, 여러정보를 각각 저장한다.
for epoch in range(10):
output = model(x_data)
loss = criterion(output, y_data)
optimizer.zero_grad()
loss.backward()
optimizer.step()
PATH = '저장할 경로 알아서 설정'
torch.save(model, PATH + f'{epoch}th_model.pt') # 전체 모델 저장
torch.save(model.state_dict(), PATH + f'{epoch}th_model_state_dict.pt') # 모델 객체의 state_dict 저장
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}, PATH + f'{epoch}th_all.tar')
print("progress:", epoch, "loss=", loss.item())
model = torch.load(PATH + 'model.pt') # 전체 모델을 통째로 불러옴, 클래스 선언 필수
model.load_state_dict(torch.load(PATH + 'model_state_dict.pt')) # state_dict를 불러 온 후, 모델에 저장
checkpoint = torch.load(PATH + 'all.tar') # dict 불러오기
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
torch.save(modelA.state_dict(), PATH) # 저장하기
modelB = TheModelBClass(*args, **kwargs) # 불러오기
modelB.load_state_dict(torch.load(PATH), strict=False)
torch.save(modelA.state_dict(), PATH) # 저장하기
modelB = TheModelBClass(*args, **kwargs) # 불러오기
modelB.load_state_dict(torch.load(PATH), strict=False)