import torch.nn as nn
class DepthwiseConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DepthwiseConv, self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,
groups=in_channels)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
저번 시간에 depthwise의 개념에 대해 알았고 그걸 구현해보자.
depthwise의 핵심은 입력 채널이 커널의 채널과 일대일 대응한다는 점이다.
직접 만들 수 있지만 torch에 이미 conv에 대한 구현이 있다. 우리가 주목할건
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,
groups=in_channels)
공식문서를 찾아보면 group에 대한 설명이 아래와 같이 나타난다.
At groups = in_channels
, each input channel is convolved with its own set of filters. The size of these filters is:
Where:
filter size를 입력 채널의 수와 같이 하는게 목표인데 여기서 out_channels의 수를 in_channels로 하면 결국 filter_size는 1이 된다. 즉 input 채널에 대응되는 filter의 수는 1로 일대일 대응을 이루게 된다.
pointwise의 경우엔 별게 없다. 그냥 kernel_size를 1로 하는 보통의 conv와 같다.
import torch
import torch.nn as nn
import torchprofile
class DepthwiseConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(DepthwiseConv, self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
def test_flops():
in_channels = 3
out_channels = 6
input_res = (3, 3) # (height, width)
# 모델 정의
model = DepthwiseConv(in_channels, out_channels)
conv_model = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
)
# 입력 텐서 정의 (4D 형식으로 수정)
input_tensor = torch.randn(1, in_channels, *input_res) # 배치 크기 1
# Flops 측정
flops = torchprofile.profile_macs(model, input_tensor)
print(f"DepthwiseConv FLOPs: {flops}")
conv_flops = torchprofile.profile_macs(conv_model, input_tensor)
print(f"Conv2d FLOPs: {conv_flops}")
# Depthwise Convolution의 Flops가 더 작은지 확인
assert flops < conv_flops, "Depthwise Convolution should have fewer FLOPs than standard convolution."
# Flops 테스트 실행
test_flops()
pycharm에서 오류가 나 코랩에 돌렸다.
DepthwiseConv FLOPs: 405
Conv2d FLOPs: 1458
결과는 이 정도로 나왔다.
이 만큼 줄어드는 데 내껀 이 값이 0.28정도인가 실제로 결과값도 비슷하게 나왔다.