음악 분류 딥러닝을 만들자(34) - factorized reduction 구현 및 아키텍처 경우의 수 계산

응큼한포도·2024년 10월 2일
0
post-thumbnail

factorized reduction 압축 설명

앞서 설명했지만 압축해서 설명하자면

경로 1: 평균 풀링(average pooling)을 적용한 후, 1x1 필터 크기의 컨볼루션(conv2d)을 사용하여 필터 개수를 조정

경로 2: 입력 이미지를 패딩한 후, 우측 하단으로 시프트하여 동일하게 평균 풀링을 적용하고, 1x1 컨볼루션을 사용

두 경로를 합쳐 하나의 출력으로 만들어, 깊이 감소(stride 2 등)를 수행하면서도 채널 정보 손실을 최소화

이 개념은 resnet과 비슷한 목적을 가졌다. 깊은 네트워크에 학습이 잘 되게 하나의 이미지를 2가지 연산으로 나눠서 이미지의 특징을 각각 맡아서 잘 살려 합쳐서 전달하자. 결국엔 좀 더 많은 특징을 잘 전달할 수 있다는 소리

import torch
import torch.nn as nn
import torch.nn.functional as F
from common_ops import create_weight
from common_ops import create_bias
from image_ops import batch_norm


class MacroChild():
    def __init__(self,
                 images,
                 labels,
                 whole_channels,
                 data_format="NHWC",
                 fixed_arc=None,
                 filters_scale=1,
                 num_layres=2,
                 num_branches=6,
                 filters=24,
                 keep_prob=1.0,
                 batch_size=32,
                 clip_mode=None,
                 grad_bound=None,
                 l2_reg=1e-4,
                 lr_init=0.1,
                 lr_dec_start=0,
                 lr_dec_every=10000,
                 lr_dec_rate=0.1,
                 lr_cosine=False,
                 lr_max=None,
                 lr_min=None,
                 lr_T_num=None,
                 optim_algo=None,
                 sync_replicas=False,
                 num_aggregate=None,
                 num_replicas=None,
                 name="child",
                 *args,
                 **kwargs
                 ):
        self.images = images
        self.labels = labels
        self.whole_channels = whole_channels
        self.data_format = data_format
        self.fixed_arc = fixed_arc
        self.filters_scale = filters_scale
        self.num_layres = num_layres
        self.num_branches = num_branches
        self.filters = filters
        self.keep_prob = keep_prob
        self.batch_size = batch_size
        self.clip_mode = clip_mode
        self.grad_bound = grad_bound
        self.l2_reg = l2_reg
        self.lr_init = lr_init
        self.lr_dec_start = lr_dec_start
        self.lr_dec_every = lr_dec_every
        self.lr_dec_rate = lr_dec_rate
        self.lr_cosine = lr_cosine
        self.lr_max = lr_max
        self.lr_min = lr_min
        self.lr_T_num = lr_T_num
        self.optim_algo = optim_algo
        self.sync_replicas = sync_replicas
        self.num_aggregate = num_aggregate
        self.num_replicas = num_replicas
        self.name = name

    def _get_C(self, x):
        if self.data_format == "NHWC":
            return x.get_shape()[3].value
        elif self.data_format == "NCHW":
            return x.get_shape()[1].value
        else:
            raise ValueError("Unknown data_format '{0}'".format(self.data_format))

    def _get_stride(self, stride):
        if self.data_format == "NHWC":
            return [1, stride, stride, 1]
        elif self.data_format == "NCHW":
            return [1, 1, stride, stride]
        else:
            raise ValueError("Unknown data_format '{0}'".format(self.data_format))

    def _factorized_reduction(self, x, filters, stride):
        assert filters % 2 == 0, (
            "filters must be a even number."
        )
        if stride == 1:
            c = self._get_C(x)
            w = create_weight("w", [1, 1, c, filters])
            path1 = F.conv2d(x, w, stride=1)
            path1 = batch_norm(path1, data_format=self.data_format)
            return path1

        # path1: AvgPooling + Conv
        path1 = F.avg_pool2d(x, kernel_size=1, stride=stride)
        w1 = create_weight("w1", [1, 1, self._get_C(path1), filters // 2])
        path1 = F.conv2d(path1, w1, stride=1)

        # path2: padding + shifting + AvgPooling + Conv
        if self.data_format == "NHWC":
            pad_arr = [0, 1, 0, 1]
            x_padded = F.pad(x, pad_arr)
            path2 = x_padded[:, 1:, 1:, :]
        else:
            pad_arr = [0, 0, 0, 1, 0, 1]
            x_padded = F.pad(x, pad_arr)
            path2 = x_padded[:, :, 1:, 1:]

        path2 = F.avg_pool2d(path2, kernel_size=1, stride=stride)
        w2 = create_weight("w2", [1, 1, self._get_C(path2), filters // 2])
        path2 = F.conv2d(path2, w2, stride=1)

        # Apply BatchNorm after concat two path
        final_path = torch.cat([path1, path2], dim=1 if self.data_format == 'NCHW' else 3)
        final_path = batch_norm(final_path, data_format=self.data_format)

        return final_path

코드에서 stride=1은 그냥 conv를 만드는 코드고 바뀔수 있음

path1 채널 절반, path2 채널 절반씩 나눠서 각각 다른 연산을 거쳐서 합친다. stride=2에선 이미지를 보존하기 위해서 패딩을 쳐줌.

profile
미친 취준생

0개의 댓글