[Pytorch] Inception Net v1 구현

도룩·2023년 11월 12일
0

목적

Inception Net 을 이해하고 Pytorch로 구현할 수 있다.

Architecture

  • 설명
    • inception module내의 conv layer의 stride는 모두 1로 고정.
    • # 3x3 reduce: 3x3 conv 이전 1x1 conv의 필터 수
    • pool proj: pooling layer 이후에 있는 1x1 conv의 필터 수
    • inception의 output size의 채널 수는 #1x1 + #3x3 + #5x5 + pool proj 한 것과 같다. (concat)
    • inception module 내의 conv 중, 특히 #5x5의 채널 수가 작은 것을 볼 수 있다.
      -> 다른 conv에 비해 5x5 conv가 많은 파라미터를 갖기 때문인 것으로 해석될 수 있다.

특징

  • Inception module

    • 여러 conv를 사용함으로써 다양한 receptive field를 갖는 feature map을 concat해서 사용하겠다는 취지의 module
    • 여러 사이즈를 가진 conv layer를 통과시켜 나온 feature map들을 채널 (depth) 방향으로 concat 하는 방식으로 구성.
      (inception module로 들어가는 input과 output의 size를 같게 함으로써 concat이 가능하게 하였음.)
      이 논문에서는 inception module내의 conv layer의 stride는 모두 1로 고정하였음. -> padding을 적절히 조절하여 사이즈를 맞춰줌.
      -> 1x1 conv: padding = 0
      -> 3x3 conv: padding = 1
      -> 5x5 conv: padding = 2
      -> 3x3 maxpool: padding = 2
  • 1x1 conv로 dimension reduction -> 파라미터 수 감소 (Figure 2b)
    예를 들어, 3x3 conv를 이용해서 필터 수를 192 -> 128로 줄일 때,

    • 1x1 conv 사용 안 했을 때의 파라미터 수: 128 x 192 x 3 x 3 = 221,184
    • 1x1 conv 사용 했을 때 파라미터 수 (1x1 conv 필터 수 96 이라고 가정): (96 x 192 x 1 x 1) + (128 x 96 x 3 x 3) = 129,024
  • Auxiliary classifier 사용
    모델이 깊어짐에 따라 vanishing gradient를 막기 위해 사용.

    • Loss = out_loss + 0.3(aux1_loss + aux2_loss)
    • Training 할 때만 존재.
    • inception 4a 이후, inception. 4d 이후에 하나 씩 존재.
  • LRN(Local Response Normalization)을 사용 (요즘은 잘 쓰이지 않음.)

    • 해당 feature map의 각각의 픽셀 값에 대해 normalization
      -> 주변 feature map의 같은 위치의 픽셀값끼리 제곱하고 더해서 나누어주는 방식
  • VGGNet 보다 훨씬 적은 파라미터 수를 가지면서 더 좋은 성능.
  • GoogLeNet이라고도 불림.

Code

환경

  • python 3.8.16
  • pytorch 2.1.0
  • torchinfo 1.8.0

구현

import torch
from torch import nn
from torchinfo import summary
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

class Inception_block(nn.Module):
    def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool):
        super().__init__()

        self.branch1 = conv_block(in_channels, out_1x1, kernel_size = 1)
        self.branch2 = nn.Sequential(
            conv_block(in_channels, red_3x3, kernel_size = 1),
            conv_block(red_3x3, out_3x3, kernel_size = 3, padding = 1)
            )
        self.branch3 = nn.Sequential(
            conv_block(in_channels, red_5x5, kernel_size = 1),
            conv_block(red_5x5, out_5x5, kernel_size = 5, padding = 2)
            )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
            conv_block(in_channels, out_1x1pool, kernel_size = 1)
        )
    
    def forward(self, x):
        return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim = 1) # 채널 기준으로 concat

class Inception_Aux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.avgpool = nn.AvgPool2d(kernel_size = 5, stride = 3)
        self.conv = conv_block(in_channels, 128, kernel_size = 1)
        
        self.fc1 = nn.Linear(128 * 4 * 4, 1024)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout2d(p = 0.7)

        
        self.fc2 = nn.Linear(1024, num_classes)
    
    def forward(self, x):

        x = self.avgpool(x)     # Aux1: N x 512 x 14 x 14 -> N x 512 x 4 x 4
                                # Aux2: N x 528 x 14 x 14 -> N x 528 x 4 x 4

        x = self.conv(x)        # Aux1: N x 512 x 4 x 4 -> N x 128 x 4 x 4
                                # Aux2: N x 528 x 4 x 4 -> N x 128 x 4 x 4

        x = torch.flatten(x, 1) # N x 128 x 4 x 4 -> N x 2048
        x = self.fc1(x)         # N x 2048 -> N x 1024
        x = self.relu(x)
        x = self.dropout(x)

        
        x = self.fc2(x)         # N x 1024 -> N x 1000

        return x

  
class Inception_V1(nn.Module):
    def __init__(self, num_classes = 1000, use_aux = True):
        super().__init__()
        in_channels = 3 #RGB

        self.conv1 = conv_block(in_channels, 64, kernel_size = 7, stride = 2, padding = 3)
        self.maxpool1 = nn.MaxPool2d(3, stride = 2, padding = 1)
        self.conv2a = conv_block(64, 64, kernel_size = 1) # 표에는 없는데 그림에 존재. out_channels는 안 나와 있음.
        self.conv2b = conv_block(64, 192, kernel_size = 3, stride = 1, padding = 1)
        self.maxpool2 = nn.MaxPool2d(3, stride = 2, padding = 1)

        # In this order: in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
        self.inception_3a = Inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception_3b = Inception_block(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)

        self.inception_4a = Inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception_4b = Inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception_4c = Inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception_4d = Inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception_4e = Inception_block(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)

        self.inception_5a = Inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception_5b = Inception_block(832, 384, 192, 384, 48, 128, 128)

        if use_aux:
            self.aux1 = Inception_Aux(512, num_classes)
            self.aux2 = Inception_Aux(528, num_classes)
        else:
            self.aux1 = None
            self.aux2 = None

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) #GAP
        self.dropout = nn.Dropout(p = 0.4)
        self.linear=  nn.Linear(1024, num_classes)
    
    def forward(self, x):
        x = self.conv1(x)           # N x 3 x 224 x 224 -> N x 3 x 112 x 112
        x = self.maxpool1(x)        # N x 64 x 112 x 112 -> N x 64 x 56 x 56
        x = self.conv2a(x)          # N x 64 x 56 x 56 -> N x 64 x 56 x 56
        x = self.conv2b(x)          # N x 64 x 56 x 56 -> N x 192 x 56 x 56
        x = self.maxpool2(x)        # N x 192 x 56 x 56 -> N x 192 x 28 x 28

        x = self.inception_3a(x)    # N x 192 x 28 x 28 -> N x 256 x 28 x 28
        x = self.inception_3b(x)    # N x 256 x 28 x 28 -> N x 480 x 28 x 28
        x = self.maxpool3(x)        # N x 480 x 28 x 28 -> N x 480 x 14 x 14
        x = self.inception_4a(x)    # N x 480 x 14 x 14 -> N x 512 x 14 x 14
        if self.aux1 is not None and self.training:
            aux1 = self.aux1(x)     # N x 512 x 14 x 14 -> N x 1000
        else:
            aux1 = None

        x = self.inception_4b(x)    # N x 512 x 14 x 14 -> N x 512 x 14 x 14
        x = self.inception_4c(x)    # N x 512 x 14 x 14 -> N x 512 x 14 x 14
        x = self.inception_4d(x)    # N x 512 x 14 x 14 -> N x 528 x 14 x 14
        if self.aux2 is not None and self.training:
            aux2 = self.aux2(x)     # N x 528 x 14 x 14 -> N x 1000
        else:
            aux2 = None

        x = self.inception_4e(x)     # N x 528 x 14 x 14 -> N x 832 x 14 x 14
        x = self.maxpool4(x)         # N x 832 x 14 x 14 -> N x 832 x 7 x 7
        x = self.inception_5a(x)     # N x 832 x 7 x 7 -> N x 832 x 7 x 7
        x = self.inception_5b(x)     # N x 832 x 7 x 7 -> N x 1024 x 7 x 7

        x = self.avgpool(x)         # N x 1024 x 7 x 7 -> N x 1024 x 1 x 1
        x = torch.flatten(x, 1)     # N x 1024 x 1 x 1 -> N x 1024
        x = self.dropout(x)         # Dropout
        x = self.linear(x)          # N x 1024 -> N x 1000

        return x, aux2, aux1
model = Inception_V1()
summary(model, input_size=(2,3,224,224), device='cpu')
#### Output ####
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Inception_V1                             [2, 1000]                 6,379,984
├─conv_block: 1-1                        [2, 64, 112, 112]         --
│    └─Conv2d: 2-1                       [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                  [2, 64, 112, 112]         128
│    └─ReLU: 2-3                         [2, 64, 112, 112]         --
├─MaxPool2d: 1-2                         [2, 64, 56, 56]           --
├─conv_block: 1-3                        [2, 64, 56, 56]           --
│    └─Conv2d: 2-4                       [2, 64, 56, 56]           4,096
│    └─BatchNorm2d: 2-5                  [2, 64, 56, 56]           128
│    └─ReLU: 2-6                         [2, 64, 56, 56]           --
├─conv_block: 1-4                        [2, 192, 56, 56]          --
│    └─Conv2d: 2-7                       [2, 192, 56, 56]          110,592
│    └─BatchNorm2d: 2-8                  [2, 192, 56, 56]          384
│    └─ReLU: 2-9                         [2, 192, 56, 56]          --
├─MaxPool2d: 1-5                         [2, 192, 28, 28]          --
├─Inception_block: 1-6                   [2, 256, 28, 28]          --
│    └─conv_block: 2-10                  [2, 64, 28, 28]           --
│    │    └─Conv2d: 3-1                  [2, 64, 28, 28]           12,288
│    │    └─BatchNorm2d: 3-2             [2, 64, 28, 28]           128
│    │    └─ReLU: 3-3                    [2, 64, 28, 28]           --
│    └─Sequential: 2-11                  [2, 128, 28, 28]          --
│    │    └─conv_block: 3-4              [2, 96, 28, 28]           18,624
│    │    └─conv_block: 3-5              [2, 128, 28, 28]          110,848
│    └─Sequential: 2-12                  [2, 32, 28, 28]           --
│    │    └─conv_block: 3-6              [2, 16, 28, 28]           3,104
│    │    └─conv_block: 3-7              [2, 32, 28, 28]           12,864
│    └─Sequential: 2-13                  [2, 32, 28, 28]           --
│    │    └─MaxPool2d: 3-8               [2, 192, 28, 28]          --
│    │    └─conv_block: 3-9              [2, 32, 28, 28]           6,208
├─Inception_block: 1-7                   [2, 480, 28, 28]          --
│    └─conv_block: 2-14                  [2, 128, 28, 28]          --
│    │    └─Conv2d: 3-10                 [2, 128, 28, 28]          32,768
│    │    └─BatchNorm2d: 3-11            [2, 128, 28, 28]          256
│    │    └─ReLU: 3-12                   [2, 128, 28, 28]          --
│    └─Sequential: 2-15                  [2, 192, 28, 28]          --
│    │    └─conv_block: 3-13             [2, 128, 28, 28]          33,024
│    │    └─conv_block: 3-14             [2, 192, 28, 28]          221,568
│    └─Sequential: 2-16                  [2, 96, 28, 28]           --
│    │    └─conv_block: 3-15             [2, 32, 28, 28]           8,256
│    │    └─conv_block: 3-16             [2, 96, 28, 28]           76,992
│    └─Sequential: 2-17                  [2, 64, 28, 28]           --
│    │    └─MaxPool2d: 3-17              [2, 256, 28, 28]          --
│    │    └─conv_block: 3-18             [2, 64, 28, 28]           16,512
├─MaxPool2d: 1-8                         [2, 480, 14, 14]          --
├─Inception_block: 1-9                   [2, 512, 14, 14]          --
│    └─conv_block: 2-18                  [2, 192, 14, 14]          --
│    │    └─Conv2d: 3-19                 [2, 192, 14, 14]          92,160
│    │    └─BatchNorm2d: 3-20            [2, 192, 14, 14]          384
│    │    └─ReLU: 3-21                   [2, 192, 14, 14]          --
│    └─Sequential: 2-19                  [2, 208, 14, 14]          --
│    │    └─conv_block: 3-22             [2, 96, 14, 14]           46,272
│    │    └─conv_block: 3-23             [2, 208, 14, 14]          180,128
│    └─Sequential: 2-20                  [2, 48, 14, 14]           --
│    │    └─conv_block: 3-24             [2, 16, 14, 14]           7,712
│    │    └─conv_block: 3-25             [2, 48, 14, 14]           19,296
│    └─Sequential: 2-21                  [2, 64, 14, 14]           --
│    │    └─MaxPool2d: 3-26              [2, 480, 14, 14]          --
│    │    └─conv_block: 3-27             [2, 64, 14, 14]           30,848
├─Inception_block: 1-10                  [2, 512, 14, 14]          --
│    └─conv_block: 2-22                  [2, 160, 14, 14]          --
│    │    └─Conv2d: 3-28                 [2, 160, 14, 14]          81,920
│    │    └─BatchNorm2d: 3-29            [2, 160, 14, 14]          320
│    │    └─ReLU: 3-30                   [2, 160, 14, 14]          --
│    └─Sequential: 2-23                  [2, 224, 14, 14]          --
│    │    └─conv_block: 3-31             [2, 112, 14, 14]          57,568
│    │    └─conv_block: 3-32             [2, 224, 14, 14]          226,240
│    └─Sequential: 2-24                  [2, 64, 14, 14]           --
│    │    └─conv_block: 3-33             [2, 24, 14, 14]           12,336
│    │    └─conv_block: 3-34             [2, 64, 14, 14]           38,528
│    └─Sequential: 2-25                  [2, 64, 14, 14]           --
│    │    └─MaxPool2d: 3-35              [2, 512, 14, 14]          --
│    │    └─conv_block: 3-36             [2, 64, 14, 14]           32,896
├─Inception_block: 1-11                  [2, 512, 14, 14]          --
│    └─conv_block: 2-26                  [2, 128, 14, 14]          --
│    │    └─Conv2d: 3-37                 [2, 128, 14, 14]          65,536
│    │    └─BatchNorm2d: 3-38            [2, 128, 14, 14]          256
│    │    └─ReLU: 3-39                   [2, 128, 14, 14]          --
│    └─Sequential: 2-27                  [2, 256, 14, 14]          --
│    │    └─conv_block: 3-40             [2, 128, 14, 14]          65,792
│    │    └─conv_block: 3-41             [2, 256, 14, 14]          295,424
│    └─Sequential: 2-28                  [2, 64, 14, 14]           --
│    │    └─conv_block: 3-42             [2, 24, 14, 14]           12,336
│    │    └─conv_block: 3-43             [2, 64, 14, 14]           38,528
│    └─Sequential: 2-29                  [2, 64, 14, 14]           --
│    │    └─MaxPool2d: 3-44              [2, 512, 14, 14]          --
│    │    └─conv_block: 3-45             [2, 64, 14, 14]           32,896
├─Inception_block: 1-12                  [2, 528, 14, 14]          --
│    └─conv_block: 2-30                  [2, 112, 14, 14]          --
│    │    └─Conv2d: 3-46                 [2, 112, 14, 14]          57,344
│    │    └─BatchNorm2d: 3-47            [2, 112, 14, 14]          224
│    │    └─ReLU: 3-48                   [2, 112, 14, 14]          --
│    └─Sequential: 2-31                  [2, 288, 14, 14]          --
│    │    └─conv_block: 3-49             [2, 144, 14, 14]          74,016
│    │    └─conv_block: 3-50             [2, 288, 14, 14]          373,824
│    └─Sequential: 2-32                  [2, 64, 14, 14]           --
│    │    └─conv_block: 3-51             [2, 32, 14, 14]           16,448
│    │    └─conv_block: 3-52             [2, 64, 14, 14]           51,328
│    └─Sequential: 2-33                  [2, 64, 14, 14]           --
│    │    └─MaxPool2d: 3-53              [2, 512, 14, 14]          --
│    │    └─conv_block: 3-54             [2, 64, 14, 14]           32,896
├─Inception_block: 1-13                  [2, 832, 14, 14]          --
│    └─conv_block: 2-34                  [2, 256, 14, 14]          --
│    │    └─Conv2d: 3-55                 [2, 256, 14, 14]          135,168
│    │    └─BatchNorm2d: 3-56            [2, 256, 14, 14]          512
│    │    └─ReLU: 3-57                   [2, 256, 14, 14]          --
│    └─Sequential: 2-35                  [2, 320, 14, 14]          --
│    │    └─conv_block: 3-58             [2, 160, 14, 14]          84,800
│    │    └─conv_block: 3-59             [2, 320, 14, 14]          461,440
│    └─Sequential: 2-36                  [2, 128, 14, 14]          --
│    │    └─conv_block: 3-60             [2, 32, 14, 14]           16,960
│    │    └─conv_block: 3-61             [2, 128, 14, 14]          102,656
│    └─Sequential: 2-37                  [2, 128, 14, 14]          --
│    │    └─MaxPool2d: 3-62              [2, 528, 14, 14]          --
│    │    └─conv_block: 3-63             [2, 128, 14, 14]          67,840
├─MaxPool2d: 1-14                        [2, 832, 7, 7]            --
├─Inception_block: 1-15                  [2, 832, 7, 7]            --
│    └─conv_block: 2-38                  [2, 256, 7, 7]            --
│    │    └─Conv2d: 3-64                 [2, 256, 7, 7]            212,992
│    │    └─BatchNorm2d: 3-65            [2, 256, 7, 7]            512
│    │    └─ReLU: 3-66                   [2, 256, 7, 7]            --
│    └─Sequential: 2-39                  [2, 320, 7, 7]            --
│    │    └─conv_block: 3-67             [2, 160, 7, 7]            133,440
│    │    └─conv_block: 3-68             [2, 320, 7, 7]            461,440
│    └─Sequential: 2-40                  [2, 128, 7, 7]            --
│    │    └─conv_block: 3-69             [2, 32, 7, 7]             26,688
│    │    └─conv_block: 3-70             [2, 128, 7, 7]            102,656
│    └─Sequential: 2-41                  [2, 128, 7, 7]            --
│    │    └─MaxPool2d: 3-71              [2, 832, 7, 7]            --
│    │    └─conv_block: 3-72             [2, 128, 7, 7]            106,752
├─Inception_block: 1-16                  [2, 1024, 7, 7]           --
│    └─conv_block: 2-42                  [2, 384, 7, 7]            --
│    │    └─Conv2d: 3-73                 [2, 384, 7, 7]            319,488
│    │    └─BatchNorm2d: 3-74            [2, 384, 7, 7]            768
│    │    └─ReLU: 3-75                   [2, 384, 7, 7]            --
│    └─Sequential: 2-43                  [2, 384, 7, 7]            --
│    │    └─conv_block: 3-76             [2, 192, 7, 7]            160,128
│    │    └─conv_block: 3-77             [2, 384, 7, 7]            664,320
│    └─Sequential: 2-44                  [2, 128, 7, 7]            --
│    │    └─conv_block: 3-78             [2, 48, 7, 7]             40,032
│    │    └─conv_block: 3-79             [2, 128, 7, 7]            153,856
│    └─Sequential: 2-45                  [2, 128, 7, 7]            --
│    │    └─MaxPool2d: 3-80              [2, 832, 7, 7]            --
│    │    └─conv_block: 3-81             [2, 128, 7, 7]            106,752
├─AdaptiveAvgPool2d: 1-17                [2, 1024, 1, 1]           --
├─Dropout: 1-18                          [2, 1024]                 --
├─Linear: 1-19                           [2, 1000]                 1,025,000
==========================================================================================
Total params: 13,385,816
Trainable params: 13,385,816
Non-trainable params: 0
Total mult-adds (G): 3.17
==========================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 103.25
Params size (MB): 28.02
Estimated Total Size (MB): 132.48
==========================================================================================

0개의 댓글

관련 채용 정보