Inception Net v2, v3 를 이해하고 Pytorch로 구현할 수 있다.
Modules
Auxiliary classifier
Label Smoothing
Inception Net v2와 v3의 차이
import torch
from torch import nn
from torchinfo import summary
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs)
self.batchnorm = nn.BatchNorm2d(out_channels, eps = 0.001)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.batchnorm(x)
x = self.relu(x)
return x
class InceptionF5(nn.Module): # Figure 5
def __init__(self, in_channels):
super().__init__()
self.branch1 = nn.Sequential(
BasicConv2d(in_channels, 64, kernel_size = 1),
BasicConv2d(64, 96, kernel_size = 3, padding = 1),
BasicConv2d(96, 96, kernel_size = 3, padding = 1),
)
self.branch2 = nn.Sequential(
BasicConv2d(in_channels, 48, kernel_size = 1),
BasicConv2d(48, 64, kernel_size = 3, padding = 1),
)
self.branch3 = nn.Sequential(
nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1),
BasicConv2d(in_channels, 64, kernel_size = 1),
)
self.branch4 = BasicConv2d(in_channels, 64, kernel_size = 1)
def forward(self, x):
return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim = 1)
class InceptionF6(nn.Module): # Figure 6
def __init__(self, in_channels, f_7x7):
super().__init__()
self.branch1 = nn.Sequential(
BasicConv2d(in_channels, f_7x7, kernel_size = 1),
BasicConv2d(f_7x7, f_7x7, kernel_size = (1, 7), padding = (0, 3)),
BasicConv2d(f_7x7, f_7x7, kernel_size = (7, 1), padding = (3, 0)),
BasicConv2d(f_7x7, f_7x7, kernel_size = (1, 7), padding = (0, 3)),
BasicConv2d(f_7x7, 192, kernel_size = (7, 1), padding = (3, 0)),
)
self.branch2 = nn.Sequential(
BasicConv2d(in_channels, f_7x7, kernel_size = 1),
BasicConv2d(f_7x7, f_7x7, kernel_size = (1, 7), padding = (0, 3)),
BasicConv2d(f_7x7, 192, kernel_size = (7, 1), padding = (3, 0)),
)
self.branch3 = nn.Sequential(
nn.MaxPool2d(3, stride = 1, padding = 1),
BasicConv2d(in_channels, 192, kernel_size = 1),
)
self.branch4 = BasicConv2d(in_channels, 192, kernel_size = 1)
def forward(self, x):
return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x), self.branch4(x)], dim = 1)
class InceptionF7(nn.Module): # Figure 7
def __init__(self, in_channels):
super().__init__()
self.branch1_stem = nn.Sequential(
BasicConv2d(in_channels, 448, kernel_size = 1),
BasicConv2d(448, 384, kernel_size = 3, padding = 1),
)
self.branch1_left = BasicConv2d(384, 384, kernel_size = (1, 3), padding = (0, 1))
self.branch1_right = BasicConv2d(384, 384, kernel_size = (3, 1), padding = (1, 0))
self.branch2_stem = BasicConv2d(in_channels, 384, kernel_size = 1)
self.branch2_left = BasicConv2d(384, 384, kernel_size = (1, 3), padding = (0, 1))
self.branch2_right = BasicConv2d(384, 384, kernel_size = (3, 1), padding = (1, 0))
self.branch3 = nn.Sequential(
nn.MaxPool2d(3, stride = 1, padding = 1),
BasicConv2d(in_channels, 192, kernel_size = 1)
)
self.branch4 = BasicConv2d(in_channels, 320, kernel_size = 1)
def forward(self, x):
branch1_stem = self.branch1_stem(x)
branch2_stem = self.branch2_stem(x)
branch1 = torch.cat([self.branch1_left(branch1_stem), self.branch1_right(branch1_stem)], dim = 1)
branch2 = torch.cat([self.branch2_left(branch2_stem), self.branch2_right(branch2_stem)], dim = 1)
branch3 = self.branch3(x)
branch4 = self.branch4(x)
return torch.cat([branch1, branch2, branch3, branch4], dim = 1)
class Inception_ReduceA(nn.Module): # Figure 10 : conv (stride 2) -> pooling operation,
# 사람들 마다 코드가 조금씩 달라 pytorch source code를 이용.
def __init__(self, in_channels):
super().__init__()
self.branch1 = nn.Sequential(
BasicConv2d(in_channels, 64, kernel_size = 1),
BasicConv2d(64, 96, kernel_size = 3, padding = 1),
BasicConv2d(96, 96, kernel_size = 3, stride = 2),
)
self.branch2 = BasicConv2d(in_channels, 384, kernel_size = 3, stride = 2)
self.branch3 = nn.MaxPool2d(3, stride = 2)
def forward(self, x):
return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x)], dim = 1)
class Inception_ReduceB(nn.Module): # Figure 10 : conv (stride 2) -> pooling operation
# 사람들 마다 코드가 조금씩 달라 pytorch source code를 이용.
def __init__(self, in_channels):
super().__init__()
self.branch1 = nn.Sequential(
BasicConv2d(in_channels, 192, kernel_size = 1),
BasicConv2d(192, 192, kernel_size = (1, 7), padding = (0, 3)),
BasicConv2d(192, 192, kernel_size = (7, 1), padding = (3, 0)),
BasicConv2d(192, 192, kernel_size = 3, stride = 2)
)
self.branch2 = nn.Sequential(
BasicConv2d(in_channels, 192, kernel_size = 1),
BasicConv2d(192, 320, kernel_size = 3, stride = 2),
)
self.branch3 = nn.MaxPool2d(3, stride = 2)
def forward(self, x):
return torch.cat([self.branch1(x), self.branch2(x), self.branch3(x)], dim = 1)
class Inception_Aux(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d((5, 5)) # paper에는 nn.AvgPool2d(kernel_size = 5, stride = 3)
self.conv = BasicConv2d(in_channels, 128, kernel_size = 1)
self.fc1 = nn.Linear(5 * 5 * 128, 1024)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.7)
self.fc2 = nn.Linear(1024, num_classes)
def forward(self, x):
x = self.avgpool(x) # N x 768 x 17 x 17 -> N x 768 x 5 x 5
x = self.conv(x) # N x 768 x 5 x 5 -> N x 128 x 5 x 5
x = torch.flatten(x, 1) # N x 128 x 5 x 5 -> N x 3200
x = self.fc1(x) # N x 3200 -> N x 1024
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x) # N x 1024 -> N x 1000
return x
class Inception_V3(nn.Module):
def __init__(self, num_classes = 1000, use_aux = True, drop_p = 0.5):
super().__init__()
in_channels = 3 #RGB
self.conv1a = BasicConv2d(in_channels, 32, kernel_size = 3, stride = 2)
self.conv1b = BasicConv2d(32, 32, kernel_size = 3)
self.conv1c = BasicConv2d(32, 64, kernel_size = 3, padding = 1)
self.pool1 = nn.MaxPool2d(3, stride = 2)
self.conv2a = BasicConv2d(64, 80, kernel_size = 3)
self.conv2b = BasicConv2d(80, 192, kernel_size = 3, stride = 2)
self.conv2c = BasicConv2d(192, 288, kernel_size = 3, padding = 1)
self.inception3a = InceptionF5(288)
self.inception3b = InceptionF5(288)
self.inception3c = InceptionF5(288)
self.inception_red1 = Inception_ReduceA(288)
self.inception4a = InceptionF6(768, f_7x7 = 128)
self.inception4b = InceptionF6(768, f_7x7 = 160)
self.inception4c = InceptionF6(768, f_7x7 = 160)
self.inception4d = InceptionF6(768, f_7x7 = 160)
self.inception4e = InceptionF6(768, f_7x7 = 192)
if use_aux:
self.aux = Inception_Aux(768, num_classes = num_classes)
self.inception_red2 = Inception_ReduceB(768)
self.inception5a = InceptionF7(1280)
self.inception5b = InceptionF7(2048)
self.pool6 = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(p = drop_p)
self.fc = nn.Linear(2048, num_classes)
def forward(self, x):
x = self.conv1a(x) # -> N x 32 x 149 x 149
x = self.conv1b(x) # -> N x 32 x 147 x 147
x = self.conv1c(x) # -> N x 64 x 147 x 147
x = self.pool1(x) # -> N x 64 x 73 x 73
x = self.conv2a(x) # -> N x 80 x 71 x 71
x = self.conv2b(x) # -> N x 192 x 35 x 35
x = self.conv2c(x) # -> N x 288 x 35 x 35
x = self.inception3a(x) # -> N x (96 + 64 * 3) x 35 x 35 = N x 288 x 35 x 35
x = self.inception3b(x) # -> N x 288 x 35 x 35
x = self.inception3c(x) # -> N x 288 x 35 x 35
x = self.inception_red1(x) # -> N x 768 x 17 x 17
x = self.inception4a(x) # -> N x (192 * 4) x 17 x 17 = N x 768 x 17 x 17
x = self.inception4b(x) # -> N x 768 x 17 x 17
x = self.inception4c(x) # -> N x 768 x 17 x 17
x = self.inception4d(x) # -> N x 768 x 17 x 17
x = self.inception4e(x) # -> N x 768 x 17 x 17
if self.aux is not None and self.training:
aux = self.aux(x)
else:
aux = None # Not defined error 방지
x = self.inception_red2(x) # -> N x 1280 x 8 x 8
x = self.inception5a(x) # -> N x (384 * 2 * 2 + 192 + 320) x 8 x 8 = N x 2048 x 8 x 8
x = self.inception5b(x) # -> N x 2048 x 8 x 8
x = self.pool6(x) # -> N x 2048 x 1 x 1
x = torch.flatten(x, 1) # -> N x 2048
x = self.dropout(x)
x = self.fc(x) # -> N x 1000
return x, aux
model = Inception_V3()
summary(model, input_size=(2,3,299,299), device='cpu')
#### Output ####
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
Inception_V2 [2, 1000] 4,401,384
├─BasicConv2d: 1-1 [2, 32, 149, 149] --
│ └─Conv2d: 2-1 [2, 32, 149, 149] 864
│ └─BatchNorm2d: 2-2 [2, 32, 149, 149] 64
│ └─ReLU: 2-3 [2, 32, 149, 149] --
├─BasicConv2d: 1-2 [2, 32, 147, 147] --
│ └─Conv2d: 2-4 [2, 32, 147, 147] 9,216
│ └─BatchNorm2d: 2-5 [2, 32, 147, 147] 64
│ └─ReLU: 2-6 [2, 32, 147, 147] --
├─BasicConv2d: 1-3 [2, 64, 147, 147] --
│ └─Conv2d: 2-7 [2, 64, 147, 147] 18,432
│ └─BatchNorm2d: 2-8 [2, 64, 147, 147] 128
│ └─ReLU: 2-9 [2, 64, 147, 147] --
├─MaxPool2d: 1-4 [2, 64, 73, 73] --
├─BasicConv2d: 1-5 [2, 80, 71, 71] --
│ └─Conv2d: 2-10 [2, 80, 71, 71] 46,080
│ └─BatchNorm2d: 2-11 [2, 80, 71, 71] 160
│ └─ReLU: 2-12 [2, 80, 71, 71] --
├─BasicConv2d: 1-6 [2, 192, 35, 35] --
│ └─Conv2d: 2-13 [2, 192, 35, 35] 138,240
│ └─BatchNorm2d: 2-14 [2, 192, 35, 35] 384
│ └─ReLU: 2-15 [2, 192, 35, 35] --
├─BasicConv2d: 1-7 [2, 288, 35, 35] --
│ └─Conv2d: 2-16 [2, 288, 35, 35] 497,664
│ └─BatchNorm2d: 2-17 [2, 288, 35, 35] 576
│ └─ReLU: 2-18 [2, 288, 35, 35] --
├─InceptionF5: 1-8 [2, 288, 35, 35] --
│ └─Sequential: 2-19 [2, 96, 35, 35] --
│ │ └─BasicConv2d: 3-1 [2, 64, 35, 35] 18,560
│ │ └─BasicConv2d: 3-2 [2, 96, 35, 35] 55,488
│ │ └─BasicConv2d: 3-3 [2, 96, 35, 35] 83,136
│ └─Sequential: 2-20 [2, 64, 35, 35] --
│ │ └─BasicConv2d: 3-4 [2, 48, 35, 35] 13,920
│ │ └─BasicConv2d: 3-5 [2, 64, 35, 35] 27,776
│ └─Sequential: 2-21 [2, 64, 35, 35] --
│ │ └─MaxPool2d: 3-6 [2, 288, 35, 35] --
│ │ └─BasicConv2d: 3-7 [2, 64, 35, 35] 18,560
│ └─BasicConv2d: 2-22 [2, 64, 35, 35] --
│ │ └─Conv2d: 3-8 [2, 64, 35, 35] 18,432
│ │ └─BatchNorm2d: 3-9 [2, 64, 35, 35] 128
│ │ └─ReLU: 3-10 [2, 64, 35, 35] --
├─InceptionF5: 1-9 [2, 288, 35, 35] --
│ └─Sequential: 2-23 [2, 96, 35, 35] --
│ │ └─BasicConv2d: 3-11 [2, 64, 35, 35] 18,560
│ │ └─BasicConv2d: 3-12 [2, 96, 35, 35] 55,488
│ │ └─BasicConv2d: 3-13 [2, 96, 35, 35] 83,136
│ └─Sequential: 2-24 [2, 64, 35, 35] --
│ │ └─BasicConv2d: 3-14 [2, 48, 35, 35] 13,920
│ │ └─BasicConv2d: 3-15 [2, 64, 35, 35] 27,776
│ └─Sequential: 2-25 [2, 64, 35, 35] --
│ │ └─MaxPool2d: 3-16 [2, 288, 35, 35] --
│ │ └─BasicConv2d: 3-17 [2, 64, 35, 35] 18,560
│ └─BasicConv2d: 2-26 [2, 64, 35, 35] --
│ │ └─Conv2d: 3-18 [2, 64, 35, 35] 18,432
│ │ └─BatchNorm2d: 3-19 [2, 64, 35, 35] 128
│ │ └─ReLU: 3-20 [2, 64, 35, 35] --
├─InceptionF5: 1-10 [2, 288, 35, 35] --
│ └─Sequential: 2-27 [2, 96, 35, 35] --
│ │ └─BasicConv2d: 3-21 [2, 64, 35, 35] 18,560
│ │ └─BasicConv2d: 3-22 [2, 96, 35, 35] 55,488
│ │ └─BasicConv2d: 3-23 [2, 96, 35, 35] 83,136
│ └─Sequential: 2-28 [2, 64, 35, 35] --
│ │ └─BasicConv2d: 3-24 [2, 48, 35, 35] 13,920
│ │ └─BasicConv2d: 3-25 [2, 64, 35, 35] 27,776
│ └─Sequential: 2-29 [2, 64, 35, 35] --
│ │ └─MaxPool2d: 3-26 [2, 288, 35, 35] --
│ │ └─BasicConv2d: 3-27 [2, 64, 35, 35] 18,560
│ └─BasicConv2d: 2-30 [2, 64, 35, 35] --
│ │ └─Conv2d: 3-28 [2, 64, 35, 35] 18,432
│ │ └─BatchNorm2d: 3-29 [2, 64, 35, 35] 128
│ │ └─ReLU: 3-30 [2, 64, 35, 35] --
├─Inception_ReduceA: 1-11 [2, 768, 17, 17] --
│ └─Sequential: 2-31 [2, 96, 17, 17] --
│ │ └─BasicConv2d: 3-31 [2, 64, 35, 35] 18,560
│ │ └─BasicConv2d: 3-32 [2, 96, 35, 35] 55,488
│ │ └─BasicConv2d: 3-33 [2, 96, 17, 17] 83,136
│ └─BasicConv2d: 2-32 [2, 384, 17, 17] --
│ │ └─Conv2d: 3-34 [2, 384, 17, 17] 995,328
│ │ └─BatchNorm2d: 3-35 [2, 384, 17, 17] 768
│ │ └─ReLU: 3-36 [2, 384, 17, 17] --
│ └─MaxPool2d: 2-33 [2, 288, 17, 17] --
├─InceptionF6: 1-12 [2, 768, 17, 17] --
│ └─Sequential: 2-34 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-37 [2, 128, 17, 17] 98,560
│ │ └─BasicConv2d: 3-38 [2, 128, 17, 17] 114,944
│ │ └─BasicConv2d: 3-39 [2, 128, 17, 17] 114,944
│ │ └─BasicConv2d: 3-40 [2, 128, 17, 17] 114,944
│ │ └─BasicConv2d: 3-41 [2, 192, 17, 17] 172,416
│ └─Sequential: 2-35 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-42 [2, 128, 17, 17] 98,560
│ │ └─BasicConv2d: 3-43 [2, 128, 17, 17] 114,944
│ │ └─BasicConv2d: 3-44 [2, 192, 17, 17] 172,416
│ └─Sequential: 2-36 [2, 192, 17, 17] --
│ │ └─MaxPool2d: 3-45 [2, 768, 17, 17] --
│ │ └─BasicConv2d: 3-46 [2, 192, 17, 17] 147,840
│ └─BasicConv2d: 2-37 [2, 192, 17, 17] --
│ │ └─Conv2d: 3-47 [2, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-48 [2, 192, 17, 17] 384
│ │ └─ReLU: 3-49 [2, 192, 17, 17] --
├─InceptionF6: 1-13 [2, 768, 17, 17] --
│ └─Sequential: 2-38 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-50 [2, 160, 17, 17] 123,200
│ │ └─BasicConv2d: 3-51 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-52 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-53 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-54 [2, 192, 17, 17] 215,424
│ └─Sequential: 2-39 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-55 [2, 160, 17, 17] 123,200
│ │ └─BasicConv2d: 3-56 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-57 [2, 192, 17, 17] 215,424
│ └─Sequential: 2-40 [2, 192, 17, 17] --
│ │ └─MaxPool2d: 3-58 [2, 768, 17, 17] --
│ │ └─BasicConv2d: 3-59 [2, 192, 17, 17] 147,840
│ └─BasicConv2d: 2-41 [2, 192, 17, 17] --
│ │ └─Conv2d: 3-60 [2, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-61 [2, 192, 17, 17] 384
│ │ └─ReLU: 3-62 [2, 192, 17, 17] --
├─InceptionF6: 1-14 [2, 768, 17, 17] --
│ └─Sequential: 2-42 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-63 [2, 160, 17, 17] 123,200
│ │ └─BasicConv2d: 3-64 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-65 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-66 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-67 [2, 192, 17, 17] 215,424
│ └─Sequential: 2-43 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-68 [2, 160, 17, 17] 123,200
│ │ └─BasicConv2d: 3-69 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-70 [2, 192, 17, 17] 215,424
│ └─Sequential: 2-44 [2, 192, 17, 17] --
│ │ └─MaxPool2d: 3-71 [2, 768, 17, 17] --
│ │ └─BasicConv2d: 3-72 [2, 192, 17, 17] 147,840
│ └─BasicConv2d: 2-45 [2, 192, 17, 17] --
│ │ └─Conv2d: 3-73 [2, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-74 [2, 192, 17, 17] 384
│ │ └─ReLU: 3-75 [2, 192, 17, 17] --
├─InceptionF6: 1-15 [2, 768, 17, 17] --
│ └─Sequential: 2-46 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-76 [2, 160, 17, 17] 123,200
│ │ └─BasicConv2d: 3-77 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-78 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-79 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-80 [2, 192, 17, 17] 215,424
│ └─Sequential: 2-47 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-81 [2, 160, 17, 17] 123,200
│ │ └─BasicConv2d: 3-82 [2, 160, 17, 17] 179,520
│ │ └─BasicConv2d: 3-83 [2, 192, 17, 17] 215,424
│ └─Sequential: 2-48 [2, 192, 17, 17] --
│ │ └─MaxPool2d: 3-84 [2, 768, 17, 17] --
│ │ └─BasicConv2d: 3-85 [2, 192, 17, 17] 147,840
│ └─BasicConv2d: 2-49 [2, 192, 17, 17] --
│ │ └─Conv2d: 3-86 [2, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-87 [2, 192, 17, 17] 384
│ │ └─ReLU: 3-88 [2, 192, 17, 17] --
├─InceptionF6: 1-16 [2, 768, 17, 17] --
│ └─Sequential: 2-50 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-89 [2, 192, 17, 17] 147,840
│ │ └─BasicConv2d: 3-90 [2, 192, 17, 17] 258,432
│ │ └─BasicConv2d: 3-91 [2, 192, 17, 17] 258,432
│ │ └─BasicConv2d: 3-92 [2, 192, 17, 17] 258,432
│ │ └─BasicConv2d: 3-93 [2, 192, 17, 17] 258,432
│ └─Sequential: 2-51 [2, 192, 17, 17] --
│ │ └─BasicConv2d: 3-94 [2, 192, 17, 17] 147,840
│ │ └─BasicConv2d: 3-95 [2, 192, 17, 17] 258,432
│ │ └─BasicConv2d: 3-96 [2, 192, 17, 17] 258,432
│ └─Sequential: 2-52 [2, 192, 17, 17] --
│ │ └─MaxPool2d: 3-97 [2, 768, 17, 17] --
│ │ └─BasicConv2d: 3-98 [2, 192, 17, 17] 147,840
│ └─BasicConv2d: 2-53 [2, 192, 17, 17] --
│ │ └─Conv2d: 3-99 [2, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-100 [2, 192, 17, 17] 384
│ │ └─ReLU: 3-101 [2, 192, 17, 17] --
├─Inception_ReduceB: 1-17 [2, 1280, 8, 8] --
│ └─Sequential: 2-54 [2, 192, 8, 8] --
│ │ └─BasicConv2d: 3-102 [2, 192, 17, 17] 147,840
│ │ └─BasicConv2d: 3-103 [2, 192, 17, 17] 258,432
│ │ └─BasicConv2d: 3-104 [2, 192, 17, 17] 258,432
│ │ └─BasicConv2d: 3-105 [2, 192, 8, 8] 332,160
│ └─Sequential: 2-55 [2, 320, 8, 8] --
│ │ └─BasicConv2d: 3-106 [2, 192, 17, 17] 147,840
│ │ └─BasicConv2d: 3-107 [2, 320, 8, 8] 553,600
│ └─MaxPool2d: 2-56 [2, 768, 8, 8] --
├─InceptionF7: 1-18 [2, 2048, 8, 8] --
│ └─Sequential: 2-57 [2, 384, 8, 8] --
│ │ └─BasicConv2d: 3-108 [2, 448, 8, 8] 574,336
│ │ └─BasicConv2d: 3-109 [2, 384, 8, 8] 1,549,056
│ └─BasicConv2d: 2-58 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-110 [2, 384, 8, 8] 491,520
│ │ └─BatchNorm2d: 3-111 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-112 [2, 384, 8, 8] --
│ └─BasicConv2d: 2-59 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-113 [2, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-114 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-115 [2, 384, 8, 8] --
│ └─BasicConv2d: 2-60 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-116 [2, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-117 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-118 [2, 384, 8, 8] --
│ └─BasicConv2d: 2-61 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-119 [2, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-120 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-121 [2, 384, 8, 8] --
│ └─BasicConv2d: 2-62 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-122 [2, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-123 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-124 [2, 384, 8, 8] --
│ └─Sequential: 2-63 [2, 192, 8, 8] --
│ │ └─MaxPool2d: 3-125 [2, 1280, 8, 8] --
│ │ └─BasicConv2d: 3-126 [2, 192, 8, 8] 246,144
│ └─BasicConv2d: 2-64 [2, 320, 8, 8] --
│ │ └─Conv2d: 3-127 [2, 320, 8, 8] 409,600
│ │ └─BatchNorm2d: 3-128 [2, 320, 8, 8] 640
│ │ └─ReLU: 3-129 [2, 320, 8, 8] --
├─InceptionF7: 1-19 [2, 2048, 8, 8] --
│ └─Sequential: 2-65 [2, 384, 8, 8] --
│ │ └─BasicConv2d: 3-130 [2, 448, 8, 8] 918,400
│ │ └─BasicConv2d: 3-131 [2, 384, 8, 8] 1,549,056
│ └─BasicConv2d: 2-66 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-132 [2, 384, 8, 8] 786,432
│ │ └─BatchNorm2d: 3-133 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-134 [2, 384, 8, 8] --
│ └─BasicConv2d: 2-67 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-135 [2, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-136 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-137 [2, 384, 8, 8] --
│ └─BasicConv2d: 2-68 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-138 [2, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-139 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-140 [2, 384, 8, 8] --
│ └─BasicConv2d: 2-69 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-141 [2, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-142 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-143 [2, 384, 8, 8] --
│ └─BasicConv2d: 2-70 [2, 384, 8, 8] --
│ │ └─Conv2d: 3-144 [2, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-145 [2, 384, 8, 8] 768
│ │ └─ReLU: 3-146 [2, 384, 8, 8] --
│ └─Sequential: 2-71 [2, 192, 8, 8] --
│ │ └─MaxPool2d: 3-147 [2, 2048, 8, 8] --
│ │ └─BasicConv2d: 3-148 [2, 192, 8, 8] 393,600
│ └─BasicConv2d: 2-72 [2, 320, 8, 8] --
│ │ └─Conv2d: 3-149 [2, 320, 8, 8] 655,360
│ │ └─BatchNorm2d: 3-150 [2, 320, 8, 8] 640
│ │ └─ReLU: 3-151 [2, 320, 8, 8] --
├─AdaptiveAvgPool2d: 1-20 [2, 2048, 1, 1] --
├─Dropout: 1-21 [2, 2048] --
├─Linear: 1-22 [2, 1000] 2,049,000
==========================================================================================
Total params: 30,355,632
Trainable params: 30,355,632
Non-trainable params: 0
Total mult-adds (G): 12.71
==========================================================================================
Input size (MB): 2.15
Forward/backward pass size (MB): 291.32
Params size (MB): 103.82
Estimated Total Size (MB): 397.28
==========================================================================================