ResNet์ ๋ง์ดํฌ๋ก์ํํธ์์ ๊ฐ๋ฐํ ์๊ณ ๋ฆฌ์ฆ์ผ๋ก "Deep Residual Learning for Image Recognition"์ด๋ผ๋ ๋ ผ๋ฌธ์์ ๋ฐํ๋์๋ค. ResNet ํต์ฌ์ ๊น์ด์ง ์ ๊ฒฝ๋ง์ ํจ๊ณผ์ ์ผ๋ก ํ์ตํ๊ธฐ ์ํ ๋ฐฉ๋ฒ์ผ๋ก ๋ ์ง๋์ผ(residual) ๊ฐ๋ ์ ๊ณ ์ํ ๊ฒ์ด๋ค.
์ผ๋ฐ์ ์ผ๋ก ์ ๊ฒฝ๋ง ๊น์ด๊ฐ ๊น์ด์ง์๋ก ์ฑ๋ฅ์ด ์ข์์ง๋ค๊ฐ ์ผ์ ํ ๋จ๊ณ์ ๋ค๋ค๋ฅด๋ฉด ์คํ๋ ค ์ฑ๋ฅ์ด ๋๋น ์ง๋ค๋ ๋จ์ ์ด ์๋ค.
ResNet์ ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ ์ง๋์ผ ๋ธ๋ก(residual block)์ ๋์
ํ๋ค. ๋ ์ง๋์ผ ๋ธ๋ก์ ๊ธฐ์ธ๊ธฐ๊ฐ ์ ์ ํ๋ ์ ์๋๋ก ์ผ์ข
์ ์์ปท(shortcut, skip connection)
์ ๋ง๋ค์ด ๊ธฐ์ธ๊ธฐ ์๋ฉธ ๋ฌธ์ ๋ฅผ ๋ฐฉ์งํ๋ค.
๐๐ป ๋ธ๋ก(block)์ด๋ ๊ณ์ธต์ ๋ฌถ์์ด๋ค. ํฉ์ฑ๊ณฑ์ธต์ ํ๋์ ๋ธ๋ก์ผ๋ก ๋ฌถ์ ๊ฒ์ด๋ค.
์๋์ ๊ทธ๋ฆผ์์ ๊ฐ์ ์์ผ๋ก ๋ฌถ์ธ ๊ณ์ธต๋ค์ ํ๋์ ๋ ์ง๋์ผ ๋ธ๋ก์ด๋ผ๊ณ ํ๋ฉฐ ๋ ์ง๋์ผ ๋ธ๋ก์ ์ฌ๋ฌ ๊ฐ ์์ ๊ฒ์ ResNet์ด๋ผ๊ณ ํ๋ค.
๊ณ์ธต์ ๊น์ด๊ฐ ๊น์ด์ง์๋ก ํ๋ผ๋ฏธํฐ๊ฐ ๋ฌด์ ํ์ผ๋ก ์ปค์ง๊ธฐ ๋๋ฌธ์ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ณ๋ชฉ ๋ธ๋ก(bottleneck block)์ด๋ผ๋ ๊ฒ์ ๋์๋ค.
โถ๏ธ ๋ณ๋ชฉ ๋ธ๋ก์ ์ฌ์ฉํ๋ ResNet50์ ์ดํด๋ณด์!
ResNet50์์๋ 3x3 ํฉ์ฑ๊ณฑ์ธต ์๋ค๋ก 1x1 ํฉ์ฑ๊ณฑ์ธต์ด ๋ถ์ด ์๋๋ฐ, 1x1 ํฉ์ฑ๊ณฑ์ธต์ ์ฑ๋ ์๋ฅผ ์กฐ์ ํ๋ฉด์ ์ฐจ์์ ์ค์๋ค ๋๋ฆฌ๋ ๊ฒ์ด ๊ฐ๋ฅํ๊ธฐ ๋๋ฌธ์ ํ๋ผ๋ฏธํฐ ์๋ฅผ ์ค์ผ ์ ์๋ค.
์์ด๋ดํฐํฐ ๋งคํ(identity mapping)์ด๋ ์ ๋ ฅ x๊ฐ ์ด๋ค ํจ์๋ฅผ ํต๊ณผํ๋๋ผ๋ ๋ค์ x๋ผ๋ ํํ๋ก ์ถ๋ ฅ๋๋๋ก ํ๋ค.
๋ค์ด์ํ(downsample)์ด๋ ํน์ฑ๋ฑ ํฌ๊ธฐ๋ฅผ ์ค์ด๊ธฐ ์ํ ๊ฒ์ผ๋ก ํ๋ง๊ณผ ๊ฐ์ ์ญํ ์ ํ๋ค. ๋ค๋ฅธ ๋ ์ง๋์ผ ๋ธ๋ญ๊ฐ์ ํํ๋ฅผ ๋ง์ถ์ง ์์ผ๋ฉด ์์ด๋ดํฐํฐ ๋งคํ์ ํ ์ ์๊ธฐ ๋๋ฌธ์ ์์ด๋ดํฐํฐ์ ๋ํด ๋ค์ด์ํ์ด ํ์ํ๋ค.
๋ง์ฝ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ํํ๋ฅผ ๊ฐ๋๋ก ๋ง์ถ์ด ์ฃผ๊ธฐ ์ํด์๋
์คํธ๋ผ์ด๋ 2๋ฅผ ๊ฐ์ง 1x1 ํฉ์ฑ๊ณฑ ๊ณ์ธต์ ํ๋ ์ฐ๊ฒฐ
ํด์ฃผ๋ฉด ๋๋ค.โถ๏ธ ์ ๋ ฅ๊ณผ ์ถ๋ ฅ์ ์ฐจ์์ด ๊ฐ์ ๊ฒ์ ์์ดํ ๋๋ ๋ธ๋ก์ด๋ผ๊ณ ํ๋ฉฐ, ์ ๋ ฅ ๋ฐ ์ถ๋ ฅ ์ฐจ์์ด ๋์ผํ์ง ์๊ณ ์ ๋ ฅ์ ์ฐจ์์ ์ถ๋ ฅ์ ๋ง์ถ์ด ๋ณ๊ฒฝํด์ผ ํ๋ ๊ฒ์ ํฉ์ฑ๊ณฑ ๋ธ๋ก์ด๋ผ๊ณ ํ๋ค.
๐๐ป ResNet์ ์ ์ฒด ๋คํธ์ํฌ ๊ตฌ์ฑ์ ์ํด ๊ทธ๊ฒ์ ๊ตฌ์ฑํ๋ ๊ธฐ๋ณธ ๋ธ๋ก
๊ณผ ๋ณ๋ชฉ ๋ธ๋ก
์ ๋ํ ์ฝ๋๋ฅผ ์ดํด๋ณด์
1) ๊ธฐ๋ณธ ๋ธ๋ก
# ๊ธฐ๋ณธ ๋ธ๋ก
# ResNet18, ResNet34์์ ์ฌ์ฉ๋จ
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3,
stride = stride, padding = 1, bias = False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3,
stride = 1, padding = 1, bias = False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace = True)
if downsample: # ๋ค์ด์ํ์ด ์ ์ฉ๋จ -> ์
๋ ฅ ๋ฐ์ดํฐ์ ํฌ๊ธฐ์ ๋คํธ์ํฌ๋ฅผ ํต๊ณผํ ํ ์ถ๋ ฅ ๋ฐ์ดํฐ์ ํฌ๊ธฐ๊ฐ ๋ค๋ฅผ ๊ฒฝ์ฐ์ ์ฌ์ฉํ๋ค.
conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1,
stride = stride, bias = False)
bn = nn.BatchNorm2d(out_channels)
downsample = nn.Sequential(conv, bn)
else:
downsample = None
self.downsample = downsample
def forward(self, x):
i = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.downsample is not None:
i = self.downsample(i)
x += i # identity mapping (skip connection)
x = self.relu(x)
return x
2) ๋ณ๋ชฉ ๋ธ๋ก
# ๋ณ๋ชฉ ๋ธ๋ก
# ResNet50,ResNet101,ResNet152์์ ์ ์ฉ๋จ
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, bias = False) # 1x1ํฉ์ฑ๊ณฑ์ธต
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = stride, padding = 1, bias = False) # 3x3ํฉ์ฑ๊ณฑ์ธต
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size = 1,
stride = 1, bias = False)
# 1x1ํฉ์ฑ๊ณฑ์ธต
# self.expansion * out_channels : ๋ค์ ๊ณ์ธต์ ์
๋ ฅ ์ฑ๋ ์์ ์ผ์นํ๋๋กํ๊ธฐ ์ํจ
self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)
self.relu = nn.ReLU(inplace = True)
if downsample:
conv = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size = 1,
stride = stride, bias = False)
bn = nn.BatchNorm2d(self.expansion * out_channels)
downsample = nn.Sequential(conv, bn)
else:
downsample = None
self.downsample = downsample
def forward(self, x):
i = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.downsample is not None:
i = self.downsample(i)
x += i
x = self.relu(x)
return x
1) ์ง์ layer๋ง๋ค๊ธฐ
class ResNet(nn.Module):
def __init__(self, config, output_dim, zero_init_residual=False):
super().__init__()
block, n_blocks, channels = config # - โ
self.in_channels = channels[0]
assert len(n_blocks) == len(channels) == 4
self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size = 7, stride = 2, padding = 3, bias = False)
self.bn1 = nn.BatchNorm2d(self.in_channels)
self.relu = nn.ReLU(inplace = True)
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0]) # - โก
self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride = 2)
self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride = 2)
self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride = 2)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(self.in_channels, output_dim)
if zero_init_residual: # - โข
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def get_resnet_layer(self, block, n_blocks, channels, stride = 1): # - โฃ
layers = []
if self.in_channels != block.expansion * channels:
downsample = True
else:
downsample = False
layers.append(block(self.in_channels, channels, stride, downsample))
for i in range(1, n_blocks):
layers.append(block(block.expansion * channels, channels))
self.in_channels = block.expansion * channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
h = x.view(x.shape[0], -1)
x = self.fc(h)
return x, h
โ config ํ๊ฒฝ ์ค์ (์ด๋ค resnet์ ์ฌ์ฉํ ์ง์ ๋ฐ๋ผ ๊ฐ์ด ๋ฌ๋ผ์ง๋ค.)
โก ์์ ์ฐ๊ฒฐ ์ธต์ผ๋ก get_resnet_layerํจ์๋ฅผ ์ด์ฉํด layer๋ฅผ ์์ฑํ๋ค.
โข ๊ฐ ๋ ์ง๋์ผ ๋ถ๊ธฐ์ ์๋ ๋ง์ง๋ง Batch Normalization์ 0์ผ๋ก ์ด๊ธฐํํด์ ๋ค์ ๋ ์ง๋์ผ ๋ถ๊ธฐ๋ฅผ 0์์ ์๊ฐ์ ์ ์๋๋กํ๋ค.
-> ํ์๋ ์๋๋ 0์ผ๋ก ์ด๊ธฐํ๋ฅผ ํ ๊ฒฝ์ฐ ๋ชจ๋ธ์ ์ฑ๋ฅ์ด 0.2 ~ 0.3% ์ ๋ ํฅ์๋๋ค.
โฃ ์๋์ ์ฝ๋๋ฅผ ํตํด ๊ตฌํํ๊ณ ์ ํ๋ ResNet์ ์ ์ํ๋ค.
ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])
# ๊ธฐ๋ณธ ๋ธ๋ญ์ ์ฌ์ฉํ ResNet
resnet18_config = ResNetConfig(block = BasicBlock,
n_blocks = [2,2,2,2],
channels = [64, 128, 256, 512])
resnet34_config = ResNetConfig(block = BasicBlock,
n_blocks = [3,4,6,3],
channels = [64, 128, 256, 512])
# ๋ณ๋ชฉ ๋ธ๋ญ์ ์ฌ์ฉํ ResNet
resnet50_config = ResNetConfig(block = Bottleneck,
n_blocks = [3, 4, 6, 3],
channels = [64, 128, 256, 512])
resnet101_config = ResNetConfig(block = Bottleneck,
n_blocks = [3, 4, 23, 3],
channels = [64, 128, 256, 512])
resnet152_config = ResNetConfig(block = Bottleneck,
n_blocks = [3, 8, 36, 3],
channels = [64, 128, 256, 512])
2) ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ์ฌ์ฉํ๊ธฐ
pretrained_model = models.resnet50(pretrained = True)