check, initialize and freeze parameters

TEMP·2021년 10월 29일
0

Pytorch

목록 보기
5/11

https://stackoverflow.com/questions/63785319/pytorch-torch-no-grad-versus-requires-grad-false
with torch.no_grad()requires_grad = False의 차이(?)

기본적인 code부터 차근차근 나아가 보자.
transfer learning == freeze some layers 라고 생각해도 괜찮을 거 같다.
기본적으로 전이 학습이란 성능이 좋다고 알려진 즉, 잘 학습되어있는 parameter를 가지고 오는 것이다. 이후 특정 몇몇의 layer만 학습시켜(fine tuning이라고 보면 된다.) 내가 원하는 dataset에 맞게 바꿔주면 된다. 이게 왜 좋은지는 다음과 같이 설명 할 수 있을거 같다.
일단 time cost가 급격하게 감소한다. 상당한 강점이다. 처음에 이 개념을 알았을 때는 다른 dataset 또는 심지어는 classification 에서 trained parameter를 segmenatation에서도 이용하기도 하는데 이게 왜 잘되는지 오히려 방해되는거 아닌가 생각했었는데, CNN이 하는 역할 즉, low - high pixel 인식 이므로 low parameter만 잘 가지고 온다면 feature 추출을 잘한다고 볼 수 있다. 따라서 합리적이다.

auto grad에 관해서는 이미 다룸.

1. parameter 확인

import torch
import torch.nn as nn
train_input.shape
----------------------------------------
torch.Size([64, 3, 224, 224])
class test_model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 30,
        kernel_size = 7, stride=2, padding=10, bias = False)
        self.conv2 = nn.Conv2d(in_channels = 30, out_channels = 30, 
        kernel_size = 3, stride=1, padding=10, bias = False)        
    def forward(self, inputs):
        feature_map1 = self.conv1(inputs)
        feature_map2 = self.conv2(feature_map1)        
        return feature_map2

model shape 확인

import pytorch_model_summary
from torchinfo import summary
model = test_model().to(device)
x = train_input
print(pytorch_model_summary.summary(model, x, show_input=True))
print('!@#'*40)
summary(model, input_size = x.shape )

parameter 확인하기 & 추출

parameters={}
for n,i in enumerate(model.parameters()):
    parameters[n] = i
    print(n)
parameters.keys()

또는 아래가 더 편할듯

parameters={}
for n,i in model.named_parameters():
    parameters[n] = i
parameters.keys()

차이점은 아래 것은 model class에서 정의한 이름이 추출됨. 편함.

제일 편한방법

model.state_dict()

다만 차이점이 있다.
state_dict()requires_grad = False이다.

2. parameter initialize

크게 두 단계로 나누어 보면 된다.
첫번째로 initial 함수를 만들어 준다.

def initialize_weight(m):
    if isinstance(m, nn.Conv2D):
        nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')

다음으로는 이 함수를 model에 적용하면 된다.

model.apply(initialize_weight)

이때 apply는 model에 있는 layer들을 모두 도는 것을 알 수 있다.
즉, for문 이라고 보며된다.

  • 추가
len(model.state_dict())

전체 parameter의 수

model.state_dict().keys()

찾기 쉽게

model.conv1.weight.shape

위의 key를 통해 확인

parameter_test = torch.nn.parameter.Parameter(data = torch.zeros(32,1,3,3), requires_grad = True)
model.conv1.weight = parameter_test

parameter로 지정해준다음에 model에 parameter에 접근해서 직접 변경

model.state_dict()

변경됨을 확인

더 추가해보자
실제 모델에서 초기화 하는 방법.

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class test(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = conv_bn_relu(in_channel = 3, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
        self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.conv2 = conv_bn_relu(in_channel = 32, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2)) 
        
        self.conv3 = conv_bn_relu(in_channel = 32, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
        self.pool3 = nn.MaxPool2d(kernel_size=(2, 2))

        self.conv4 = conv_bn_relu(in_channel = 32, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
        self.pool4 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.conv5 = conv_bn_relu(in_channel = 32, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
        self.pool5 = nn.MaxPool2d(kernel_size=(2, 2))
        
        self.flatten = Flatten()
        
        self.fc1 = nn.Linear(32*3*3, 10)
        self.fc2 = nn.Linear(10, 10)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.pool3(x)
        
        x = self.conv4(x)
        x = self.pool4(x)
        
        x = self.conv5(x)
        x = self.pool5(x)   
        
        x = self.flatten(x)

        return x

위와같은 model을 만든다.

model = test()
model.state_dict()

parameter를 확인한다.
보면 원래 default로 초기화가 잘 되어 있다.

이에 다음 초기화 함수를 apply할꺼다.

def _initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.constant_(m.bias, 0)

잘되는지 test 해보기 위해 conv의 bias를 제외한 부분만 0으로 만들어 보자.

def _initialize_weights_test(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.uniform_(m.weight, a=0.0, b=0.0)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.constant_(m.bias, 0)
model.apply(_initialize_weights_test)
model.state_dict()

정확히 conv의 weight부분(bias제외)만 0 으로 바뀌었다.
즉, 위의 함수를 사용하여 model에 apply 하면 된다.

3.freeze

grad 여부 확인

for n,i in enumerate(model.parameters()):
    print(i.requires_grad)

이때 다음과 같이 grad를 끌 수 있다.

for n,i in enumerate(model.parameters()):
    if n == 0:
        i.requires_grad = False
for i,j in model.named_parameters():
    print(i,j.requires_grad)

0개의 댓글

관련 채용 정보