음악 분류 딥러닝을 만들자(12) - depthwise, pointwise 구현

응큼한포도·2024년 8월 1일
0
post-thumbnail
import torch.nn as nn

class DepthwiseConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DepthwiseConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,
                                   groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)

        return x

저번 시간에 depthwise의 개념에 대해 알았고 그걸 구현해보자.
depthwise의 핵심은 입력 채널이 커널의 채널과 일대일 대응한다는 점이다.

직접 만들 수 있지만 torch에 이미 conv에 대한 구현이 있다. 우리가 주목할건

nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,
                                   groups=in_channels)

공식문서를 찾아보면 group에 대한 설명이 아래와 같이 나타난다.

At groups = in_channels, each input channel is convolved with its own set of filters. The size of these filters is:

Filter size=(height,width,out_channelsin_channels)\text{Filter size} = \left(\text{height}, \text{width}, \frac{\text{out\_channels}}{\text{in\_channels}}\right)

Where:

  • height: Height of the filter
  • width: Width of the filter
  • in_channels: Number of input channels
  • out_channels: Number of output channels

filter size를 입력 채널의 수와 같이 하는게 목표인데 여기서 out_channels의 수를 in_channels로 하면 결국 filter_size는 1이 된다. 즉 input 채널에 대응되는 filter의 수는 1로 일대일 대응을 이루게 된다.

pointwise의 경우엔 별게 없다. 그냥 kernel_size를 1로 하는 보통의 conv와 같다.

테스트 코드

import torch
import torch.nn as nn
import torchprofile

class DepthwiseConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DepthwiseConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

def test_flops():
    in_channels = 3
    out_channels = 6
    input_res = (3, 3)  # (height, width)

    # 모델 정의
    model = DepthwiseConv(in_channels, out_channels)
    conv_model = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
    )

    # 입력 텐서 정의 (4D 형식으로 수정)
    input_tensor = torch.randn(1, in_channels, *input_res)  # 배치 크기 1

    # Flops 측정
    flops = torchprofile.profile_macs(model, input_tensor)
    print(f"DepthwiseConv FLOPs: {flops}")

    conv_flops = torchprofile.profile_macs(conv_model, input_tensor)
    print(f"Conv2d FLOPs: {conv_flops}")

    # Depthwise Convolution의 Flops가 더 작은지 확인
    assert flops < conv_flops, "Depthwise Convolution should have fewer FLOPs than standard convolution."

# Flops 테스트 실행
test_flops()

pycharm에서 오류가 나 코랩에 돌렸다.

DepthwiseConv FLOPs: 405
Conv2d FLOPs: 1458

결과는 이 정도로 나왔다.

이 만큼 줄어드는 데 내껀 이 값이 0.28정도인가 실제로 결과값도 비슷하게 나왔다.

profile
미친 취준생

0개의 댓글