음악 분류 딥러닝을 만들자(25) - search space를 객체지향적으로 refactor

응큼한포도·2024년 8월 30일
0

search space 리팩터링

저번 시간에 구성한 mobileNetv3 search space를 확장성을 위해서 리팩터링하자.

from abc import ABC, abstractmethod

class SearchSpace:
    def __init__(self):
        pass

    @abstractmethod
    def get_search_space_ops(self) -> dict:
        """
        Returns search space ops as a dictionary of lists.
        """
        pass

많은 논문들과 구현체들을 살펴본 결과 ops의 형태로 search space를 관리한다.

프로젝트의 구현에서 다양한 nas를 구현하기 때문에 search space의 원소는 언제든지 바뀔 수 있으니 abc 클래스를 이용해 인터페이스 형태로 구현해주자.

@abstractmethod를 이용해서 꼭 구현해야 되는 매서드를 만들어주고 ops를 dictionary list로 반환하게 강제해주자.

이제 이 클래스를 이용해서

class MobilenetV3SearchSpace(SearchSpace):
    def __init__(self):
        super().__init__()
        self.conv_ops = ['conv', 'dconv', 'mbconv', 'pconv']  # Convolution operations
        self.kernel_sizes = [3, 5]  # Kernel sizes
        self.se_ratios = [0, 0.25]  # Squeeze-and-Excitation ratios
        self.skip_ops = ['none', 'identity', 'pool']  # Skip operations
        self.filter_sizes = [0.75, 1.0, 1.25]  # Filter sizes
        self.num_layers = [-1, 0, 1]  # Number of layers per block

    def get_search_space_ops(self) -> dict:
        """
        Returns the complete search space as a dictionary of lists.
        """
        return {
            'ConvOp': self.conv_ops,
            'KernelSize': self.kernel_sizes,
            'SERatio': self.se_ratios,
            'SkipOp': self.skip_ops,
            'FilterSize': self.filter_sizes,
            'NumLayers': self.num_layers
        }

이런식의 ops들을 내 맘대로 만들어주고 실제 구현에선

        if conv_op == 'mbconv':
            layers.append(InvertedResidual(in_channels, out_channels, stride, expand_ratio))
        elif conv_op == 'conv':
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size//2))
        elif conv_op == 'dconv':
            layers.append(DepthwiseConv(in_channels))
        elif conv_op == 'pconv':
            layers.append(PointwiseConv(in_channels, out_channels))

이런식으로 사용하면 된다. CNN에선 어느정도 구조가 잡혀 있으니
conv_branch, pool_branch 등을 만들고 ops에서 관리해서 사용하면 된다.

나중에 실제 구현을 해서 보여주겠다.

profile
미친 취준생

0개의 댓글