https://stackoverflow.com/questions/63785319/pytorch-torch-no-grad-versus-requires-grad-false
with torch.no_grad()
와 requires_grad = False
의 차이(?)
기본적인 code부터 차근차근 나아가 보자.
transfer learning == freeze some layers 라고 생각해도 괜찮을 거 같다.
기본적으로 전이 학습이란 성능이 좋다고 알려진 즉, 잘 학습되어있는 parameter를 가지고 오는 것이다. 이후 특정 몇몇의 layer만 학습시켜(fine tuning이라고 보면 된다.) 내가 원하는 dataset에 맞게 바꿔주면 된다. 이게 왜 좋은지는 다음과 같이 설명 할 수 있을거 같다.
일단 time cost가 급격하게 감소한다. 상당한 강점이다. 처음에 이 개념을 알았을 때는 다른 dataset 또는 심지어는 classification 에서 trained parameter를 segmenatation에서도 이용하기도 하는데 이게 왜 잘되는지 오히려 방해되는거 아닌가 생각했었는데, CNN이 하는 역할 즉, low - high pixel 인식 이므로 low parameter만 잘 가지고 온다면 feature 추출을 잘한다고 볼 수 있다. 따라서 합리적이다.
auto grad
에 관해서는 이미 다룸.
import torch import torch.nn as nn train_input.shape ---------------------------------------- torch.Size([64, 3, 224, 224])
class test_model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(in_channels = 3, out_channels = 30, kernel_size = 7, stride=2, padding=10, bias = False) self.conv2 = nn.Conv2d(in_channels = 30, out_channels = 30, kernel_size = 3, stride=1, padding=10, bias = False) def forward(self, inputs): feature_map1 = self.conv1(inputs) feature_map2 = self.conv2(feature_map1) return feature_map2
model shape 확인
import pytorch_model_summary from torchinfo import summary model = test_model().to(device) x = train_input print(pytorch_model_summary.summary(model, x, show_input=True)) print('!@#'*40) summary(model, input_size = x.shape )
parameter 확인하기 & 추출
parameters={} for n,i in enumerate(model.parameters()): parameters[n] = i print(n) parameters.keys()
또는 아래가 더 편할듯
parameters={} for n,i in model.named_parameters(): parameters[n] = i parameters.keys()
차이점은 아래 것은 model class에서 정의한 이름이 추출됨. 편함.
제일 편한방법
model.state_dict()
다만 차이점이 있다.
state_dict()
는 requires_grad = False
이다.
크게 두 단계로 나누어 보면 된다.
첫번째로 initial 함수를 만들어 준다.
def initialize_weight(m): if isinstance(m, nn.Conv2D): nn.init.kaiming_uniform_(m.weight.data,nonlinearity='relu')
다음으로는 이 함수를 model에 적용하면 된다.
model.apply(initialize_weight)
이때
apply
는 model에 있는 layer들을 모두 도는 것을 알 수 있다.
즉, for문 이라고 보며된다.
len(model.state_dict())
전체 parameter의 수
model.state_dict().keys()
찾기 쉽게
model.conv1.weight.shape
위의 key를 통해 확인
parameter_test = torch.nn.parameter.Parameter(data = torch.zeros(32,1,3,3), requires_grad = True) model.conv1.weight = parameter_test
parameter로 지정해준다음에 model에 parameter에 접근해서 직접 변경
model.state_dict()
변경됨을 확인
더 추가해보자
실제 모델에서 초기화 하는 방법.
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
class test(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = conv_bn_relu(in_channel = 3, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv2 = conv_bn_relu(in_channel = 32, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv3 = conv_bn_relu(in_channel = 32, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
self.pool3 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv4 = conv_bn_relu(in_channel = 32, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
self.pool4 = nn.MaxPool2d(kernel_size=(2, 2))
self.conv5 = conv_bn_relu(in_channel = 32, out_channel = 32, kernel = (3,3), s = (1,1), p = (1,1))
self.pool5 = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = Flatten()
self.fc1 = nn.Linear(32*3*3, 10)
self.fc2 = nn.Linear(10, 10)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.pool3(x)
x = self.conv4(x)
x = self.pool4(x)
x = self.conv5(x)
x = self.pool5(x)
x = self.flatten(x)
return x
위와같은 model을 만든다.
model = test() model.state_dict()
parameter를 확인한다.
보면 원래 default로 초기화가 잘 되어 있다.
이에 다음 초기화 함수를
apply
할꺼다.def _initialize_weights(m): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0)
잘되는지 test 해보기 위해 conv의 bias를 제외한 부분만 0으로 만들어 보자.
def _initialize_weights_test(m): if isinstance(m, nn.Conv2d): torch.nn.init.uniform_(m.weight, a=0.0, b=0.0) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) model.apply(_initialize_weights_test) model.state_dict()
정확히 conv의 weight부분(bias제외)만 0 으로 바뀌었다.
즉, 위의 함수를 사용하여 model에 apply 하면 된다.
grad 여부 확인
for n,i in enumerate(model.parameters()): print(i.requires_grad)
이때 다음과 같이 grad를 끌 수 있다.
for n,i in enumerate(model.parameters()): if n == 0: i.requires_grad = False for i,j in model.named_parameters(): print(i,j.requires_grad)