Transfer learning & Hyper Parameter Tuning 실습

이상민·2023년 3월 17일
0
post-custom-banner

이번 실습에서는 이전에 배운 transfer learninig과 hyperparameter Tuning을 학습해본다.

Transfer Learning

모델생성

imageNet에서 학습된 resnet18을 전이하여 mnist데이터를 분류하는 모델을 만들기

import torchvision
import torch
import numpy

imagenet_resnet18 = torchvision.models.resnet18(pretrained = True)#imageNet에서 학습된 ResNet18 모델을 pretrained = True로 설정하여 불러옴

Mnist Dataset불러오기

mnist_train = torchvision.dataset.MNIST(root = "./mnist",train = True, download = True)
mnist_test = torchvision.dataset.MNIST(root = "./mnist",train = Flase, download = True) #test데이터를 불러올 시 train = Flase로 설정

Mnist를 학습할 CNN모델 생성하기

mnist_resnet18 = torchvision.models.resnet18(pretrained = False)

#mnistdata set의 경우 흑백이미지로 채널이 1차원이다.
# 반면 resnet은 컬러를 분류하는 모델이므로 입력 채널의 크기가 3차원
print(np.array(mnist_train[0][0]).shape) #	(28,28)
print(imagenet_resnet18.conv1) #Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

데이터 전처리

mnist dataset을 resnet의 입력 차원과 맞춰 주기 위해 Grayscale을 사용하여 채널을 맞춰준다.

common_transform = torchvision.transforms.Compose(
  [
    torchvision.transforms.Grayscale(num_output_channels=3), # grayscale의 1채널 영상을 3채널로 동일한 값으로 확장함
    torchvision.transforms.ToTensor() # PIL Image를 Tensor type로 변경함
  ]
)

mnist_train_transformed = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=common_transform)
mnist_test_transformed = torchvision.datasets.MNIST(root='./mnist', train=False, download=True, transform=common_transform)
#	변환결과
print(mnist_train_transformed[0][0].shape) # (3,28,28)

Fine tuning

데이터의 입력 차원을 맞춰 줬다면. 이제 resnet18 모델의 출력 차원을 minst data의 class(10개)로 맞추어줘야한다.

MNIST_CLASS_NUM = 10 #mnistdata는 0부터 9까지 총 10개의 class를 갖는다
mnist_resnet18.fc = torch.nn.Linear(in_features=512, out_features=MNIST_CLASS_NUM, bias=True)#ResNet18의 마지막 fullconnectedlayer의 출력을 (512,1000)에서 (512,10)으로 변경

torch.nn.init.xavier_uniform_(mnist_resnet18.fc.weight)#변경된 가중치를 xavier_uniform분포를 갖도롤 초기화
stdv = 1. / math.sqrt(mnist_resnet18.fc.weight.size(1)) #편향 초기화
mnist_resnet18.fc.bias.data.uniform_(-stdv, stdv)

학습

앞선 과정이 끝났다면 이제 학습하는 일만 남았다.

### 학습 코드 시작
best_test_accuracy = 0.
best_test_loss = 9999.

for epoch in range(NUM_EPOCH):
  for phase in ["train", "test"]:
    running_loss = 0.
    running_acc = 0.
    if phase == "train":
      mnist_resnet18.train() # 네트워크 모델을 train 모드로 두어 gradient을 계산하고, 여러 sub module (배치 정규화, 드롭아웃 등)이 train mode로 작동할 수 있도록 함
    elif phase == "test":
      mnist_resnet18.eval() # 네트워크 모델을 eval 모드 두어 여러 sub module들이 eval mode로 작동할 수 있게 함

    for ind, (images, labels) in enumerate(tqdm(dataloaders[phase])):
      # (참고.해보기) 현재 tqdm으로 출력되는 것이 단순히 진행 상황 뿐인데 현재 epoch, running_loss와 running_acc을 출력하려면 어떻게 할 수 있는지 tqdm 문서를 보고 해봅시다!
      # hint - with, pbar
      images = images.to(device)
      labels = labels.to(device)

      optimizer.zero_grad() # parameter gradient를 업데이트 전 초기화함

      with torch.set_grad_enabled(phase == "train"): # train 모드일 시에는 gradient를 계산하고, 아닐 때는 gradient를 계산하지 않아 연산량 최소화
        logits = mnist_resnet18(images)
        _, preds = torch.max(logits, 1) # 모델에서 linear 값으로 나오는 예측 값 ([0.9,1.2, 3.2,0.1,-0.1,...])을 최대 output index를 찾아 예측 레이블([2])로 변경함  
        loss = loss_fn(logits, labels)

        if phase == "train":
          loss.backward() # 모델의 예측 값과 실제 값의 CrossEntropy 차이를 통해 gradient 계산
          optimizer.step() # 계산된 gradient를 가지고 모델 업데이트

      running_loss += loss.item() * images.size(0) # 한 Batch에서의 loss 값 저장
      running_acc += torch.sum(preds == labels.data) # 한 Batch에서의 Accuracy 값 저장

    # 한 epoch이 모두 종료되었을 때,
    epoch_loss = running_loss / len(dataloaders[phase].dataset)
    epoch_acc = running_acc / len(dataloaders[phase].dataset)

    print(f"현재 epoch-{epoch}{phase}-데이터 셋에서 평균 Loss : {epoch_loss:.3f}, 평균 Accuracy : {epoch_acc:.3f}")
    if phase == "test" and best_test_accuracy < epoch_acc: # phase가 test일 때, best accuracy 계산
      best_test_accuracy = epoch_acc
    if phase == "test" and best_test_loss > epoch_loss: # phase가 test일 때, best loss 계산
      best_test_loss = epoch_loss
print("학습 종료!")
print(f"최고 accuracy : {best_test_accuracy}, 최고 낮은 loss : {best_test_loss}")
profile
잘하자
post-custom-banner

0개의 댓글