논문 레시피에 따르면 enas의 operation은 총 6개로
커널 크기가 3, 5인 conv, seperable conv 4개
커널 크기가 3인 max, average pool 2개 해서 총 6개이다.
따라서 구현할 브랜치는 총 6개의 경우의 수를 선택할 수 있게 한다.
구현은 크게 2가지로 층을 고정시키고 특정 층만 바꾸는 fixed_layer와 모든 층을 새롭게 만드는 enas_layer가 있다.
def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters):
inputs = prev_layers[-1]
if self.whole_channels:
if self.data_format == "NHWC":
inp_h = inputs.get_shape()[1].value
inp_w = inputs.get_shape()[2].value
inp_c = inputs.get_shape()[3].value
elif self.data_format == "NCHW":
inp_c = inputs.get_shape()[1].value
inp_h = inputs.get_shape()[2].value
inp_w = inputs.get_shape()[3].value
count = self.sample_arc[start_idx]
branches = {}
if count == 0:
y = self._conv_branch(inputs, 3, out_filters=out_filters, seperable=False)
branches[0] = y
elif count == 1:
y = self._conv_branch(inputs, 3, out_filters=out_filters, seperable=True)
branches[1] = y
elif count == 2:
y = self._conv_branch(inputs, 5, out_filters=out_filters, seperable=False)
branches[2] = y
elif count == 3:
y = self._conv_branch(inputs, 5, out_filters=out_filters, seperable=True)
branches[3] = y
elif count == 4:
y = self._pool_branch(inputs, mode="avg")
branches[4] = y
elif count == 5:
y = self._pool_branch(inputs, mode="max")
branches[5] = y
out = branches.get(count, torch.zeros_like(inputs))
else:
count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches]
branches = {}
branches.append(self._conv_branch(inputs, 3, out_filters=out_filters, seperable=False))
branches.append(self._conv_branch(inputs, 3, out_filters=out_filters, seperable=True))
branches.append(self._conv_branch(inputs, 5, out_filters=out_filters, seperable=False))
branches.append(self._conv_branch(inputs, 5, out_filters=out_filters, seperable=True))
if self.num_branches >= 5:
branches.append(self._pool_branch(inputs, mode="avg"))
if self.num_branches >= 6:
branches.append(self._pool_branch(inputs,mode="max"))
branches = torch.cat(branches, dim=1)
w = create_weight([self.num_branches * out_filters, out_filters])
w_mask = torch.zeros(self.num_branches * out_filters, dtype=torch.bool)
new_range = torch.arange(self.num_branches * self.out_filters)
for i in range(self.num_branches):
start = out_filters * i + count[2 * i]
w_mask = torch.logical_or(w_mask, (new_range >= start) & (new_range < start + count[2 * i + 1]))
w = w[w_mask].view(-1, out_filters)
# Apply convolution
out = F.conv2d(branches, w.unsqueeze(2).unsqueeze(3))
# Skip connections
if layer_id > 0:
skip_start = start_idx + (1 if self.whole_channels else 2 * self.num_branches)
skip = self.sample_arc[skip_start: skip_start + layer_id]
res_layers = [torch.zeros_like(prev_layers[i]) if skip[i] == 0 else prev_layers[i] for i in range(layer_id)]
res_layers.append(out)
out = torch.stack(res_layers, dim=0).sum(dim=0)
# Batch norm and relu activation
out = self.batch_norm_layer(out)
out = F.relu(out)
return out
전체 채널을 모두 만들것인지 아닌지에 따라서 whole_channel로 선택할 수 있게 분기를 만들자. 마지막엔 skip connection을 해준다. 이때 skip connection은 어떤 레이어를 이을지 선택하는 것으로 resnet의 효과를 얻게 된다.
def _fixed_layer(self,
layer_id,
prev_layers,
start_idx,
out_filters):
inputs = prev_layers[-1]
if self.whole_channels:
if self.data_format == "NHWC":
c = inputs.get_shape()[3].value
elif self.data_format == "NCHW":
c = inputs.get_shape()[1].value
count = self.sample_arc[start_idx]
if count in [0, 1, 2, 3]:
size = [3, 3, 5, 5]
filter_size = size[count]
out = F.relu(PointwiseConv(c, out_filters)(inputs))
out = self.batch_norm_layer(out, data_format=self.data_format)
conv_filter = nn.Conv2d(out_filters, out_filters, kernel_size=filter_size,padding=filter_size//2)
out = F.relu(conv_filter(out))
out = self.batch_norm_layer(out, data_format=self.data_format)
elif count == 4:
out = self._pool_branch(inputs, out_filters, mode="avg")
elif count == 5:
out = self._pool_branch(inputs, out_filters, mode="max")
else:
raise ValueError(f"Invalid layer id {layer_id}")
else:
count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches] * self.filters_scale
branches = []
total_out_channels = 0
total_out_channels += count[1]
branches.append(self._conv_branch(inputs, 3, count[1]))
total_out_channels += count[3]
branches.append(self._conv_branch(inputs, 3, count[3], seperable=True))
total_out_channels += count[5]
branches.append(self._conv_branch(inputs, 5, count[5]))
total_out_channels += count[7]
branches.append(self._conv_branch(inputs, 5, count[7], seperable=True))
if self.num_branches >= 5:
total_out_channels += count[9]
branches.append(self._pool_branch(inputs, count[9], mode="avg"))
if self.num_branches >= 6:
total_out_channels += count[11]
branches.append(self._pool_branch(inputs, count[11], mode="max"))
final_conv = nn.Conv2d(total_out_channels, out_filters, kernel_size=1)
out = F.relu(final_conv(torch.cat(branches, dim=1)))
out = self.batch_norm_layer(out_filters, data_format=self.data_format)
if layer_id > 0:
skip_start = start_idx + (1 if self.whole_channels else 2 * self.num_branches)
skip = self.sample_arc[skip_start: skip_start + layer_id]
total_skip_channels = sum(skip)
res_layers = [prev_layers[i] for i in range(layer_id) if skip[i] == 1]
prev = torch.cat(res_layers + [out], dim=0)
skip_conv = nn.Conv2d(total_skip_channels * out_filters, out_filters, kernel_size=1)
out = F.relu(skip_conv(prev))
out = self.batch_norm_layer(out_filters)(out)
return out
이미 고정된 레이어를 만드는 경우로 별 다른건 없다.
고정된 아키텍처에 따라서 아키텍처를 분기문으로 읽고 레이어를 만드는 매서드이다.
역시나 마지막은 skip을 하게 해줘서 resnet의 효과를 준다.
macro 탐색 방법은 이것으로 마치겠다. marco 방법을 통해 enas를 이해하는 정도로 구현했고 실제로 우리가 이걸 쓰진 않을것이다. 진짜 쓸 것은 micro 탐색 방법으로 이걸 mnas, tunas와 다양한 기법을 연결하여 발전시킬 것이다.