Pytorch best model 저장하기

highway92·2022년 2월 3일
1

머신러닝

목록 보기
3/5

딥러닝을 하다보니 이런일이 생기게 되었다.

이렇게 train 중에 점수가 좋게 나온녀석이 있다면 저장해두고 싶어진 것이다.

여러가지 방법이 있지만 내가 사용한 방법과 pytorch 공식문서의 권장 사항을 설명하고자 한다.

1. 내가 사용한 방법

필자는 코랩을 사용하였기 때문에 model을 파일로 따로 저장한다거나 하는 귀찮은 짓을 하고 싶지 않았다.

참고로 model을 저장한다는 것은 파라미터들을 저장한다는 것이다.
아래에서 설명하겠지만 torch.save를 하게 되면 ordered dict 형태로 파라미터들이 저장된다.

그래서 다음과 같은 방법을 사용했다.

이런식으로 best_acc와 f1 변수를 만들고 epoch을 돌면서 이보다 좋은 점수가 나오면 최신화하는 식이다.

여기서 주의해야 할 점은 best_f1_model = model.state_dict()와 같이 할당 하는 것이 아니라.
deepcopy를 사용하여야 한다는 것이다. 이에 대한 내용은 공식문서를 참고하자.
https://tutorials.pytorch.kr/beginner/saving_loading_models.html

전체코드는 다음과 같다.

from copy import deepcopy
loss_list = []
acc_list = []
best_acc = 0
best_f1 = 0
best_acc_model = None 
best_f1_model = None
from sklearn.metrics import f1_score


for epoch in range(EPOCHS):
    for i, (X_batch, y_batch) in enumerate(trainloader):
        #Forward 
        X_batch = X_batch.to(DEVICE)
        y_batch = y_batch.to(DEVICE)
        y_output = model(X_batch)

        loss = criterion(y_output, y_batch) 

        #Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #misc (acc 계산, etc) 
        y_pred = torch.max(y_output, 1)[1]
        acc = accuracy_score(y_pred.data.cpu(), y_batch.data.cpu())
        f_score = f1_score(y_pred.data.cpu(), y_batch.data.cpu(), average='macro')
        if acc > best_acc:
          best_acc = acc
          best_acc_model = deepcopy(model.state_dict())

        if f_score > best_f1:
          best_f1 = f_score
          best_f1_model = deepcopy(model.state_dict())

        loss_list.append(loss.item())
        acc_list.append(acc)

    if (epoch+1) % 10 == 0:
        print('Epoch [{}/{}] Step [{}/{}] Loss: [{:.4f}] Train ACC [{:.2f}%] F1 Score [{:.2f}%]'.format(epoch+1, EPOCHS, \
                                                                                   i+1, len(trainloader), loss.item(), acc*100, f_score * 100))
        

2. pytorch 권장방법

pytorch에서는 파일로 저장하고 불러오는 것을 권장한다고 한다. 위의 과정에서 torch.save만 추가해주면 된다.

저장하기

model = SomeModel()
torch.save(model.state_dict(), PATH)

불러오기

model = SomeModel()
model.load_state_dict(torch.load(PATH))
model.eval()

전체코드 : https://github.com/highway92/machine_learning/blob/main/year_dream/do/MLP%EB%A5%BC_%EC%9D%B4%EC%9A%A9%ED%95%9C_%EA%B8%88%EC%9C%B5%EB%8D%B0%EC%9D%B4%ED%84%B0_%EB%B6%84%EC%84%9D(best_model).ipynb

profile
웹 개발자로 활동하고 있습니다.

0개의 댓글