PyTorch 구조 - 2

c0natus·2022년 2월 6일
0

PyTorch

목록 보기
3/4

1. 모델 저장 및 불러오기


  • 딥러닝을 학습하는데 오랜 시간이 걸리고, 예기치 못하게 학습이 중단될 수 있다.

  • 이를 보완하기 위해, 학습 결과를 저장해야할 필요가 있다.

1.1. torch.save()

  • PyTorch에서는 모델을 저장하기 위해, torch.save() 함수를 제공한다.

  • torch.save()를 통해, 모델의 구조와 모델의 parameter를 저장할 수 있다.

  • 학습 중간 과정을 저장하여 최선의 결과 모델을 선택하는 Early Stop을 할 수 있고, 만들어진 모델을 타인(연구자 등)과 공유하여 학습 재연성을 향상할 수 있다.

1.2. mymodel.state_dict()

  • torch.nn.Module은 state_dict() method를 통해 모델의 parameter를 collections.OrderedDict 타입으로 정리해서 반환한다.

1.3. torchsummary

  • 해당 패키지는 model의 구조와 parameter를 keras 처럼 모델을 출력해준다.

1.4. torch.load()

  • 저장된 모델을 불러와서 변수에 할당하기 위해 torch.load() 함수를 사용해야 한다.

1.5. model.load_state_dict()

  • torch.nn.Module에는 load_state_dict() method를 통해 저장된 parameter를 불러 모델의 parameter로 지정할 수 있다.

  • 주의할 점은 해당 새로운 model은 parameter를 저장할 때와 같은 model 구조를 가져야 한다.

1.6. 예제: model, parameter 저장 및 불러오기


class MyClass(torch.nn.Module):
    ...

model = MyClass()

MODEL_PATH ="saved"

# 파라미터 저장
if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)
torch.save(model.state_dict(), 
           os.path.join(MODEL_PATH, "model.pt"))

# parameter를 load할 것이기 때문에,
# new model은 꼭 동일한 모델이어야 한다.
new_model = MyClass()
new_model.load_state_dict(torch.load(os.path.join(
    MODEL_PATH, "model.pt")))


# 모델 자체(구조와 파라미터)를 저장후 load
torch.save(model, os.path.join(MODEL_PATH, "model_pickle.pt"))
model = torch.load(os.path.join(MODEL_PATH, "model_pickle.pt"))
from torchsummary import summary
summary(model, (3, 224, 224))

1.7. checkpoint

  • 학습의 중간 결과물을 저장해서 최선의 결과를 선택하는 earlystopping 기법을 사용할 수 있다.

  • 일반적으로 epoch 그리고 train data와 validation data의 loss, metric 값을 지속적으로 저장 및 확인하여 최선의 결과를 선택한다.

1.8. 예제: checkpoint

class MyClass(torch.nn.Module):
    ...

model = MyClass()

torch.save({
            'epoch': e,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
            },
            f"saved/checkpoint_{e}_{epoch_loss / len(dataloader)}_{epoch_acc / len(dataloader)}" # len(dataloader) == batch_size?
        )

new_model = MyClass()

checkpoint = torch.load(PATH)
new_model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

2. Trnasfer Learning


  • Tranfer Learning은 다른 사람(data set)이 만든 모델(pretrained model)에 현재 데이터를 적용하는 것이다.

  • 현재의 Deep Learning에서는 가장 일반적인 학습 기법이다.

  • backbone architecture가 잘 학습된 모델에서 일부분만 변경하여 학습을 수행한다.

  • TorchVision 모듈에서 다양한 기본 모델을 제공한다.

  • NLP는 HuggingFace가 사실장 표준이다.

2.1. Freezing

  • pretrained model을 활용시 모델의 일부분을 frozen 시킨다. 참고

2.2. 예제: vgg16

  • pretrained model
import torch
from torch import nn
from torchvision import models


class MyNewNet(nn.Module):   
    def __init__(self):
        super(MyNewNet, self).__init__()
        self.vgg19 = models.vgg19(pretrained=True)
        
        # 모델의 마지막에 linear layer 추가
        self.linear_layers = nn.Linear(1000, 1)


    # Defining the forward pass    
    def forward(self, x):
        x = self.vgg19(x)        
        return self.linear_layers(x)

# 마지막 layer를 제외하고 frozen
for param in my_model.parameters():
	# frozen
    param.requires_grad = False
    
for param in my_model.linear_layers.parameters():
    param.requires_grad = True

3. Monitoring tools for PyTorch


  • 학습이 길어지면 오래 기다려야 한는데, 이때 중간 과정들을 기록해주는 도구들을 소개한다.

3.1. Tensorboard

  • TensorFlow의 프로젝트로 만들어진 시각화 도구로, 학습 그래프/metric/학습 결과의 시각화를 지원한다.

  • PyTorch도 연결 가능하다.

3.2. weight & biases

  • 머신러닝/딥러닝 실험을 원활히 지원하기 위한 상용도구이다.

  • 협업, code versioning, 실험 결과 기록 등을 제공한다.

profile
Done is Better Than Perfect

0개의 댓글