ResNet 을 이해하고 Pytorch로 구현할 수 있다.
Network architecture
Residual Block (Figure 5)
등장배경
Residual learning: Building block
Bottleneck
VGGNet과 구조 비교
import torch
from torch import nn
from torchinfo import summary
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, inner_channels, stride = 1, projection = None):
super().__init__()
self.residual = nn.Sequential(
nn.Conv2d(in_channels, inner_channels, 3, stride = stride, padding = 1, bias = False),
nn.BatchNorm2d(inner_channels),
nn.ReLU(),
nn.Conv2d(inner_channels, inner_channels * self.expansion, 3, padding = 1, bias = False),
nn.BatchNorm2d(inner_channels * self.expansion),
)
self.projection = projection
self.relu = nn.ReLU()
def forward(self, x):
residual = self.residual(x)
if self.projection is not None:
skip_connection = self.projection(x)
else:
skip_connection = x
out = self.relu(residual + skip_connection)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, inner_channels, stride = 1, projection = None):
super().__init__()
self.residual = nn.Sequential(
nn.Conv2d(in_channels, inner_channels, 1, stride = stride, bias = False),
nn.BatchNorm2d(inner_channels),
nn.ReLU(),
nn.Conv2d(inner_channels, inner_channels, 3, padding = 1, bias = False),
nn.BatchNorm2d(inner_channels),
nn.ReLU(),
nn.Conv2d(inner_channels, inner_channels * self.expansion, 1, bias = False),
nn.BatchNorm2d(inner_channels * self.expansion),
)
self.projection = projection
self.relu = nn.ReLU()
def forward(self, x):
residual = self.residual(x)
if self.projection is not None:
skip_connection = self.projection(x)
else:
skip_connection = x
out = self.relu(residual + skip_connection)
return out
class ResNet(nn.Module):
def __init__(self, block, num_block_list, n_classes = 1000):
super().__init__()
assert len(num_block_list) == 4
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)
self.in_channels = 64
self.stage1 = self.make_stage(block, 64, num_block_list[0], stride = 1)
self.stage2 = self.make_stage(block, 128, num_block_list[1], stride = 2)
self.stage3 = self.make_stage(block, 256, num_block_list[2], stride = 2)
self.stage4 = self.make_stage(block, 512, num_block_list[3], stride = 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, n_classes)
def make_stage(self, block, inner_channels, num_blocks, stride = 1):
if stride != 1 or self.in_channels != inner_channels * block.expansion:
projection = nn.Sequential(
nn.Conv2d(self.in_channels, inner_channels * block.expansion, 1, stride = stride, bias = False),
nn.BatchNorm2d(inner_channels * block.expansion),
)
else:
projection = None
layers = []
for idx in range(num_blocks):
if idx == 0:
layers.append(block(self.in_channels, inner_channels, stride, projection))
self.in_channels = inner_channels * block.expansion
else:
layers.append(block(self.in_channels, inner_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet18(**kwargs):
return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
def resnet34(**kwargs):
return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
def resnet50(**kwargs):
return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
def resnet101(**kwargs):
return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
def resnet152(**kwargs):
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
model = resnet50()
summary(model, input_size=(2,3,224,224), device='cpu')
#### OUTPUT ####
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ResNet [2, 1000] --
├─Sequential: 1-1 [2, 64, 112, 112] --
│ └─Conv2d: 2-1 [2, 64, 112, 112] 9,408
│ └─BatchNorm2d: 2-2 [2, 64, 112, 112] 128
│ └─ReLU: 2-3 [2, 64, 112, 112] --
├─MaxPool2d: 1-2 [2, 64, 56, 56] --
├─Sequential: 1-3 [2, 256, 56, 56] --
│ └─Bottleneck: 2-4 [2, 256, 56, 56] --
│ │ └─Sequential: 3-1 [2, 256, 56, 56] 58,112
│ │ └─Sequential: 3-2 [2, 256, 56, 56] 16,896
│ │ └─ReLU: 3-3 [2, 256, 56, 56] --
│ └─Bottleneck: 2-5 [2, 256, 56, 56] --
│ │ └─Sequential: 3-4 [2, 256, 56, 56] 70,400
│ │ └─ReLU: 3-5 [2, 256, 56, 56] --
│ └─Bottleneck: 2-6 [2, 256, 56, 56] --
│ │ └─Sequential: 3-6 [2, 256, 56, 56] 70,400
│ │ └─ReLU: 3-7 [2, 256, 56, 56] --
├─Sequential: 1-4 [2, 512, 28, 28] --
│ └─Bottleneck: 2-7 [2, 512, 28, 28] --
│ │ └─Sequential: 3-8 [2, 512, 28, 28] 247,296
│ │ └─Sequential: 3-9 [2, 512, 28, 28] 132,096
│ │ └─ReLU: 3-10 [2, 512, 28, 28] --
│ └─Bottleneck: 2-8 [2, 512, 28, 28] --
│ │ └─Sequential: 3-11 [2, 512, 28, 28] 280,064
│ │ └─ReLU: 3-12 [2, 512, 28, 28] --
│ └─Bottleneck: 2-9 [2, 512, 28, 28] --
│ │ └─Sequential: 3-13 [2, 512, 28, 28] 280,064
│ │ └─ReLU: 3-14 [2, 512, 28, 28] --
│ └─Bottleneck: 2-10 [2, 512, 28, 28] --
│ │ └─Sequential: 3-15 [2, 512, 28, 28] 280,064
│ │ └─ReLU: 3-16 [2, 512, 28, 28] --
├─Sequential: 1-5 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-11 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-17 [2, 1024, 14, 14] 986,112
│ │ └─Sequential: 3-18 [2, 1024, 14, 14] 526,336
│ │ └─ReLU: 3-19 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-12 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-20 [2, 1024, 14, 14] 1,117,184
│ │ └─ReLU: 3-21 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-13 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-22 [2, 1024, 14, 14] 1,117,184
│ │ └─ReLU: 3-23 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-14 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-24 [2, 1024, 14, 14] 1,117,184
│ │ └─ReLU: 3-25 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-15 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-26 [2, 1024, 14, 14] 1,117,184
│ │ └─ReLU: 3-27 [2, 1024, 14, 14] --
│ └─Bottleneck: 2-16 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-28 [2, 1024, 14, 14] 1,117,184
│ │ └─ReLU: 3-29 [2, 1024, 14, 14] --
├─Sequential: 1-6 [2, 2048, 7, 7] --
│ └─Bottleneck: 2-17 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-30 [2, 2048, 7, 7] 3,938,304
│ │ └─Sequential: 3-31 [2, 2048, 7, 7] 2,101,248
│ │ └─ReLU: 3-32 [2, 2048, 7, 7] --
│ └─Bottleneck: 2-18 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-33 [2, 2048, 7, 7] 4,462,592
│ │ └─ReLU: 3-34 [2, 2048, 7, 7] --
│ └─Bottleneck: 2-19 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-35 [2, 2048, 7, 7] 4,462,592
│ │ └─ReLU: 3-36 [2, 2048, 7, 7] --
├─AdaptiveAvgPool2d: 1-7 [2, 2048, 1, 1] --
├─Linear: 1-8 [2, 1000] 2,049,000
==========================================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
Total mult-adds (G): 7.72
==========================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 338.80
Params size (MB): 102.23
Estimated Total Size (MB): 442.24
==========================================================================================
와 굉장히 좋은 글이네요.
Visualizing the Loss Landscape of Neural Nets논문도 흥미로워서 재밌는거 같아요.
'Residual Networks Behave Like Ensembles of Relatively Shallow Networks'도 재밌게 읽었었는데, 이것도 추천드립니다!