Transfer learning & Hyper Parameter Tuning 실습

이상민·2023년 3월 17일

이번 실습에서는 이전에 배운 transfer learninig과 hyperparameter Tuning을 학습해본다.

Transfer Learning

모델생성

imageNet에서 학습된 resnet18을 전이하여 mnist데이터를 분류하는 모델을 만들기

import torchvision
import torch
import numpy

imagenet_resnet18 = torchvision.models.resnet18(pretrained = True)#imageNet에서 학습된 ResNet18 모델을 pretrained = True로 설정하여 불러옴

Mnist Dataset불러오기

mnist_train = torchvision.dataset.MNIST(root = "./mnist",train = True, download = True)
mnist_test = torchvision.dataset.MNIST(root = "./mnist",train = Flase, download = True) #test데이터를 불러올 시 train = Flase로 설정

Mnist를 학습할 CNN모델 생성하기

mnist_resnet18 = torchvision.models.resnet18(pretrained = False)

#mnistdata set의 경우 흑백이미지로 채널이 1차원이다.
# 반면 resnet은 컬러를 분류하는 모델이므로 입력 채널의 크기가 3차원
print(np.array(mnist_train[0][0]).shape) #	(28,28)
print(imagenet_resnet18.conv1) #Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

데이터 전처리

mnist dataset을 resnet의 입력 차원과 맞춰 주기 위해 Grayscale을 사용하여 채널을 맞춰준다.

common_transform = torchvision.transforms.Compose(
  [
    torchvision.transforms.Grayscale(num_output_channels=3), # grayscale의 1채널 영상을 3채널로 동일한 값으로 확장함
    torchvision.transforms.ToTensor() # PIL Image를 Tensor type로 변경함
  ]
)

mnist_train_transformed = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=common_transform)
mnist_test_transformed = torchvision.datasets.MNIST(root='./mnist', train=False, download=True, transform=common_transform)
#	변환결과
print(mnist_train_transformed[0][0].shape) # (3,28,28)

Fine tuning

데이터의 입력 차원을 맞춰 줬다면. 이제 resnet18 모델의 출력 차원을 minst data의 class(10개)로 맞추어줘야한다.

MNIST_CLASS_NUM = 10 #mnistdata는 0부터 9까지 총 10개의 class를 갖는다
mnist_resnet18.fc = torch.nn.Linear(in_features=512, out_features=MNIST_CLASS_NUM, bias=True)#ResNet18의 마지막 fullconnectedlayer의 출력을 (512,1000)에서 (512,10)으로 변경

torch.nn.init.xavier_uniform_(mnist_resnet18.fc.weight)#변경된 가중치를 xavier_uniform분포를 갖도롤 초기화
stdv = 1. / math.sqrt(mnist_resnet18.fc.weight.size(1)) #편향 초기화
mnist_resnet18.fc.bias.data.uniform_(-stdv, stdv)

학습

앞선 과정이 끝났다면 이제 학습하는 일만 남았다.

### 학습 코드 시작
best_test_accuracy = 0.
best_test_loss = 9999.

for epoch in range(NUM_EPOCH):
  for phase in ["train", "test"]:
    running_loss = 0.
    running_acc = 0.
    if phase == "train":
      mnist_resnet18.train() # 네트워크 모델을 train 모드로 두어 gradient을 계산하고, 여러 sub module (배치 정규화, 드롭아웃 등)이 train mode로 작동할 수 있도록 함
    elif phase == "test":
      mnist_resnet18.eval() # 네트워크 모델을 eval 모드 두어 여러 sub module들이 eval mode로 작동할 수 있게 함

    for ind, (images, labels) in enumerate(tqdm(dataloaders[phase])):
      # (참고.해보기) 현재 tqdm으로 출력되는 것이 단순히 진행 상황 뿐인데 현재 epoch, running_loss와 running_acc을 출력하려면 어떻게 할 수 있는지 tqdm 문서를 보고 해봅시다!
      # hint - with, pbar
      images = images.to(device)
      labels = labels.to(device)

      optimizer.zero_grad() # parameter gradient를 업데이트 전 초기화함

      with torch.set_grad_enabled(phase == "train"): # train 모드일 시에는 gradient를 계산하고, 아닐 때는 gradient를 계산하지 않아 연산량 최소화
        logits = mnist_resnet18(images)
        _, preds = torch.max(logits, 1) # 모델에서 linear 값으로 나오는 예측 값 ([0.9,1.2, 3.2,0.1,-0.1,...])을 최대 output index를 찾아 예측 레이블([2])로 변경함  
        loss = loss_fn(logits, labels)

        if phase == "train":
          loss.backward() # 모델의 예측 값과 실제 값의 CrossEntropy 차이를 통해 gradient 계산
          optimizer.step() # 계산된 gradient를 가지고 모델 업데이트

      running_loss += loss.item() * images.size(0) # 한 Batch에서의 loss 값 저장
      running_acc += torch.sum(preds == labels.data) # 한 Batch에서의 Accuracy 값 저장

    # 한 epoch이 모두 종료되었을 때,
    epoch_loss = running_loss / len(dataloaders[phase].dataset)
    epoch_acc = running_acc / len(dataloaders[phase].dataset)

    print(f"현재 epoch-{epoch}{phase}-데이터 셋에서 평균 Loss : {epoch_loss:.3f}, 평균 Accuracy : {epoch_acc:.3f}")
    if phase == "test" and best_test_accuracy < epoch_acc: # phase가 test일 때, best accuracy 계산
      best_test_accuracy = epoch_acc
    if phase == "test" and best_test_loss > epoch_loss: # phase가 test일 때, best loss 계산
      best_test_loss = epoch_loss
print("학습 종료!")
print(f"최고 accuracy : {best_test_accuracy}, 최고 낮은 loss : {best_test_loss}")
profile
잘하자

0개의 댓글