Inception Net 을 이해하고 Pytorch로 구현할 수 있다.
Inception module
1x1 conv로 dimension reduction -> 파라미터 수 감소 (Figure 2b)
예를 들어, 3x3 conv를 이용해서 필터 수를 192 -> 128로 줄일 때,
Auxiliary classifier 사용
모델이 깊어짐에 따라 vanishing gradient를 막기 위해 사용.
LRN(Local Response Normalization)을 사용 (요즘은 잘 쓰이지 않음.)
import torch
from torch import nn
from torchinfo import summary
class conv_block(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs)
self.batchnorm = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.batchnorm(x)
x = self.relu(x)
return x
class Inception_block(nn.Module):
def __init__(self, in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool):
super().__init__()
self.branch1 = conv_block(in_channels, out_1x1, kernel_size = 1)
self.branch2 = nn.Sequential(
conv_block(in_channels, red_3x3, kernel_size = 1),
conv_block(red_3x3, out_3x3, kernel_size = 3, padding = 1)
)
self.branch3 = nn.Sequential(
conv_block(in_channels, red_5x5, kernel_size = 1),
conv_block(red_5x5, out_5x5, kernel_size = 5, padding = 2)
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
conv_block(in_channels, out_1x1pool, kernel_size = 1)
)
def forward(self, x):
return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim = 1) # 채널 기준으로 concat
class Inception_Aux(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.avgpool = nn.AvgPool2d(kernel_size = 5, stride = 3)
self.conv = conv_block(in_channels, 128, kernel_size = 1)
self.fc1 = nn.Linear(128 * 4 * 4, 1024)
self.relu = nn.ReLU()
self.dropout = nn.Dropout2d(p = 0.7)
self.fc2 = nn.Linear(1024, num_classes)
def forward(self, x):
x = self.avgpool(x) # Aux1: N x 512 x 14 x 14 -> N x 512 x 4 x 4
# Aux2: N x 528 x 14 x 14 -> N x 528 x 4 x 4
x = self.conv(x) # Aux1: N x 512 x 4 x 4 -> N x 128 x 4 x 4
# Aux2: N x 528 x 4 x 4 -> N x 128 x 4 x 4
x = torch.flatten(x, 1) # N x 128 x 4 x 4 -> N x 2048
x = self.fc1(x) # N x 2048 -> N x 1024
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x) # N x 1024 -> N x 1000
return x
class Inception_V1(nn.Module):
def __init__(self, num_classes = 1000, use_aux = True):
super().__init__()
in_channels = 3 #RGB
self.conv1 = conv_block(in_channels, 64, kernel_size = 7, stride = 2, padding = 3)
self.maxpool1 = nn.MaxPool2d(3, stride = 2, padding = 1)
self.conv2a = conv_block(64, 64, kernel_size = 1) # 표에는 없는데 그림에 존재. out_channels는 안 나와 있음.
self.conv2b = conv_block(64, 192, kernel_size = 3, stride = 1, padding = 1)
self.maxpool2 = nn.MaxPool2d(3, stride = 2, padding = 1)
# In this order: in_channels, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, out_1x1pool
self.inception_3a = Inception_block(192, 64, 96, 128, 16, 32, 32)
self.inception_3b = Inception_block(256, 128, 128, 192, 32, 96, 64)
self.maxpool3 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
self.inception_4a = Inception_block(480, 192, 96, 208, 16, 48, 64)
self.inception_4b = Inception_block(512, 160, 112, 224, 24, 64, 64)
self.inception_4c = Inception_block(512, 128, 128, 256, 24, 64, 64)
self.inception_4d = Inception_block(512, 112, 144, 288, 32, 64, 64)
self.inception_4e = Inception_block(528, 256, 160, 320, 32, 128, 128)
self.maxpool4 = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
self.inception_5a = Inception_block(832, 256, 160, 320, 32, 128, 128)
self.inception_5b = Inception_block(832, 384, 192, 384, 48, 128, 128)
if use_aux:
self.aux1 = Inception_Aux(512, num_classes)
self.aux2 = Inception_Aux(528, num_classes)
else:
self.aux1 = None
self.aux2 = None
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) #GAP
self.dropout = nn.Dropout(p = 0.4)
self.linear= nn.Linear(1024, num_classes)
def forward(self, x):
x = self.conv1(x) # N x 3 x 224 x 224 -> N x 3 x 112 x 112
x = self.maxpool1(x) # N x 64 x 112 x 112 -> N x 64 x 56 x 56
x = self.conv2a(x) # N x 64 x 56 x 56 -> N x 64 x 56 x 56
x = self.conv2b(x) # N x 64 x 56 x 56 -> N x 192 x 56 x 56
x = self.maxpool2(x) # N x 192 x 56 x 56 -> N x 192 x 28 x 28
x = self.inception_3a(x) # N x 192 x 28 x 28 -> N x 256 x 28 x 28
x = self.inception_3b(x) # N x 256 x 28 x 28 -> N x 480 x 28 x 28
x = self.maxpool3(x) # N x 480 x 28 x 28 -> N x 480 x 14 x 14
x = self.inception_4a(x) # N x 480 x 14 x 14 -> N x 512 x 14 x 14
if self.aux1 is not None and self.training:
aux1 = self.aux1(x) # N x 512 x 14 x 14 -> N x 1000
else:
aux1 = None
x = self.inception_4b(x) # N x 512 x 14 x 14 -> N x 512 x 14 x 14
x = self.inception_4c(x) # N x 512 x 14 x 14 -> N x 512 x 14 x 14
x = self.inception_4d(x) # N x 512 x 14 x 14 -> N x 528 x 14 x 14
if self.aux2 is not None and self.training:
aux2 = self.aux2(x) # N x 528 x 14 x 14 -> N x 1000
else:
aux2 = None
x = self.inception_4e(x) # N x 528 x 14 x 14 -> N x 832 x 14 x 14
x = self.maxpool4(x) # N x 832 x 14 x 14 -> N x 832 x 7 x 7
x = self.inception_5a(x) # N x 832 x 7 x 7 -> N x 832 x 7 x 7
x = self.inception_5b(x) # N x 832 x 7 x 7 -> N x 1024 x 7 x 7
x = self.avgpool(x) # N x 1024 x 7 x 7 -> N x 1024 x 1 x 1
x = torch.flatten(x, 1) # N x 1024 x 1 x 1 -> N x 1024
x = self.dropout(x) # Dropout
x = self.linear(x) # N x 1024 -> N x 1000
return x, aux2, aux1
model = Inception_V1()
summary(model, input_size=(2,3,224,224), device='cpu')
#### Output ####
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Inception_V1 [2, 1000] 6,379,984
├─conv_block: 1-1 [2, 64, 112, 112] --
│ └─Conv2d: 2-1 [2, 64, 112, 112] 9,408
│ └─BatchNorm2d: 2-2 [2, 64, 112, 112] 128
│ └─ReLU: 2-3 [2, 64, 112, 112] --
├─MaxPool2d: 1-2 [2, 64, 56, 56] --
├─conv_block: 1-3 [2, 64, 56, 56] --
│ └─Conv2d: 2-4 [2, 64, 56, 56] 4,096
│ └─BatchNorm2d: 2-5 [2, 64, 56, 56] 128
│ └─ReLU: 2-6 [2, 64, 56, 56] --
├─conv_block: 1-4 [2, 192, 56, 56] --
│ └─Conv2d: 2-7 [2, 192, 56, 56] 110,592
│ └─BatchNorm2d: 2-8 [2, 192, 56, 56] 384
│ └─ReLU: 2-9 [2, 192, 56, 56] --
├─MaxPool2d: 1-5 [2, 192, 28, 28] --
├─Inception_block: 1-6 [2, 256, 28, 28] --
│ └─conv_block: 2-10 [2, 64, 28, 28] --
│ │ └─Conv2d: 3-1 [2, 64, 28, 28] 12,288
│ │ └─BatchNorm2d: 3-2 [2, 64, 28, 28] 128
│ │ └─ReLU: 3-3 [2, 64, 28, 28] --
│ └─Sequential: 2-11 [2, 128, 28, 28] --
│ │ └─conv_block: 3-4 [2, 96, 28, 28] 18,624
│ │ └─conv_block: 3-5 [2, 128, 28, 28] 110,848
│ └─Sequential: 2-12 [2, 32, 28, 28] --
│ │ └─conv_block: 3-6 [2, 16, 28, 28] 3,104
│ │ └─conv_block: 3-7 [2, 32, 28, 28] 12,864
│ └─Sequential: 2-13 [2, 32, 28, 28] --
│ │ └─MaxPool2d: 3-8 [2, 192, 28, 28] --
│ │ └─conv_block: 3-9 [2, 32, 28, 28] 6,208
├─Inception_block: 1-7 [2, 480, 28, 28] --
│ └─conv_block: 2-14 [2, 128, 28, 28] --
│ │ └─Conv2d: 3-10 [2, 128, 28, 28] 32,768
│ │ └─BatchNorm2d: 3-11 [2, 128, 28, 28] 256
│ │ └─ReLU: 3-12 [2, 128, 28, 28] --
│ └─Sequential: 2-15 [2, 192, 28, 28] --
│ │ └─conv_block: 3-13 [2, 128, 28, 28] 33,024
│ │ └─conv_block: 3-14 [2, 192, 28, 28] 221,568
│ └─Sequential: 2-16 [2, 96, 28, 28] --
│ │ └─conv_block: 3-15 [2, 32, 28, 28] 8,256
│ │ └─conv_block: 3-16 [2, 96, 28, 28] 76,992
│ └─Sequential: 2-17 [2, 64, 28, 28] --
│ │ └─MaxPool2d: 3-17 [2, 256, 28, 28] --
│ │ └─conv_block: 3-18 [2, 64, 28, 28] 16,512
├─MaxPool2d: 1-8 [2, 480, 14, 14] --
├─Inception_block: 1-9 [2, 512, 14, 14] --
│ └─conv_block: 2-18 [2, 192, 14, 14] --
│ │ └─Conv2d: 3-19 [2, 192, 14, 14] 92,160
│ │ └─BatchNorm2d: 3-20 [2, 192, 14, 14] 384
│ │ └─ReLU: 3-21 [2, 192, 14, 14] --
│ └─Sequential: 2-19 [2, 208, 14, 14] --
│ │ └─conv_block: 3-22 [2, 96, 14, 14] 46,272
│ │ └─conv_block: 3-23 [2, 208, 14, 14] 180,128
│ └─Sequential: 2-20 [2, 48, 14, 14] --
│ │ └─conv_block: 3-24 [2, 16, 14, 14] 7,712
│ │ └─conv_block: 3-25 [2, 48, 14, 14] 19,296
│ └─Sequential: 2-21 [2, 64, 14, 14] --
│ │ └─MaxPool2d: 3-26 [2, 480, 14, 14] --
│ │ └─conv_block: 3-27 [2, 64, 14, 14] 30,848
├─Inception_block: 1-10 [2, 512, 14, 14] --
│ └─conv_block: 2-22 [2, 160, 14, 14] --
│ │ └─Conv2d: 3-28 [2, 160, 14, 14] 81,920
│ │ └─BatchNorm2d: 3-29 [2, 160, 14, 14] 320
│ │ └─ReLU: 3-30 [2, 160, 14, 14] --
│ └─Sequential: 2-23 [2, 224, 14, 14] --
│ │ └─conv_block: 3-31 [2, 112, 14, 14] 57,568
│ │ └─conv_block: 3-32 [2, 224, 14, 14] 226,240
│ └─Sequential: 2-24 [2, 64, 14, 14] --
│ │ └─conv_block: 3-33 [2, 24, 14, 14] 12,336
│ │ └─conv_block: 3-34 [2, 64, 14, 14] 38,528
│ └─Sequential: 2-25 [2, 64, 14, 14] --
│ │ └─MaxPool2d: 3-35 [2, 512, 14, 14] --
│ │ └─conv_block: 3-36 [2, 64, 14, 14] 32,896
├─Inception_block: 1-11 [2, 512, 14, 14] --
│ └─conv_block: 2-26 [2, 128, 14, 14] --
│ │ └─Conv2d: 3-37 [2, 128, 14, 14] 65,536
│ │ └─BatchNorm2d: 3-38 [2, 128, 14, 14] 256
│ │ └─ReLU: 3-39 [2, 128, 14, 14] --
│ └─Sequential: 2-27 [2, 256, 14, 14] --
│ │ └─conv_block: 3-40 [2, 128, 14, 14] 65,792
│ │ └─conv_block: 3-41 [2, 256, 14, 14] 295,424
│ └─Sequential: 2-28 [2, 64, 14, 14] --
│ │ └─conv_block: 3-42 [2, 24, 14, 14] 12,336
│ │ └─conv_block: 3-43 [2, 64, 14, 14] 38,528
│ └─Sequential: 2-29 [2, 64, 14, 14] --
│ │ └─MaxPool2d: 3-44 [2, 512, 14, 14] --
│ │ └─conv_block: 3-45 [2, 64, 14, 14] 32,896
├─Inception_block: 1-12 [2, 528, 14, 14] --
│ └─conv_block: 2-30 [2, 112, 14, 14] --
│ │ └─Conv2d: 3-46 [2, 112, 14, 14] 57,344
│ │ └─BatchNorm2d: 3-47 [2, 112, 14, 14] 224
│ │ └─ReLU: 3-48 [2, 112, 14, 14] --
│ └─Sequential: 2-31 [2, 288, 14, 14] --
│ │ └─conv_block: 3-49 [2, 144, 14, 14] 74,016
│ │ └─conv_block: 3-50 [2, 288, 14, 14] 373,824
│ └─Sequential: 2-32 [2, 64, 14, 14] --
│ │ └─conv_block: 3-51 [2, 32, 14, 14] 16,448
│ │ └─conv_block: 3-52 [2, 64, 14, 14] 51,328
│ └─Sequential: 2-33 [2, 64, 14, 14] --
│ │ └─MaxPool2d: 3-53 [2, 512, 14, 14] --
│ │ └─conv_block: 3-54 [2, 64, 14, 14] 32,896
├─Inception_block: 1-13 [2, 832, 14, 14] --
│ └─conv_block: 2-34 [2, 256, 14, 14] --
│ │ └─Conv2d: 3-55 [2, 256, 14, 14] 135,168
│ │ └─BatchNorm2d: 3-56 [2, 256, 14, 14] 512
│ │ └─ReLU: 3-57 [2, 256, 14, 14] --
│ └─Sequential: 2-35 [2, 320, 14, 14] --
│ │ └─conv_block: 3-58 [2, 160, 14, 14] 84,800
│ │ └─conv_block: 3-59 [2, 320, 14, 14] 461,440
│ └─Sequential: 2-36 [2, 128, 14, 14] --
│ │ └─conv_block: 3-60 [2, 32, 14, 14] 16,960
│ │ └─conv_block: 3-61 [2, 128, 14, 14] 102,656
│ └─Sequential: 2-37 [2, 128, 14, 14] --
│ │ └─MaxPool2d: 3-62 [2, 528, 14, 14] --
│ │ └─conv_block: 3-63 [2, 128, 14, 14] 67,840
├─MaxPool2d: 1-14 [2, 832, 7, 7] --
├─Inception_block: 1-15 [2, 832, 7, 7] --
│ └─conv_block: 2-38 [2, 256, 7, 7] --
│ │ └─Conv2d: 3-64 [2, 256, 7, 7] 212,992
│ │ └─BatchNorm2d: 3-65 [2, 256, 7, 7] 512
│ │ └─ReLU: 3-66 [2, 256, 7, 7] --
│ └─Sequential: 2-39 [2, 320, 7, 7] --
│ │ └─conv_block: 3-67 [2, 160, 7, 7] 133,440
│ │ └─conv_block: 3-68 [2, 320, 7, 7] 461,440
│ └─Sequential: 2-40 [2, 128, 7, 7] --
│ │ └─conv_block: 3-69 [2, 32, 7, 7] 26,688
│ │ └─conv_block: 3-70 [2, 128, 7, 7] 102,656
│ └─Sequential: 2-41 [2, 128, 7, 7] --
│ │ └─MaxPool2d: 3-71 [2, 832, 7, 7] --
│ │ └─conv_block: 3-72 [2, 128, 7, 7] 106,752
├─Inception_block: 1-16 [2, 1024, 7, 7] --
│ └─conv_block: 2-42 [2, 384, 7, 7] --
│ │ └─Conv2d: 3-73 [2, 384, 7, 7] 319,488
│ │ └─BatchNorm2d: 3-74 [2, 384, 7, 7] 768
│ │ └─ReLU: 3-75 [2, 384, 7, 7] --
│ └─Sequential: 2-43 [2, 384, 7, 7] --
│ │ └─conv_block: 3-76 [2, 192, 7, 7] 160,128
│ │ └─conv_block: 3-77 [2, 384, 7, 7] 664,320
│ └─Sequential: 2-44 [2, 128, 7, 7] --
│ │ └─conv_block: 3-78 [2, 48, 7, 7] 40,032
│ │ └─conv_block: 3-79 [2, 128, 7, 7] 153,856
│ └─Sequential: 2-45 [2, 128, 7, 7] --
│ │ └─MaxPool2d: 3-80 [2, 832, 7, 7] --
│ │ └─conv_block: 3-81 [2, 128, 7, 7] 106,752
├─AdaptiveAvgPool2d: 1-17 [2, 1024, 1, 1] --
├─Dropout: 1-18 [2, 1024] --
├─Linear: 1-19 [2, 1000] 1,025,000
==========================================================================================
Total params: 13,385,816
Trainable params: 13,385,816
Non-trainable params: 0
Total mult-adds (G): 3.17
==========================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 103.25
Params size (MB): 28.02
Estimated Total Size (MB): 132.48
==========================================================================================