ResNeXt 을 이해하고 Pytorch로 구현할 수 있다.
in channels = 128
이다. 이를 32 개의 그룹으로 쪼갤 것이다. 각 그룹별 in_channels = 4
이다. (128 / 32 = 4)
out_channels = 128
이다. 각 그룹별로 4 개의 feature map을 뽑아 총 128개의 feature를 만든다. (32 x 4 = 128)
(Grouped Convolution을 적용할 때 out_channels
는 cardinarity의 배수여야 함.)
개선
파라미터 수
일반 Conv와 Grouped Conv를 비교해보자.
nn.Conv2d(64, 128, 3)
vs nn.Conv2d(64, 128, 3, groups = 32)
(1) nn.Conv2d(64, 128, 3)
: 64 x 128 x 3 x 3
(2) nn.Conv2d(64, 128, 3, groups = 32)
: 32 x (2 x 4 x 3 x 3) = 64 x 4 x 3 x 3
즉, 그룹 수만큼 파라미터 수가 줄어든다.
Bottleneck block 개선
결과
ResNet-101
과 ResNeXt-101
을 비교해보면 ResNeXt-101
의 Error가 더 낮은 것을 볼 수 있다. 이는 Grouped Conv로 파라미터 수를 줄이면서, 채널 수를 늘려 Bottleneck 현상을 어느정도 완화 시킨 것이 ResNet
의 Bottleneck 보다 효과적이라고 해석할 수 있다.
ResNet-200
과 ResNet-101, wider
를 비교해보면 ResNet-101, wider
의 Error가 조금 더 낮다. 이는 layer의 깊이도 중요하지만 layer를 구성하는 필터 수(너비)도 꽤 중요하다고 볼 수 있다.
ResNeXt-101, C = 2
와 ResNeXt-101, C = 64
를 비교해보면 그룹 수가 많은 것이 더 Error가 낮다. 또한 ResNet-200
과 ResNet-101, wider
에 비해 ResNeXt-101
이 압도적으로 성능이 좋다. 이는 다음과 같이 해석 할 수 있다. 만약, 모델의 파라미터를 2 배 키우고 싶다면.. 층을 깊이 쌓거나, 단순히 너비(채널)를 늘리는 것보다는 층을 유지하고 Grouped Conv를 사용하면서 채널 수를 늘리는 방법이 효율성이 뛰어다는 것을 시사한다.
ResNet과 매우 유사하다.
import torch
from torch import nn
from torchinfo import summary
class Bottleneck(nn.Module):
expansion = 2
def __init__(self, in_channels, inner_channels, cardinality, stride = 1, projection = None):
super().__init__()
self.residual = nn.Sequential(
nn.Conv2d(in_channels, inner_channels, 1, stride = stride, bias = False),
nn.BatchNorm2d(inner_channels),
nn.ReLU(),
nn.Conv2d(inner_channels, inner_channels, 3, stride = 1, padding = 1, groups = cardinality, bias = False),
nn.BatchNorm2d(inner_channels),
nn.ReLU(),
nn.Conv2d(inner_channels, inner_channels * self.expansion, 1, bias = False),
nn.BatchNorm2d(inner_channels * self.expansion),
)
self.relu = nn.ReLU()
self.projection = projection
def forward(self, x):
residual = self.residual(x)
if self.projection is not None:
skip_connection = self.projection(x)
else:
skip_connection = x
out = self.relu(residual + skip_connection)
return out
class ResNeXt(nn.Module):
def __init__(self, block, block_list, cardinality, n_classes = 1000):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 7, stride = 2, padding = 3, bias = False),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)
self.in_channels = 64
self.stage1 = self.make_stage(block, 128, block_list[0], stride = 1, cardinality = cardinality)
self.stage2 = self.make_stage(block, 256, block_list[1], stride = 2, cardinality = cardinality)
self.stage3 = self.make_stage(block, 512, block_list[2], stride = 2, cardinality = cardinality)
self.stage4 = self.make_stage(block, 1024, block_list[3], stride = 2, cardinality = cardinality)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(2048, n_classes)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def make_stage(self, block, inner_channels, block_nums, stride, cardinality):
if self.in_channels != inner_channels * block.expansion or stride != 1:
projection = nn.Sequential(
nn.Conv2d(self.in_channels, inner_channels * block.expansion, 1, stride = stride, bias = False),
nn.BatchNorm2d(inner_channels * block.expansion)
)
else:
projection = None
layers = []
for idx in range(block_nums):
if idx == 0:
layers.append(block(self.in_channels, inner_channels, cardinality, stride, projection))
self.in_channels = inner_channels * block.expansion
else:
layers.append(block(self.in_channels, inner_channels, cardinality))
return nn.Sequential(*layers)
def ResNeXt50():
return ResNeXt(Bottleneck, [3, 4, 6, 3], cardinality = 32)
def ResNeXt101():
return ResNeXt(Bottleneck, [3, 4, 23, 3], cardinality = 32)
model = ResNeXt50()
summary(model, input_size=(2,3,224,224), device='cpu')
#### OUTPUT ####
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ResNeXt [2, 1000] --
├─Sequential: 1-1 [2, 64, 112, 112] --
│ └─Conv2d: 2-1 [2, 64, 112, 112] 9,408
│ └─BatchNorm2d: 2-2 [2, 64, 112, 112] 128
│ └─ReLU: 2-3 [2, 64, 112, 112] --
├─MaxPool2d: 1-2 [2, 64, 56, 56] --
├─Sequential: 1-3 [2, 256, 56, 56] --
│ └─Bottleneck: 2-4 [2, 256, 56, 56] --
│ │ └─Sequential: 3-1 [2, 256, 56, 56] 46,592
│ │ └─Sequential: 3-2 [2, 256, 56, 56] 16,896
│ │ └─ReLU: 3-3 [2, 256, 56, 56] --
│ └─Bottleneck: 2-5 [2, 256, 56, 56] --
│ │ └─Sequential: 3-4 [2, 256, 56, 56] 71,168
│ │ └─ReLU: 3-5 [2, 256, 56, 56] --
│ └─Bottleneck: 2-6 [2, 256, 56, 56] --
│ │ └─Sequential: 3-6 [2, 256, 56, 56] 71,168
│ │ └─ReLU: 3-7 [2, 256, 56, 56] --
├─Sequential: 1-4 [2, 512, 28, 28] --
│ └─Bottleneck: 2-7 [2, 512, 28, 28] --
│ │ └─Sequential: 3-8 [2, 512, 28, 28] 217,088
│ │ └─Sequential: 3-9 [2, 512, 28, 28] 132,096
│ │ └─ReLU: 3-10 [2, 512, 28, 28] --
│ └─Bottleneck: 2-8 [2, 512, 28, 28] --
│ │ └─Sequential: 3-11 [2, 512, 28, 28] 282,624
│ │ └─ReLU: 3-12 [2, 512, 28, 28] --
│ └─Bottleneck: 2-9 [2, 512, 28, 28] --
│ │ └─Sequential: 3-13 [2, 512, 28, 28] 282,624
│ │ └─ReLU: 3-14 [2, 512, 28, 28] --
│ └─Bottleneck: 2-10 [2, 512, 28, 28] --
│ │ └─Sequential: 3-15 [2, 512, 28, 28] 282,624
│ │ └─ReLU: 3-16 [2, 512, 28, 28] --
├─Sequential: 1-5 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-11 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-17 [2, 1024, 14, 14] 864,256
│ │ └─Sequential: 3-18 [2, 1024, 14, 14] 526,336
│ │ └─ReLU: 3-19 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-12 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-20 [2, 1024, 14, 14] 1,126,400
│ │ └─ReLU: 3-21 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-13 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-22 [2, 1024, 14, 14] 1,126,400
│ │ └─ReLU: 3-23 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-14 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-24 [2, 1024, 14, 14] 1,126,400
│ │ └─ReLU: 3-25 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-15 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-26 [2, 1024, 14, 14] 1,126,400
│ │ └─ReLU: 3-27 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-16 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-28 [2, 1024, 14, 14] 1,126,400
│ │ └─ReLU: 3-29 [2, 1024, 14, 14] --
├─Sequential: 1-6 [2, 2048, 7, 7] --
│ └─Bottleneck: 2-17 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-30 [2, 2048, 7, 7] 3,448,832
│ │ └─Sequential: 3-31 [2, 2048, 7, 7] 2,101,248
│ │ └─ReLU: 3-32 [2, 2048, 7, 7] --
│ └─Bottleneck: 2-18 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-33 [2, 2048, 7, 7] 4,497,408
│ │ └─ReLU: 3-34 [2, 2048, 7, 7] --
│ └─Bottleneck: 2-19 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-35 [2, 2048, 7, 7] 4,497,408
│ │ └─ReLU: 3-36 [2, 2048, 7, 7] --
├─AdaptiveAvgPool2d: 1-7 [2, 2048, 1, 1] --
├─Linear: 1-8 [2, 1000] 2,049,000
==========================================================================================
Total params: 25,028,904
Trainable params: 25,028,904
Non-trainable params: 0
Total mult-adds (G): 7.54
==========================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 427.11
Params size (MB): 100.12
Estimated Total Size (MB): 528.43
==========================================================================================