논문 레시피에 따르면 초기화를 kaiming으로 해야한다.
파이토치에서 좋으면서 안 좋은점이 그냥 nn.Conv2d로 레이어를 만들어 주면 자동으로 가중치를 초기화 한다. 그러면 kaiming 초기화가 안되서 초기화를 커스텀 해주기 위해서 가중치를 직접 초기화하고 F.conv2d를 이용해야 한다.
import torch.nn as nn
from nas.common_ops import create_weight
import torch.nn.functional as F
class PointwiseConv(nn.Module):
def __init__(self, in_channels, out_channels, initailizer=None):
super(PointwiseConv, self).__init__()
self.weight = create_weight([out_channels, in_channels, 1, 1], initializer=initailizer)
def forward(self, x):
return F.conv2d(x, self.weight)
import torch.nn as nn
import torch.nn.functional as F
from nas.common_ops import create_weight
class DepthwiseConv(nn.Module):
def __init__(self, in_channels, initializer=None):
super(DepthwiseConv, self).__init__()
self.weight = create_weight([in_channels, 1, 3, 3], initializer=initializer)
def forward(self, x):
return F.conv2d(x, self.weight, groups=self.in_channels, padding=1)
원래 pointwise와 depthwise 레이어를 만들때 자동 초기화를 했었는데 커스텀 가중치 초기화를 위해서 위와 같이 바꿔주자.
def _conv_branch(self,
inputs,
filter_size,
count,
out_filters,
ch_mul=1,
start_idx=None,
seperable=False
):
if start_idx is None:
assert self.fixed_arc is not None, "you need start_idx or fixed_arc"
if self.data_format == "NHWC":
c = inputs.get_shape()[3].value
elif self.data_format == "NCHW":
c = inputs.get_shape()[1].value
x = PointwiseConv(c, out_filters)(inputs)
x = batch_norm(x, data_format=self.data_format)
x = F.relu(x)
if start_idx is None:
if seperable:
depthwise_conv = DepthwiseConv(in_channels=out_filters)
x = depthwise_conv(x)
pointwise_conv = PointwiseConv(in_channels=out_filters * ch_mul, out_channels=count)
x = pointwise_conv(x)
else:
x = nn.Conv2d(c, count, kernel_size=filter_size, padding=filter_size // 2)
x = batch_norm(x, data_format=self.data_format)
else:
if seperable:
depthwise_conv = DepthwiseConv(in_channels=out_filters)
x = depthwise_conv(x)
w_pointwise = create_weight([out_filters * ch_mul, out_filters])
w_pointwise = w_pointwise[start_idx:start_idx + count, :]
w_pointwise = w_pointwise(0, 1)
w_pointwise = w_pointwise.view(1, 1, out_filters * ch_mul, count)
x = F.conv2d(x, w_pointwise, stride=1, padding=filter_size // 2)
else:
w = create_weight([filter_size, filter_size, out_filters, out_filters])
w = w.transpose(0, 3)
w = w[start_idx:start_idx + count, :, :, :]
x = F.conv2d(x, w, stride=1, padding=filter_size // 2)
return x
macro_child.py에 브랜치를 만들어주는 _conv_branch 매서드를 만들어주자.
이 매서드의 특징은 크게 2가지이다.
우선 위에서 바꾼 depthwise, pointwise를 이용해서 가중치 초기화 없이 만들어 준다.
if문을 이용해서 마스킹 없이 레이어 만들기, 마스킹 써서 레이어 만들기 두 가지로 나뉜 뒤
다시 if 문을 써서 seperable, 그냥 conv 둘 중 하나를 선택 해준다.
주목할 부분은 마스킹에서 커스텀 weight를 사용해야한다는 점이다.
마스킹이란건 아래와 같은 목적이 있다.
즉 필터 중에서 특정 필터만을 선택해서 가중치를 초기화 해야하므로 당연히 내가 직접 커스텀 가중치를 만들고 필터 범위를 선택해줘야한다.
따라서 범위를 위 코드와 같이 count를 이용해 설정하고 F.conv2d를 이용해 직접 레이어를 만들어준다.