전이학습 (Transfer Learning)

pppanghyun·2022년 7월 31일
0

Pytorch 기본

목록 보기
15/21

전이학습(Transfer Learning)은 가존의 잘 알려진 데이터 혹은 사전학습(pretrained)된 모델을 도메인 확장을 위해 사용하는 학습을 의미함 (ex 이미지넷으로 학습한 모델을 실생활에 사용하는 경우..)

일반화 + 실생활 적용을 중요하게 생각하는 인공지능 분야에서는 매우 중요한 연구 분야.

TASKE: 'MNIST' 데이터로 학습된 ResNet 모델로 CIFAR 데이터 분류하기!

1. 라이브러리 불러오기

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

2. CIFAR 10 데이터 불러오기 및 전처리 (trasform)

10개의 class ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

transform = transforms.Compose(
    [transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True) 

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,shuffle=False)

3. Pretrained ResNet18 모델 불러오기

# pretrained=True를 하면 모델 구조와 + 사전 학습 된 파라메타를 모두 불러옴
# pretrained=False를 하면 구조만
# 모델과 텐서에 .to(device)를 붙여야만 GPU 연산이 가능!!!

model = torchvision.models.resnet18(pretrained=True)
print(model)

#result (중간 생략, 마지막 부분)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

모델 구조를 보면 linear 부분의 out feature가 1000인 것을 알 수 있음.
이는 resnet18이 imagenet 데이터로 학습되었기 때문 !!!
따라서, cifar10 데이터의 분류를 위해선 출력층을 수정해야하ㅁ

4. 출력층 수정

num_ftrs = model.fc.in_features # fc의 입력 노드 수를 산출 (512개)
model.fc = nn.Linear(num_ftrs, 10) # fc를 nn.Linear(num_ftrs, 10)로 대체
model = model.to(device)

print(model)

#result (중간 생략, 마지막 부분)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=10, bias=True)
)
# 10개로 잘 바뀐것을 알 수 있음

5. 손실함수 및 최적화 방법 및 모델 학습 (pretrained model !)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-2)

for epoch in range(20):

    running_loss = 0.0
    for data in trainloader:
        
        inputs, labels = data[0].to(device), data[1].to(device)
          
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    cost = running_loss / len(trainloader)        
    print('[%d] loss: %.3f' %(epoch + 1, cost))  

torch.save(model.state_dict(), './models/cifar10_resnet18.pth')      

# 학습된 모델 상태(stae_dict()) .pth 에 저장 ! 
print('Finished Training')

6. CIFAR10 데이터로 학습한 ResNet18 모델 불러오기!

model = torchvision.models.resnet18(pretrained=False) # 구조만 불러오고
num_ftrs = model.fc.in_features # fc의 입력 노드 수를 산출 (512개)
model.fc = nn.Linear(num_ftrs, 10) # fc를 nn.Linear(num_ftrs, 10)로 대체
model = model.to(device)
model.load_state_dict(torch.load('./models/cifar10_resnet18.pth'))

# 중요
<All keys matched successfully> # 가 나와야 구조 + 파라미터 잘 들어온 것 

7. 모델 평가

correct = 0
total = 0

with torch.no_grad():
    model.eval()
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

# result
Accuracy of the network on the 10000 test images: 81 %

결론: MNIST 데이터로 pretrained 된 resnet18 모델을 사용해서 cifar 10 데이터의 분류 성능은 대략 81% 정도가 나옴. 전이학습을 위해선 모델 구조를 손봐야되는 번거로움이 존재하지만 '어느 정도'는 괜찮은 성능이 나온다는 것을 확인할 수 있음!

profile
pppanghyun

0개의 댓글