*음악 분류 딥러닝을 만들자(36) - conv_branch 매서드와 각종 수정

응큼한포도·2024년 10월 20일
0

초기화를 커스텀하기 위해서 다음과 같은 파일을 수정

논문 레시피에 따르면 초기화를 kaiming으로 해야한다.

파이토치에서 좋으면서 안 좋은점이 그냥 nn.Conv2d로 레이어를 만들어 주면 자동으로 가중치를 초기화 한다. 그러면 kaiming 초기화가 안되서 초기화를 커스텀 해주기 위해서 가중치를 직접 초기화하고 F.conv2d를 이용해야 한다.

pointwise

import torch.nn as nn
from nas.common_ops import create_weight
import torch.nn.functional as F

class PointwiseConv(nn.Module):
    def __init__(self, in_channels, out_channels, initailizer=None):
        super(PointwiseConv, self).__init__()
        self.weight = create_weight([out_channels, in_channels, 1, 1], initializer=initailizer)

    def forward(self, x):
        return F.conv2d(x, self.weight)

depthwise

import torch.nn as nn
import torch.nn.functional as F
from nas.common_ops import create_weight

class DepthwiseConv(nn.Module):
    def __init__(self, in_channels, initializer=None):
        super(DepthwiseConv, self).__init__()
        self.weight = create_weight([in_channels, 1, 3, 3], initializer=initializer)

    def forward(self, x):
        return F.conv2d(x, self.weight, groups=self.in_channels, padding=1)

원래 pointwise와 depthwise 레이어를 만들때 자동 초기화를 했었는데 커스텀 가중치 초기화를 위해서 위와 같이 바꿔주자.

conv_branch

    def _conv_branch(self,
                     inputs,
                     filter_size,
                     count,
                     out_filters,
                     ch_mul=1,
                     start_idx=None,
                     seperable=False
                     ):
        if start_idx is None:
            assert self.fixed_arc is not None, "you need start_idx or fixed_arc"

        if self.data_format == "NHWC":
            c = inputs.get_shape()[3].value
        elif self.data_format == "NCHW":
            c = inputs.get_shape()[1].value

        x = PointwiseConv(c, out_filters)(inputs)
        x = batch_norm(x, data_format=self.data_format)
        x = F.relu(x)

        if start_idx is None:
            if seperable:
                depthwise_conv = DepthwiseConv(in_channels=out_filters)
                x = depthwise_conv(x)
                pointwise_conv = PointwiseConv(in_channels=out_filters * ch_mul, out_channels=count)
                x = pointwise_conv(x)
            else:
                x = nn.Conv2d(c, count, kernel_size=filter_size, padding=filter_size // 2)
                x = batch_norm(x, data_format=self.data_format)

        else:
            if seperable:
                depthwise_conv = DepthwiseConv(in_channels=out_filters)
                x = depthwise_conv(x)

                w_pointwise = create_weight([out_filters * ch_mul, out_filters])
                w_pointwise = w_pointwise[start_idx:start_idx + count, :]
                w_pointwise = w_pointwise(0, 1)
                w_pointwise = w_pointwise.view(1, 1, out_filters * ch_mul, count)
                x = F.conv2d(x, w_pointwise, stride=1, padding=filter_size // 2)
            else:
                w = create_weight([filter_size, filter_size, out_filters, out_filters])
                w = w.transpose(0, 3)
                w = w[start_idx:start_idx + count, :, :, :]
                x = F.conv2d(x, w, stride=1, padding=filter_size // 2)

        return x

macro_child.py에 브랜치를 만들어주는 _conv_branch 매서드를 만들어주자.

이 매서드의 특징은 크게 2가지이다.

  1. seperable를 할지 그냥 conv를 할 지 결정한다.
  2. 마스킹을 이용해서 특정 채널만 연산을 해준다.

우선 위에서 바꾼 depthwise, pointwise를 이용해서 가중치 초기화 없이 만들어 준다.

if문을 이용해서 마스킹 없이 레이어 만들기, 마스킹 써서 레이어 만들기 두 가지로 나뉜 뒤
다시 if 문을 써서 seperable, 그냥 conv 둘 중 하나를 선택 해준다.

주목할 부분은 마스킹에서 커스텀 weight를 사용해야한다는 점이다.

마스킹이란건 아래와 같은 목적이 있다.

즉 필터 중에서 특정 필터만을 선택해서 가중치를 초기화 해야하므로 당연히 내가 직접 커스텀 가중치를 만들고 필터 범위를 선택해줘야한다.

따라서 범위를 위 코드와 같이 count를 이용해 설정하고 F.conv2d를 이용해 직접 레이어를 만들어준다.

profile
미친 취준생

0개의 댓글