*음악 분류 딥러닝을 만들자(37) - enas_layer, fixed_layer

응큼한포도·2024년 10월 21일
0

논문 레시피

논문 레시피에 따르면 enas의 operation은 총 6개로
커널 크기가 3, 5인 conv, seperable conv 4개
커널 크기가 3인 max, average pool 2개 해서 총 6개이다.

따라서 구현할 브랜치는 총 6개의 경우의 수를 선택할 수 있게 한다.

구현은 크게 2가지로 층을 고정시키고 특정 층만 바꾸는 fixed_layer와 모든 층을 새롭게 만드는 enas_layer가 있다.

enas_layer

    def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters):

        inputs = prev_layers[-1]
        if self.whole_channels:
            if self.data_format == "NHWC":
                inp_h = inputs.get_shape()[1].value
                inp_w = inputs.get_shape()[2].value
                inp_c = inputs.get_shape()[3].value
            elif self.data_format == "NCHW":
                inp_c = inputs.get_shape()[1].value
                inp_h = inputs.get_shape()[2].value
                inp_w = inputs.get_shape()[3].value

            count = self.sample_arc[start_idx]
            branches = {}

            if count == 0:
                y = self._conv_branch(inputs, 3, out_filters=out_filters, seperable=False)
                branches[0] = y

            elif count == 1:
                y = self._conv_branch(inputs, 3, out_filters=out_filters, seperable=True)
                branches[1] = y

            elif count == 2:
                y = self._conv_branch(inputs, 5, out_filters=out_filters, seperable=False)
                branches[2] = y

            elif count == 3:
                y = self._conv_branch(inputs, 5, out_filters=out_filters, seperable=True)
                branches[3] = y

            elif count == 4:
                y = self._pool_branch(inputs, mode="avg")
                branches[4] = y

            elif count == 5:
                y = self._pool_branch(inputs, mode="max")
                branches[5] = y

            out =  branches.get(count, torch.zeros_like(inputs))

        else:
            count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches]
            branches = {}
            branches.append(self._conv_branch(inputs, 3, out_filters=out_filters, seperable=False))
            branches.append(self._conv_branch(inputs, 3, out_filters=out_filters, seperable=True))
            branches.append(self._conv_branch(inputs, 5, out_filters=out_filters, seperable=False))
            branches.append(self._conv_branch(inputs, 5, out_filters=out_filters, seperable=True))

            if self.num_branches >= 5:
                branches.append(self._pool_branch(inputs, mode="avg"))
            if self.num_branches >= 6:
                branches.append(self._pool_branch(inputs,mode="max"))

            branches = torch.cat(branches, dim=1)

            w = create_weight([self.num_branches * out_filters, out_filters])
            w_mask = torch.zeros(self.num_branches * out_filters, dtype=torch.bool)
            new_range = torch.arange(self.num_branches * self.out_filters)

            for i in range(self.num_branches):
                start = out_filters * i + count[2 * i]
                w_mask = torch.logical_or(w_mask, (new_range >= start) & (new_range < start + count[2 * i + 1]))

            w = w[w_mask].view(-1, out_filters)


            # Apply convolution
            out = F.conv2d(branches, w.unsqueeze(2).unsqueeze(3))

        # Skip connections
        if layer_id > 0:
            skip_start = start_idx + (1 if self.whole_channels else 2 * self.num_branches)
            skip = self.sample_arc[skip_start: skip_start + layer_id]
            res_layers = [torch.zeros_like(prev_layers[i]) if skip[i] == 0 else prev_layers[i] for i in range(layer_id)]
            res_layers.append(out)
            out = torch.stack(res_layers, dim=0).sum(dim=0)

        # Batch norm and relu activation
        out = self.batch_norm_layer(out)
        out = F.relu(out)

        return out

전체 채널을 모두 만들것인지 아닌지에 따라서 whole_channel로 선택할 수 있게 분기를 만들자. 마지막엔 skip connection을 해준다. 이때 skip connection은 어떤 레이어를 이을지 선택하는 것으로 resnet의 효과를 얻게 된다.

fixed layer

    def _fixed_layer(self,
                     layer_id,
                     prev_layers,
                     start_idx,
                     out_filters):

        inputs = prev_layers[-1]
        if self.whole_channels:
            if self.data_format == "NHWC":
                c = inputs.get_shape()[3].value
            elif self.data_format == "NCHW":
                c = inputs.get_shape()[1].value

            count = self.sample_arc[start_idx]
            if count in [0, 1, 2, 3]:
                size = [3, 3, 5, 5]
                filter_size = size[count]

                out = F.relu(PointwiseConv(c, out_filters)(inputs))
                out = self.batch_norm_layer(out, data_format=self.data_format)

                conv_filter = nn.Conv2d(out_filters, out_filters, kernel_size=filter_size,padding=filter_size//2)
                out = F.relu(conv_filter(out))
                out = self.batch_norm_layer(out, data_format=self.data_format)

            elif count == 4:
                out = self._pool_branch(inputs, out_filters, mode="avg")

            elif count == 5:
                out = self._pool_branch(inputs, out_filters, mode="max")

            else:
                raise ValueError(f"Invalid layer id {layer_id}")
        else:
            count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches] * self.filters_scale
            branches = []
            total_out_channels = 0

            total_out_channels += count[1]
            branches.append(self._conv_branch(inputs, 3, count[1]))

            total_out_channels += count[3]
            branches.append(self._conv_branch(inputs, 3, count[3], seperable=True))

            total_out_channels += count[5]
            branches.append(self._conv_branch(inputs, 5, count[5]))

            total_out_channels += count[7]
            branches.append(self._conv_branch(inputs, 5, count[7], seperable=True))

            if self.num_branches >= 5:
                total_out_channels += count[9]
                branches.append(self._pool_branch(inputs, count[9], mode="avg"))

            if self.num_branches >= 6:
                total_out_channels += count[11]
                branches.append(self._pool_branch(inputs, count[11], mode="max"))

            final_conv = nn.Conv2d(total_out_channels, out_filters, kernel_size=1)
            out = F.relu(final_conv(torch.cat(branches, dim=1)))
            out = self.batch_norm_layer(out_filters, data_format=self.data_format)

        if layer_id > 0:
            skip_start = start_idx + (1 if self.whole_channels else 2 * self.num_branches)
            skip = self.sample_arc[skip_start: skip_start + layer_id]
            total_skip_channels = sum(skip)

            res_layers = [prev_layers[i] for i in range(layer_id) if skip[i] == 1]
            prev = torch.cat(res_layers + [out], dim=0)

            skip_conv = nn.Conv2d(total_skip_channels * out_filters, out_filters, kernel_size=1)
            out = F.relu(skip_conv(prev))
            out = self.batch_norm_layer(out_filters)(out)

        return out

이미 고정된 레이어를 만드는 경우로 별 다른건 없다.
고정된 아키텍처에 따라서 아키텍처를 분기문으로 읽고 레이어를 만드는 매서드이다.

역시나 마지막은 skip을 하게 해줘서 resnet의 효과를 준다.

마무리

macro 탐색 방법은 이것으로 마치겠다. marco 방법을 통해 enas를 이해하는 정도로 구현했고 실제로 우리가 이걸 쓰진 않을것이다. 진짜 쓸 것은 micro 탐색 방법으로 이걸 mnas, tunas와 다양한 기법을 연결하여 발전시킬 것이다.

profile
미친 취준생

0개의 댓글