SE-Net 을 이해하고 Pytorch로 구현할 수 있다.
SE block (Squeeze-and-Excitation block)
SE-Net은 새로운 모델이라기 보다는 기존의 InceptionNet, ResNet 등의 모델에 SE-Block을 추가한 네트워크이다. (해당 논문에서는 SE block을 적용한 ResNet50을 SE-ResNet50 이라고 명명하였다.)
SE는 Squeeze and Excitation의 약자이다. Squeeze
는 pooling을 통해 feature map을 1 x 1
로 변환하는 부분이고, Excitation
은 FC layer
와 Sigmoid
를 통과시켜 어떤 feature map을 더 비중있게 볼지에 대한 가중치 값 (0 ~ 1)을 얻는다.
Conv를 통해 여러 개의 feature map들을 뽑는데, SE block의 역할은 이 여러 개의 feature map중 어떤 것을 좀 더 비중을 두면 좋을지를 결정하는 block이다. 순서는 다음과 같다.
Conv
를 통과한 feature map의 shape이 N x C x W x H
라고 해보자. Avgpool
을 통과시켜 N x C x 1 x 1
로 변환. 그리고 차원을 축소해 N x C
로 변환한다.FC-ReLU-FC-Sigmoid
를 통과시켜 N x C
는 각 C
에 대해 0~1 사이의 값으로 채워지게된다.N x C x 1 x 1
로 형태로 변환 후 feature map에 곱한다. (각각의 0 ~ 1 값을 feature map에 곱함)Reduction ratio
Excitation layers는 FC-ReLU-FC-Sigmoid
로 구성되어 있다.
FC-Sigmoid
로 구성하지 않고 왜 이렇게 구성해놓았을까? 이는 파라미터 수를 조절하기 위함이다.
(CNN에서 흔히 사용되는 기법인 3x3 Conv
를 바로 여러 개 쌓는 것 보다 중간에 1x1 Conv
를 섞어 쌓으면서 파라미터 수를 조절하려는 의도와 비슷)
즉, 첫 번째 FC
의 out_channels
이자 두 번째 FC의 in_channels
은 몇 으로 두어야 할까를 결정 해주는 것이 reduction ratio
이다. in_channels
을 reduction ratio
로 나눈 값의 몫으로 결정한다. reduction ratio
의 기본 값은 16
이다.
SE block
대신에 Conv 1x1
을 사용해도 되는거 아니야? 뭐가 다른거야?Conv
연산은 weighted sum이다. 예를 들어 input으로 들어오는 shape이 N x 3 x W x H
, output shape이 N x 5 x W x H
라고 해보자. Conv
에서는 3개
의 채널을 가진 필터 5개
를 이용해 5개의 feature map을 만들어낸다. 이 연산에서 각각의 필터는 3 개의 채널을 가지기 때문에 각 채널마다 element-wise
곱을 한 결과를 합해서 하나의 feature map을 만든다. 반면 SE block
연산은 각각의 feature map마다 0~1 사이의 스칼라 곱을 해주는 연산이다. 이는 단순히 각각의 feature map에 가중치를 부여해주는 weighting 이라고 표현할 수 있다. 따라서 output shape도 Conv
와는 다르게 input shape과 동일하다.Conv 1x1
의 input과 output의 shape을 동일하게 나오도록 하고 하면 되지 않나? SE block
을 사용하면 좋은가?SE block
을 적용한 것한 모델(SE
)과 SE block
대신 Conv 1x1
을 적용한 모델(NoSqueeze
)을 비교해보았다. (파라미터 수는 동일하게 맞춤) 결과를 보면 속도(GFLOPs
)도 SE
가 더 빠르고 성능도 SE
가 더 뛰어나다.SE block
을 추가함으로써 성능이 좋아졌다는 주장은 단순히 파라미터수가 늘어나는 만큼 성능이 좋아진게 아닐까?SE block
을 추가함으로써 적은 파라미터 수 증가로도 뛰어난 성능을 낼 수 있게 했다. (효율적인 파라미터 수 증가)ResNeXt101
과 SE-ResNeXt50
의 결과를 비교해보자.SE-ResNet50
이 ResNeXt101
보다 더 좋다.GFLOPs
): SE-ResNet50
이 훨씬 빠르다.SE-ResNet50
이 훨씬 더 적다.SE block
의 exitation
의 Activation과 Squeeze
대한 실험SE block
의 Squeeze
를 Avgpooling
, Maxpooling
로 실험해봤는데 Avgpooling
이 가장 좋았다.SE block
의 Excitation
의 Activation을 Sigmoid
, ReLU
, Tanh
로 실험해봤는데 Sigmoid
가 가장 좋았다.SE block
을 Residual block에서 어떻게 연결시킬지에 대한 실험SE block
을 어디에 연결할 것인지 실험해봤는데 SE-PRE
, starndard SE
가 유사하게 성능이 좋았다고 한다. 이 논문에서는 starndard SE
를 채택했다.SE block
을 ResNet의 어떤 stage의 bottleneck에 적용시킬 것인가에 대한 실험SE block
을 ResNet의 특정 stage의 bottleneck에 적용시키는 것 보다 모든 stage의 bottleneck에 적용시키는 것(SE_All)이 가장 좋았다.SE block
의 excitation
의 reduction ratio을 결정하는 실험import torch
from torch import nn
from torchinfo import summary
class SE_Block(nn.Module):
def __init__(self, in_channels, reduction_ratio = 16):
super().__init__()
self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
self.excitation = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction_ratio),
nn.ReLU(inplace = True),
nn.Linear(in_channels // reduction_ratio, in_channels),
nn.Sigmoid()
)
def forward(self, x):
se = self.squeeze(x)
se = se.reshape(x.shape[0], x.shape[1]) # 개x채x1x1 -> 개x채
se = self.excitation(se)
se = se.unsqueeze(dim = 2).unsqueeze(dim = 3) # 개x채 -> 개x채x1x1
out = se * x
return out
class SE_Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_channels, inner_channels, stride = 1, projection = None):
super().__init__()
self.residual = nn.Sequential(
nn.Conv2d(in_channels, inner_channels, 1, stride = stride, bias = False),
nn.BatchNorm2d(inner_channels),
nn.ReLU(inplace = True),
nn.Conv2d(inner_channels, inner_channels, 3, padding = 1, bias = False),
nn.BatchNorm2d(inner_channels),
nn.ReLU(inplace = True),
nn.Conv2d(inner_channels, inner_channels * self.expansion, 1, bias = False),
nn.BatchNorm2d(inner_channels * self.expansion)
)
self.se_block = SE_Block(inner_channels * self.expansion)
self.projection = projection
self.relu = nn.ReLU(inplace = True)
def forward(self, x):
residual = self.residual(x)
residual = self.se_block(residual)
if self.projection is not None:
skip_connection = self.projection(x)
else:
skip_connection = x
out = self.relu(residual + skip_connection)
return out
class SE_ResNet(nn.Module):
def __init__(self, block, num_blocks_list, n_classes = 1000):
super().__init__()
assert len(num_blocks_list) == 4
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size = 7, stride = 2, padding = 3, bias = False),
nn.BatchNorm2d(64),
nn.ReLU(inplace = True),
)
self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)
self.in_channels = 64
self.stage1 = self.make_stage(block, 64, num_blocks_list[0], stride = 1)
self.stage2 = self.make_stage(block, 128, num_blocks_list[1], stride = 2)
self.stage3 = self.make_stage(block, 256, num_blocks_list[2], stride = 2)
self.stage4 = self.make_stage(block, 512, num_blocks_list[3], stride = 2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, n_classes)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def make_stage(self, block, inner_channels, num_blocks, stride = 1):
if stride != 1 or self.in_channels != inner_channels * block.expansion:
projection = nn.Sequential(
nn.Conv2d(self.in_channels, inner_channels * block.expansion, 1, stride = stride, bias = False),
nn.BatchNorm2d(inner_channels * block.expansion),
)
else:
projection = None
layers = []
for idx in range(num_blocks):
if idx == 0:
layers.append(block(self.in_channels, inner_channels, stride, projection))
self.in_channels = inner_channels * block.expansion
else:
layers.append(block(self.in_channels, inner_channels))
return nn.Sequential(*layers)
def se_resnet50():
return SE_ResNet(SE_Bottleneck, [3, 4, 6, 3])
def se_resnet101():
return SE_ResNet(SE_Bottleneck, [3, 4, 23, 3])
def se_resnet152():
return SE_ResNet(SE_Bottleneck, [3, 8, 36, 3])
model = se_resnet152()
summary(model, input_size=(2,3,224,224), device='cpu')
#### OUTPUT ####
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
SE_ResNet [2, 1000] --
├─Sequential: 1-1 [2, 64, 112, 112] --
│ └─Conv2d: 2-1 [2, 64, 112, 112] 9,408
│ └─BatchNorm2d: 2-2 [2, 64, 112, 112] 128
│ └─ReLU: 2-3 [2, 64, 112, 112] --
├─MaxPool2d: 1-2 [2, 64, 56, 56] --
├─Sequential: 1-3 [2, 256, 56, 56] --
│ └─SE_Bottleneck: 2-4 [2, 256, 56, 56] --
│ │ └─Sequential: 3-1 [2, 256, 56, 56] 58,112
│ │ └─SE_Block: 3-2 [2, 256, 56, 56] 8,464
│ │ └─Sequential: 3-3 [2, 256, 56, 56] 16,896
│ │ └─ReLU: 3-4 [2, 256, 56, 56] --
│ └─SE_Bottleneck: 2-5 [2, 256, 56, 56] --
│ │ └─Sequential: 3-5 [2, 256, 56, 56] 70,400
│ │ └─SE_Block: 3-6 [2, 256, 56, 56] 8,464
│ │ └─ReLU: 3-7 [2, 256, 56, 56] --
│ └─SE_Bottleneck: 2-6 [2, 256, 56, 56] --
│ │ └─Sequential: 3-8 [2, 256, 56, 56] 70,400
│ │ └─SE_Block: 3-9 [2, 256, 56, 56] 8,464
│ │ └─ReLU: 3-10 [2, 256, 56, 56] --
├─Sequential: 1-4 [2, 512, 28, 28] --
│ └─SE_Bottleneck: 2-7 [2, 512, 28, 28] --
│ │ └─Sequential: 3-11 [2, 512, 28, 28] 247,296
│ │ └─SE_Block: 3-12 [2, 512, 28, 28] 33,312
│ │ └─Sequential: 3-13 [2, 512, 28, 28] 132,096
│ │ └─ReLU: 3-14 [2, 512, 28, 28] --
│ └─SE_Bottleneck: 2-8 [2, 512, 28, 28] --
│ │ └─Sequential: 3-15 [2, 512, 28, 28] 280,064
│ │ └─SE_Block: 3-16 [2, 512, 28, 28] 33,312
│ │ └─ReLU: 3-17 [2, 512, 28, 28] --
│ └─SE_Bottleneck: 2-9 [2, 512, 28, 28] --
│ │ └─Sequential: 3-18 [2, 512, 28, 28] 280,064
│ │ └─SE_Block: 3-19 [2, 512, 28, 28] 33,312
│ │ └─ReLU: 3-20 [2, 512, 28, 28] --
│ └─SE_Bottleneck: 2-10 [2, 512, 28, 28] --
│ │ └─Sequential: 3-21 [2, 512, 28, 28] 280,064
│ │ └─SE_Block: 3-22 [2, 512, 28, 28] 33,312
│ │ └─ReLU: 3-23 [2, 512, 28, 28] --
│ └─SE_Bottleneck: 2-11 [2, 512, 28, 28] --
│ │ └─Sequential: 3-24 [2, 512, 28, 28] 280,064
│ │ └─SE_Block: 3-25 [2, 512, 28, 28] 33,312
│ │ └─ReLU: 3-26 [2, 512, 28, 28] --
│ └─SE_Bottleneck: 2-12 [2, 512, 28, 28] --
│ │ └─Sequential: 3-27 [2, 512, 28, 28] 280,064
│ │ └─SE_Block: 3-28 [2, 512, 28, 28] 33,312
│ │ └─ReLU: 3-29 [2, 512, 28, 28] --
│ └─SE_Bottleneck: 2-13 [2, 512, 28, 28] --
│ │ └─Sequential: 3-30 [2, 512, 28, 28] 280,064
│ │ └─SE_Block: 3-31 [2, 512, 28, 28] 33,312
│ │ └─ReLU: 3-32 [2, 512, 28, 28] --
│ └─SE_Bottleneck: 2-14 [2, 512, 28, 28] --
│ │ └─Sequential: 3-33 [2, 512, 28, 28] 280,064
│ │ └─SE_Block: 3-34 [2, 512, 28, 28] 33,312
│ │ └─ReLU: 3-35 [2, 512, 28, 28] --
├─Sequential: 1-5 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-15 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-36 [2, 1024, 14, 14] 986,112
│ │ └─SE_Block: 3-37 [2, 1024, 14, 14] 132,160
│ │ └─Sequential: 3-38 [2, 1024, 14, 14] 526,336
│ │ └─ReLU: 3-39 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-16 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-40 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-41 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-42 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-17 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-43 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-44 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-45 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-18 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-46 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-47 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-48 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-19 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-49 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-50 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-51 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-20 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-52 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-53 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-54 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-21 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-55 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-56 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-57 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-22 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-58 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-59 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-60 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-23 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-61 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-62 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-63 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-24 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-64 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-65 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-66 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-25 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-67 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-68 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-69 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-26 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-70 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-71 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-72 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-27 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-73 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-74 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-75 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-28 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-76 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-77 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-78 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-29 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-79 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-80 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-81 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-30 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-82 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-83 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-84 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-31 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-85 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-86 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-87 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-32 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-88 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-89 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-90 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-33 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-91 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-92 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-93 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-34 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-94 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-95 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-96 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-35 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-97 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-98 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-99 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-36 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-100 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-101 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-102 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-37 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-103 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-104 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-105 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-38 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-106 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-107 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-108 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-39 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-109 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-110 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-111 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-40 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-112 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-113 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-114 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-41 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-115 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-116 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-117 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-42 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-118 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-119 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-120 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-43 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-121 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-122 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-123 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-44 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-124 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-125 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-126 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-45 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-127 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-128 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-129 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-46 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-130 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-131 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-132 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-47 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-133 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-134 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-135 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-48 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-136 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-137 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-138 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-49 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-139 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-140 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-141 [2, 1024, 14, 14] --
│ └─SE_Bottleneck: 2-50 [2, 1024, 14, 14] --
│ │ └─Sequential: 3-142 [2, 1024, 14, 14] 1,117,184
│ │ └─SE_Block: 3-143 [2, 1024, 14, 14] 132,160
│ │ └─ReLU: 3-144 [2, 1024, 14, 14] --
├─Sequential: 1-6 [2, 2048, 7, 7] --
│ └─SE_Bottleneck: 2-51 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-145 [2, 2048, 7, 7] 3,938,304
│ │ └─SE_Block: 3-146 [2, 2048, 7, 7] 526,464
│ │ └─Sequential: 3-147 [2, 2048, 7, 7] 2,101,248
│ │ └─ReLU: 3-148 [2, 2048, 7, 7] --
│ └─SE_Bottleneck: 2-52 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-149 [2, 2048, 7, 7] 4,462,592
│ │ └─SE_Block: 3-150 [2, 2048, 7, 7] 526,464
│ │ └─ReLU: 3-151 [2, 2048, 7, 7] --
│ └─SE_Bottleneck: 2-53 [2, 2048, 7, 7] --
│ │ └─Sequential: 3-152 [2, 2048, 7, 7] 4,462,592
│ │ └─SE_Block: 3-153 [2, 2048, 7, 7] 526,464
│ │ └─ReLU: 3-154 [2, 2048, 7, 7] --
├─AdaptiveAvgPool2d: 1-7 [2, 2048, 1, 1] --
├─Linear: 1-8 [2, 1000] 2,049,000
===============================================================================================
Total params: 66,821,848
Trainable params: 66,821,848
Non-trainable params: 0
Total mult-adds (G): 22.58
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 705.70
Params size (MB): 267.29
Estimated Total Size (MB): 974.19
===============================================================================================