전이학습: 모델 동결 (Model Freezing)

pppanghyun·2022년 7월 31일
1

Pytorch 기본

목록 보기
16/21

전이학습(Transfer Learning)은 이미 학습된 모델을 가져와 사용하는 방법. 데이터가 유사한 경우에는 새로운 학습 없이도 좋은 성능이 나옴.
이런 경우 데이터의 특징을 추출(피처 추출)하는 부분의 변수는 동결하고(freeze), 분류 파트에 해당되는 fully connected layer의 변수만 업데이트 할 수 있음. 이런 방법을 모델 동결(model freezing)이라고 함.

*이전 방법에서는 fully connected 부분도 모두 동일하게 사용

1. 라이브러리 불러오기

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

2. CIFAR10 데이터 불러오기

transform = transforms.Compose(
    [transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True) 

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,shuffle=False)

3. Pretrained 모델 불러오기

model = torchvision.models.alexnet(pretrained=True)

4. 모델 fully connected layer 구조 바꿔주기 (1000 to 10)

num_ftrs = model.classifier[6].in_features # fc의 입력 노드 수를 산출 
model.features[0] = nn.Conv2d(3, 64, kernel_size=5, stride=1)
model.classifier[6] = nn.Linear(num_ftrs, 10) # fc를 nn.Linear(num_ftrs, 10)로 대체
model = model.to(device) # 출력층 확인하기

5. 모델 프리징1 (출력층 파라미터 확인하기)

i = 0
for name, param in model.named_parameters():  # named_parameters 가중치 뽑아주는 함수
    
    print(i,name)
    i+= 1

# result
0 features.0.weight
1 features.0.bias
2 features.3.weight
3 features.3.bias
4 features.6.weight
5 features.6.bias
6 features.8.weight
7 features.8.bias
8 features.10.weight
9 features.10.bias
10 classifier.1.weight
11 classifier.1.bias
12 classifier.4.weight
13 classifier.4.bias
14 classifier.6.weight
15 classifier.6.bias

합성곱 층은 0~9, fully connected layer는 10~15임을 알 수 있음.

6. 모델 프리징2

# 합성곱 층은 0~9까지이다. 
# 따라서 9번째 변수까지 역추적을 비활성화 = freeze (학습안함) 한 후 for문을 종료한다.

for i, (name, param) in enumerate(model.named_parameters()):
    
    param.requires_grad = False
    if i == 9:
        print('end')
        break  
        
# requires_grad 확인 (classifer만 True로 나옴)
print(model.features[0].weight.requires_grad)
print(model.features[0].bias.requires_grad)
print(model.features[3].weight.requires_grad)
print(model.features[3].bias.requires_grad)
print(model.features[6].weight.requires_grad)
print(model.features[6].bias.requires_grad)
print(model.features[8].weight.requires_grad)
print(model.features[8].bias.requires_grad)
print(model.features[10].weight.requires_grad)
print(model.features[10].bias.requires_grad)
print(model.classifier[1].weight.requires_grad)
print(model.classifier[1].bias.requires_grad)
print(model.classifier[4].weight.requires_grad)
print(model.classifier[4].bias.requires_grad)
print(model.classifier[6].weight.requires_grad)
print(model.classifier[6].bias.requires_grad)

(이렇게 되면 feature 추출하는 부분은 grad=False 이기 때문에 학습이 안되고(freeze), classifier 부분만 새로운 데이터셋에 맞게 학습이 된다 grad=True)

7. 손실함수 및 최적화 정의 및 학습

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-2)
for epoch in range(20):

    running_loss = 0.0
    for data in trainloader:
        
        inputs, labels = data[0].to(device), data[1].to(device)
          
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    cost = running_loss / len(trainloader)        
    print('[%d] loss: %.3f' %(epoch + 1, cost))  
   

print('Finished Training')

8. 모델 평가

correct = 0
total = 0
with torch.no_grad():
    model.eval()
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

#result
Accuracy of the network on the 10000 test images: 39 %

결론: 전이학습에서 fully connected layer 부분만 냅두고 (학습시키고) 나머지 부분은 학습시키지 않는 (freeze) 방법을 model freezing 이라고 함. 전에 비해서 성능이 낮게 나왔는데, 데이터 마다 다른지 아니면 일반적으로 이런지 논문 찾아봐야겠다...

profile
pppanghyun

1개의 댓글

comment-user-thumbnail
2023년 1월 10일

안녕하세요 전이학습 공부하다가 들리게 됐습니다!
저도 작성자님과 마찬가지로 CNN층을 Freezing 시키는 것보다 파라미터 업데이트를 하는게 오히려
F.C layer만 학습 시키는 것보다 학습 결과가 좋게 나오더라구요..
혹시 어떤 이유로 학습이 좀 더 잘된 것 같다고 생각이 드시는지 의견 여쭤봐도 괜찮을까요?

답글 달기