ResNet ๐ŸŒ…

์„œ์€์„œยท2023๋…„ 8์›” 15์ผ
0

PyTorch

๋ชฉ๋ก ๋ณด๊ธฐ
4/5
post-thumbnail
post-custom-banner

ResNet์€ ๋งˆ์ดํฌ๋กœ์†Œํ”„ํŠธ์—์„œ ๊ฐœ๋ฐœํ•œ ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ "Deep Residual Learning for Image Recognition"์ด๋ผ๋Š” ๋…ผ๋ฌธ์—์„œ ๋ฐœํ‘œ๋˜์—ˆ๋‹ค. ResNet ํ•ต์‹ฌ์€ ๊นŠ์–ด์ง„ ์‹ ๊ฒฝ๋ง์„ ํšจ๊ณผ์ ์œผ๋กœ ํ•™์Šตํ•˜๊ธฐ ์œ„ํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ ๋ ˆ์ง€๋“€์–ผ(residual) ๊ฐœ๋…์„ ๊ณ ์•ˆํ•œ ๊ฒƒ์ด๋‹ค.


ResNet

์ผ๋ฐ˜์ ์œผ๋กœ ์‹ ๊ฒฝ๋ง ๊นŠ์ด๊ฐ€ ๊นŠ์–ด์งˆ์ˆ˜๋ก ์„ฑ๋Šฅ์ด ์ข‹์•„์ง€๋‹ค๊ฐ€ ์ผ์ •ํ•œ ๋‹จ๊ณ„์— ๋‹ค๋‹ค๋ฅด๋ฉด ์˜คํžˆ๋ ค ์„ฑ๋Šฅ์ด ๋‚˜๋น ์ง„๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ๋‹ค.

ResNet์€ ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๋ ˆ์ง€๋“€์–ผ ๋ธ”๋ก(residual block)์„ ๋„์ž…ํ–ˆ๋‹ค. ๋ ˆ์ง€๋“€์–ผ ๋ธ”๋ก์€ ๊ธฐ์šธ๊ธฐ๊ฐ€ ์ž˜ ์ „ํŒŒ๋  ์ˆ˜ ์žˆ๋„๋ก ์ผ์ข…์˜ ์ˆ์ปท(shortcut, skip connection)์„ ๋งŒ๋“ค์–ด ๊ธฐ์šธ๊ธฐ ์†Œ๋ฉธ ๋ฌธ์ œ๋ฅผ ๋ฐฉ์ง€ํ–ˆ๋‹ค.
๐Ÿ‘‰๐Ÿป ๋ธ”๋ก(block)์ด๋ž€ ๊ณ„์ธต์˜ ๋ฌถ์Œ์ด๋‹ค. ํ•ฉ์„ฑ๊ณฑ์ธต์„ ํ•˜๋‚˜์˜ ๋ธ”๋ก์œผ๋กœ ๋ฌถ์€ ๊ฒƒ์ด๋‹ค.

์•„๋ž˜์˜ ๊ทธ๋ฆผ์—์„œ ๊ฐ™์€ ์ƒ‰์œผ๋กœ ๋ฌถ์ธ ๊ณ„์ธต๋“ค์„ ํ•˜๋‚˜์˜ ๋ ˆ์ง€๋“€์–ผ ๋ธ”๋ก์ด๋ผ๊ณ  ํ•˜๋ฉฐ ๋ ˆ์ง€๋“€์–ผ ๋ธ”๋ก์„ ์—ฌ๋Ÿฌ ๊ฐœ ์Œ“์€ ๊ฒƒ์„ ResNet์ด๋ผ๊ณ  ํ•œ๋‹ค.

๊ณ„์ธต์˜ ๊นŠ์ด๊ฐ€ ๊นŠ์–ด์งˆ์ˆ˜๋ก ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋ฌด์ œํ•œ์œผ๋กœ ์ปค์ง€๊ธฐ ๋•Œ๋ฌธ์— ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๋ณ‘๋ชฉ ๋ธ”๋ก(bottleneck block)์ด๋ผ๋Š” ๊ฒƒ์„ ๋‘์—ˆ๋‹ค.

โ–ถ๏ธŽ ๋ณ‘๋ชฉ ๋ธ”๋ก์„ ์‚ฌ์šฉํ•˜๋Š” ResNet50์„ ์‚ดํŽด๋ณด์ž!
ResNet50์—์„œ๋Š” 3x3 ํ•ฉ์„ฑ๊ณฑ์ธต ์•ž๋’ค๋กœ 1x1 ํ•ฉ์„ฑ๊ณฑ์ธต์ด ๋ถ™์–ด ์žˆ๋Š”๋ฐ, 1x1 ํ•ฉ์„ฑ๊ณฑ์ธต์˜ ์ฑ„๋„ ์ˆ˜๋ฅผ ์กฐ์ ˆํ•˜๋ฉด์„œ ์ฐจ์›์„ ์ค„์˜€๋‹ค ๋Š˜๋ฆฌ๋Š” ๊ฒƒ์ด ๊ฐ€๋Šฅํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋ฅผ ์ค„์ผ ์ˆ˜ ์žˆ๋‹ค.

ResNet ์šฉ์–ด ์ •๋ฆฌ

  • ์•„์ด๋ดํ‹ฐํ‹ฐ ๋งคํ•‘(identity mapping)์ด๋ž€ ์ž…๋ ฅ x๊ฐ€ ์–ด๋–ค ํ•จ์ˆ˜๋ฅผ ํ†ต๊ณผํ•˜๋”๋ผ๋„ ๋‹ค์‹œ x๋ผ๋Š” ํ˜•ํƒœ๋กœ ์ถœ๋ ฅ๋˜๋„๋ก ํ•œ๋‹ค.

  • ๋‹ค์šด์ƒ˜ํ”Œ(downsample)์ด๋ž€ ํŠน์„ฑ๋ฑ ํฌ๊ธฐ๋ฅผ ์ค„์ด๊ธฐ ์œ„ํ•œ ๊ฒƒ์œผ๋กœ ํ’€๋ง๊ณผ ๊ฐ™์€ ์—ญํ• ์„ ํ•œ๋‹ค. ๋‹ค๋ฅธ ๋ ˆ์ง€๋“€์–ผ ๋ธ”๋Ÿญ๊ฐ„์˜ ํ˜•ํƒœ๋ฅผ ๋งž์ถ”์ง€ ์•Š์œผ๋ฉด ์•„์ด๋ดํ‹ฐํ‹ฐ ๋งคํ•‘์„ ํ•  ์ˆ˜ ์—†๊ธฐ ๋•Œ๋ฌธ์— ์•„์ด๋ดํ‹ฐํ‹ฐ์— ๋Œ€ํ•ด ๋‹ค์šด์ƒ˜ํ”Œ์ด ํ•„์š”ํ•˜๋‹ค.

    ๋งŒ์•ฝ ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ์˜ ํ˜•ํƒœ๋ฅผ ๊ฐ™๋„๋ก ๋งž์ถ”์–ด ์ฃผ๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ŠคํŠธ๋ผ์ด๋“œ 2๋ฅผ ๊ฐ€์ง„ 1x1 ํ•ฉ์„ฑ๊ณฑ ๊ณ„์ธต์„ ํ•˜๋‚˜ ์—ฐ๊ฒฐํ•ด์ฃผ๋ฉด ๋œ๋‹ค.

    โ–ถ๏ธŽ ์ž…๋ ฅ๊ณผ ์ถœ๋ ฅ์˜ ์ฐจ์›์ด ๊ฐ™์€ ๊ฒƒ์„ ์•„์ดํ…๋””๋”” ๋ธ”๋ก์ด๋ผ๊ณ  ํ•˜๋ฉฐ, ์ž…๋ ฅ ๋ฐ ์ถœ๋ ฅ ์ฐจ์›์ด ๋™์ผํ•˜์ง€ ์•Š๊ณ  ์ž…๋ ฅ์˜ ์ฐจ์›์„ ์ถœ๋ ฅ์— ๋งž์ถ”์–ด ๋ณ€๊ฒฝํ•ด์•ผ ํ•˜๋Š” ๊ฒƒ์„ ํ•ฉ์„ฑ๊ณฑ ๋ธ”๋ก์ด๋ผ๊ณ  ํ•œ๋‹ค.

Code๋กœ ์‚ดํŽด๋ณด๊ธฐ

Block

๐Ÿ‘‰๐Ÿป ResNet์˜ ์ „์ฒด ๋„คํŠธ์›Œํฌ ๊ตฌ์„ฑ์„ ์œ„ํ•ด ๊ทธ๊ฒƒ์„ ๊ตฌ์„ฑํ•˜๋Š” ๊ธฐ๋ณธ ๋ธ”๋ก๊ณผ ๋ณ‘๋ชฉ ๋ธ”๋ก์— ๋Œ€ํ•œ ์ฝ”๋“œ๋ฅผ ์‚ดํŽด๋ณด์ž
1) ๊ธฐ๋ณธ ๋ธ”๋ก

# ๊ธฐ๋ณธ ๋ธ”๋ก
# ResNet18, ResNet34์—์„œ ์‚ฌ์šฉ๋จ
class BasicBlock(nn.Module):    
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
        super().__init__()                
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, 
                               stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, 
                               stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)        
        self.relu = nn.ReLU(inplace = True)
        
        if downsample:					# ๋‹ค์šด์ƒ˜ํ”Œ์ด ์ ์šฉ๋จ -> ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ์™€ ๋„คํŠธ์›Œํฌ๋ฅผ ํ†ต๊ณผํ•œ ํ›„ ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ๊ฐ€ ๋‹ค๋ฅผ ๊ฒฝ์šฐ์— ์‚ฌ์šฉํ•œ๋‹ค.
            conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1, 
                             stride = stride, bias = False)
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None        
        self.downsample = downsample
        
    def forward(self, x):       
        i = x       
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)        
        x = self.conv2(x)
        x = self.bn2(x)
        
        if self.downsample is not None:
            i = self.downsample(i)
                        
        x += i						# identity mapping (skip connection)
        x = self.relu(x)
        
        return x

2) ๋ณ‘๋ชฉ ๋ธ”๋ก

  • 1x1ํ•ฉ์„ฑ๊ณฑ์ธต, 3x3ํ•ฉ์„ฑ๊ณฑ์ธต, 1x1ํ•ฉ์„ฑ๊ณฑ์ธต์œผ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.
  • ๊ณ„์ธต์„ ๋” ๊นŠ๊ฒŒ ์Œ“์œผ๋ฉด์„œ ๊ณ„์‚ฐ์— ๋Œ€ํ•œ ๋น„์šฉ์„ ์ค„์ผ ์ˆ˜ ์žˆ๋‹ค.
  • ํ™œ์„ฑํ™” ํ•จ์ˆ˜๊ฐ€ ๊ธฐ์กด๋ณด๋‹ค ๋” ๋งŽ์ด ํฌํ•จ๋˜๊ธฐ ๊นจ๋ฌธ์— ๋” ๋งŽ์€ ๋น„์„ ํ˜•์„ฑ์„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค.(๋‹ค์–‘ํ•œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ์ฒ˜๋ฆฌ๊ฐ€ ๊ฐ€๋Šฅํ•˜๋‹ค.)
# ๋ณ‘๋ชฉ ๋ธ”๋ก
# ResNet50,ResNet101,ResNet152์—์„œ ์ ์šฉ๋จ 
class Bottleneck(nn.Module):    
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
        super().__init__()    
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, bias = False) # 1x1ํ•ฉ์„ฑ๊ณฑ์ธต
        self.bn1 = nn.BatchNorm2d(out_channels)        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = stride, padding = 1, bias = False) # 3x3ํ•ฉ์„ฑ๊ณฑ์ธต
        self.bn2 = nn.BatchNorm2d(out_channels)        
        self.conv3 = nn.Conv2d(out_channels, self.expansion * out_channels, kernel_size = 1,
                               stride = 1, bias = False) 
# 1x1ํ•ฉ์„ฑ๊ณฑ์ธต
# self.expansion * out_channels : ๋‹ค์Œ ๊ณ„์ธต์˜ ์ž…๋ ฅ ์ฑ„๋„ ์ˆ˜์™€ ์ผ์น˜ํ•˜๋„๋กํ•˜๊ธฐ ์œ„ํ•จ
        self.bn3 = nn.BatchNorm2d(self.expansion * out_channels)  
        	   
        self.relu = nn.ReLU(inplace = True)
        
        if downsample:
            conv = nn.Conv2d(in_channels, self.expansion * out_channels, kernel_size = 1, 
                             stride = stride, bias = False)
            bn = nn.BatchNorm2d(self.expansion * out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None            
        self.downsample = downsample
        
    def forward(self, x):        
        i = x        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)        
        x = self.conv3(x)
        x = self.bn3(x)
                
        if self.downsample is not None:
            i = self.downsample(i)
            
        x += i
        x = self.relu(x)
    
        return x

Model

1) ์ง์ ‘ layer๋งŒ๋“ค๊ธฐ

  • class ์ •์˜
class ResNet(nn.Module):
    def __init__(self, config, output_dim, zero_init_residual=False):
        super().__init__()
                
        block, n_blocks, channels = config # - โ‘ 
        self.in_channels = channels[0]            
        assert len(n_blocks) == len(channels) == 4
        
        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        
        self.layer1 = self.get_resnet_layer(block,  n_blocks[0], channels[0]) # - โ‘ก
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride = 2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride = 2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride = 2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self.in_channels, output_dim)

        if zero_init_residual: # - โ‘ข
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
        
    def get_resnet_layer(self, block, n_blocks, channels, stride = 1):   # - โ‘ฃ
        layers = []        
        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False
        
        layers.append(block(self.in_channels, channels, stride, downsample))
        
        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels            
        return nn.Sequential(*layers)
        
    def forward(self, x):        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)        
        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)        
        return x, h

โ‘  config ํ™˜๊ฒฝ ์„ค์ •(์–ด๋–ค resnet์„ ์‚ฌ์šฉํ• ์ง€์— ๋”ฐ๋ผ ๊ฐ’์ด ๋‹ฌ๋ผ์ง„๋‹ค.)
โ‘ก ์™„์ „ ์—ฐ๊ฒฐ ์ธต์œผ๋กœ get_resnet_layerํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•ด layer๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
โ‘ข ๊ฐ ๋ ˆ์ง€๋“€์–ผ ๋ถ„๊ธฐ์— ์žˆ๋Š” ๋งˆ์ง€๋ง‰ Batch Normalization์„ 0์œผ๋กœ ์ดˆ๊ธฐํ™”ํ•ด์„œ ๋‹ค์Œ ๋ ˆ์ง€๋“€์–ผ ๋ถ„๊ธฐ๋ฅผ 0์—์„œ ์‹œ๊ฐ์ž˜ ์ˆ˜ ์žˆ๋„๋กํ•œ๋‹ค.
-> ํ•„์ˆ˜๋Š” ์•„๋‹ˆ๋‚˜ 0์œผ๋กœ ์ดˆ๊ธฐํ™”๋ฅผ ํ•  ๊ฒฝ์šฐ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์ด 0.2 ~ 0.3% ์ •๋„ ํ–ฅ์ƒ๋œ๋‹ค.
โ‘ฃ ์•„๋ž˜์˜ ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด ๊ตฌํ˜„ํ•˜๊ณ ์ž ํ•˜๋Š” ResNet์„ ์ •์˜ํ•œ๋‹ค.

ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])
  • block : ๊ธฐ๋ณธ ๋ธ”๋Ÿญ(BasicBlock)์ธ์ง€ ๋ณ‘๋ชฉ ๋ธ”๋Ÿญ(Bottleneck)์ธ์ง€๋ฅผ ๊ฒฐ์ •ํ•œ๋‹ค.
  • n_blocks : ๋ธ”๋Ÿญ์˜ ์ˆ˜๋ฅผ ๊ฒฐ์ •ํ•œ๋‹ค.
  • channels : ์ฑ„๋„์˜ ์ˆ˜๋ฅผ ๊ฒฐ์ •ํ•œ๋‹ค.
# ๊ธฐ๋ณธ ๋ธ”๋Ÿญ์„ ์‚ฌ์šฉํ•œ ResNet
resnet18_config = ResNetConfig(block = BasicBlock,
                               n_blocks = [2,2,2,2],
                               channels = [64, 128, 256, 512])

resnet34_config = ResNetConfig(block = BasicBlock,
                               n_blocks = [3,4,6,3],
                               channels = [64, 128, 256, 512])
                               
# ๋ณ‘๋ชฉ ๋ธ”๋Ÿญ์„ ์‚ฌ์šฉํ•œ ResNet
resnet50_config = ResNetConfig(block = Bottleneck,
                               n_blocks = [3, 4, 6, 3],
                               channels = [64, 128, 256, 512])

resnet101_config = ResNetConfig(block = Bottleneck,
                                n_blocks = [3, 4, 23, 3],
                                channels = [64, 128, 256, 512])

resnet152_config = ResNetConfig(block = Bottleneck,
                                n_blocks = [3, 8, 36, 3],
                                channels = [64, 128, 256, 512])

2) ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ชจ๋ธ ์‚ฌ์šฉํ•˜๊ธฐ

pretrained_model = models.resnet50(pretrained = True)

์ถœ์ฒ˜

  • ๋”ฅ๋Ÿฌ๋‹ ํŒŒ์ดํ† ์น˜ ๊ต๊ณผ์„œ
profile
๋‚ด์ผ์˜ ๋‚˜๋Š” ์˜ค๋Š˜๋ณด๋‹ค ๋” ๋‚˜์•„์ง€๊ธฐ๋ฅผ :D
post-custom-banner

0๊ฐœ์˜ ๋Œ“๊ธ€