CSPNet을 이해하고 Pytorch로 구현할 수 있다.
growth rate
만큼 feature map의 채널수가 점점 늘어나고, Transition Layer에서는 feature map의 크기와 채널 수를 조절한다.CSP_features_01
, Transition Layer에서 concat할 feature map을 CSP_features_02
라고 한다면 각각의 feature map 사이즈는 아래와 같을 것이다.CSP_features_01.shape = (N, C//2, H, W)
,CSP_features_02.shape = (N, C - C//2, H, W)
CSP_features_01
은 DenseBlock의 DenseLayer를 여러 개 통과하면서 (N, C//2 + growth_rate * ?, H, W)
가 될 것이고, DenseBlock을 통과하면 CSP_features_02
와 채널축 방향으로 concat한 뒤 Transition Layer를 통과할 것이다.이전에 구현한 DenseNet에 CSPNet을 추가하여 CSPDenseNet을 구현하였다.
import torch
from torch import nn
from torchinfo import summary
class DenseLayer(nn.Module):
def __init__(self, in_channels, k):
super().__init__()
self.residual = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels, 4 * k, 1, bias = False),
nn.BatchNorm2d(4 * k),
nn.ReLU(inplace = True),
nn.Conv2d(4 * k, k, 3, padding = 1, bias = False),
)
def forward(self, x):
return torch.concat([self.residual(x), x], dim = 1)
class Transition(nn.Module):
def __init__(self, in_channels, csp_transition = False):
super().__init__()
transition_layers = [
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace = True),
nn.Conv2d(in_channels, in_channels // 2, 1, bias = False),
]
if csp_transition is not True:
transition_layers.append(nn.AvgPool2d(2))
self.transition = nn.Sequential(*transition_layers)
def forward(self, x):
return self.transition(x)
class CSPDenseBlock(nn.Module):
def __init__(self, in_channels, num_blocks, k, last_stage = False):
super().__init__()
self.in_channels = in_channels
csp_channels_01 = in_channels // 2
csp_channels_02 = in_channels - csp_channels_01
layers = []
for _ in range(num_blocks):
layers.append(DenseLayer(csp_channels_01, k))
csp_channels_01 += k
layers.append(Transition(csp_channels_01, csp_transition = True))
csp_channels_01 //= 2
self.dense_block = nn.Sequential(*layers)
self.last = nn.Sequential(nn.BatchNorm2d(csp_channels_01 + csp_channels_02), nn.ReLU(inplace = True)) if last_stage else Transition(csp_channels_01 + csp_channels_02)
self.channels = csp_channels_01 + csp_channels_02 if last_stage else (csp_channels_01 + csp_channels_02) // 2
def forward(self, x):
if self.in_channels % 2:
csp_x_01 = x[:, self.in_channels // 2 + 1:, ...]
csp_x_02 = x[:, :self.in_channels // 2 + 1, ...]
else:
csp_x_01 = x[:, self.in_channels // 2:, ...]
csp_x_02 = x[:, :self.in_channels // 2, ...]
csp_x_01 = self.dense_block(csp_x_01)
csp_x = torch.cat([csp_x_01, csp_x_02], dim = 1)
return self.last(csp_x)
class CSPDenseNet(nn.Module):
def __init__(self, block_list, growth_rate, n_classes = 1000):
super().__init__()
assert len(block_list) == 4
self.k = growth_rate
self.conv1 = nn.Sequential(
nn.Conv2d(3, 2 * self.k, 7, stride = 2, padding = 3, bias = False),
nn.BatchNorm2d(2 * self.k),
nn.ReLU(inplace = True),
)
self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)
self.dense_block_01 = CSPDenseBlock(2 * self.k, block_list[0], self.k)
self.dense_block_02 = CSPDenseBlock(self.dense_block_01.channels, block_list[1], self.k)
self.dense_block_03 = CSPDenseBlock(self.dense_block_02.channels, block_list[2], self.k)
self.dense_block_04 = CSPDenseBlock(self.dense_block_03.channels, block_list[3], self.k, last_stage = True)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(self.dense_block_04.channels, n_classes)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.dense_block_01(x)
x = self.dense_block_02(x)
x = self.dense_block_03(x)
x = self.dense_block_04(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
model = CSPDenseNet264()
summary(model, input_size = (2, 3, 224, 224), device = "cpu")
#### OUTPUT ####
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
CSPDenseNet [2, 1000] --
├─Sequential: 1-1 [2, 64, 112, 112] --
│ └─Conv2d: 2-1 [2, 64, 112, 112] 9,408
│ └─BatchNorm2d: 2-2 [2, 64, 112, 112] 128
│ └─ReLU: 2-3 [2, 64, 112, 112] --
├─MaxPool2d: 1-2 [2, 64, 56, 56] --
├─CSPDenseBlock: 1-3 [2, 72, 28, 28] --
│ └─Sequential: 2-4 [2, 112, 56, 56] --
│ │ └─DenseLayer: 3-1 [2, 64, 56, 56] 41,280
│ │ └─DenseLayer: 3-2 [2, 96, 56, 56] 45,440
│ │ └─DenseLayer: 3-3 [2, 128, 56, 56] 49,600
│ │ └─DenseLayer: 3-4 [2, 160, 56, 56] 53,760
│ │ └─DenseLayer: 3-5 [2, 192, 56, 56] 57,920
│ │ └─DenseLayer: 3-6 [2, 224, 56, 56] 62,080
│ │ └─Transition: 3-7 [2, 112, 56, 56] 25,536
│ └─Transition: 2-5 [2, 72, 28, 28] --
│ │ └─Sequential: 3-8 [2, 72, 28, 28] 10,656
├─CSPDenseBlock: 1-4 [2, 123, 14, 14] --
│ └─Sequential: 2-6 [2, 210, 28, 28] --
│ │ └─DenseLayer: 3-9 [2, 68, 28, 28] 41,800
│ │ └─DenseLayer: 3-10 [2, 100, 28, 28] 45,960
│ │ └─DenseLayer: 3-11 [2, 132, 28, 28] 50,120
│ │ └─DenseLayer: 3-12 [2, 164, 28, 28] 54,280
│ │ └─DenseLayer: 3-13 [2, 196, 28, 28] 58,440
│ │ └─DenseLayer: 3-14 [2, 228, 28, 28] 62,600
│ │ └─DenseLayer: 3-15 [2, 260, 28, 28] 66,760
│ │ └─DenseLayer: 3-16 [2, 292, 28, 28] 70,920
│ │ └─DenseLayer: 3-17 [2, 324, 28, 28] 75,080
│ │ └─DenseLayer: 3-18 [2, 356, 28, 28] 79,240
│ │ └─DenseLayer: 3-19 [2, 388, 28, 28] 83,400
│ │ └─DenseLayer: 3-20 [2, 420, 28, 28] 87,560
│ │ └─Transition: 3-21 [2, 210, 28, 28] 89,040
│ └─Transition: 2-7 [2, 123, 14, 14] --
│ │ └─Sequential: 3-22 [2, 123, 14, 14] 30,750
├─CSPDenseBlock: 1-5 [2, 558, 7, 7] --
│ └─Sequential: 2-8 [2, 1054, 14, 14] --
│ │ └─DenseLayer: 3-23 [2, 93, 14, 14] 45,050
│ │ └─DenseLayer: 3-24 [2, 125, 14, 14] 49,210
│ │ └─DenseLayer: 3-25 [2, 157, 14, 14] 53,370
│ │ └─DenseLayer: 3-26 [2, 189, 14, 14] 57,530
│ │ └─DenseLayer: 3-27 [2, 221, 14, 14] 61,690
│ │ └─DenseLayer: 3-28 [2, 253, 14, 14] 65,850
│ │ └─DenseLayer: 3-29 [2, 285, 14, 14] 70,010
│ │ └─DenseLayer: 3-30 [2, 317, 14, 14] 74,170
│ │ └─DenseLayer: 3-31 [2, 349, 14, 14] 78,330
│ │ └─DenseLayer: 3-32 [2, 381, 14, 14] 82,490
│ │ └─DenseLayer: 3-33 [2, 413, 14, 14] 86,650
│ │ └─DenseLayer: 3-34 [2, 445, 14, 14] 90,810
│ │ └─DenseLayer: 3-35 [2, 477, 14, 14] 94,970
│ │ └─DenseLayer: 3-36 [2, 509, 14, 14] 99,130
│ │ └─DenseLayer: 3-37 [2, 541, 14, 14] 103,290
│ │ └─DenseLayer: 3-38 [2, 573, 14, 14] 107,450
│ │ └─DenseLayer: 3-39 [2, 605, 14, 14] 111,610
│ │ └─DenseLayer: 3-40 [2, 637, 14, 14] 115,770
│ │ └─DenseLayer: 3-41 [2, 669, 14, 14] 119,930
│ │ └─DenseLayer: 3-42 [2, 701, 14, 14] 124,090
│ │ └─DenseLayer: 3-43 [2, 733, 14, 14] 128,250
│ │ └─DenseLayer: 3-44 [2, 765, 14, 14] 132,410
│ │ └─DenseLayer: 3-45 [2, 797, 14, 14] 136,570
│ │ └─DenseLayer: 3-46 [2, 829, 14, 14] 140,730
│ │ └─DenseLayer: 3-47 [2, 861, 14, 14] 144,890
│ │ └─DenseLayer: 3-48 [2, 893, 14, 14] 149,050
│ │ └─DenseLayer: 3-49 [2, 925, 14, 14] 153,210
│ │ └─DenseLayer: 3-50 [2, 957, 14, 14] 157,370
│ │ └─DenseLayer: 3-51 [2, 989, 14, 14] 161,530
│ │ └─DenseLayer: 3-52 [2, 1021, 14, 14] 165,690
│ │ └─DenseLayer: 3-53 [2, 1053, 14, 14] 169,850
│ │ └─DenseLayer: 3-54 [2, 1085, 14, 14] 174,010
│ │ └─DenseLayer: 3-55 [2, 1117, 14, 14] 178,170
│ │ └─DenseLayer: 3-56 [2, 1149, 14, 14] 182,330
│ │ └─DenseLayer: 3-57 [2, 1181, 14, 14] 186,490
│ │ └─DenseLayer: 3-58 [2, 1213, 14, 14] 190,650
│ │ └─DenseLayer: 3-59 [2, 1245, 14, 14] 194,810
│ │ └─DenseLayer: 3-60 [2, 1277, 14, 14] 198,970
│ │ └─DenseLayer: 3-61 [2, 1309, 14, 14] 203,130
│ │ └─DenseLayer: 3-62 [2, 1341, 14, 14] 207,290
│ │ └─DenseLayer: 3-63 [2, 1373, 14, 14] 211,450
│ │ └─DenseLayer: 3-64 [2, 1405, 14, 14] 215,610
│ │ └─DenseLayer: 3-65 [2, 1437, 14, 14] 219,770
│ │ └─DenseLayer: 3-66 [2, 1469, 14, 14] 223,930
│ │ └─DenseLayer: 3-67 [2, 1501, 14, 14] 228,090
│ │ └─DenseLayer: 3-68 [2, 1533, 14, 14] 232,250
│ │ └─DenseLayer: 3-69 [2, 1565, 14, 14] 236,410
│ │ └─DenseLayer: 3-70 [2, 1597, 14, 14] 240,570
│ │ └─DenseLayer: 3-71 [2, 1629, 14, 14] 244,730
│ │ └─DenseLayer: 3-72 [2, 1661, 14, 14] 248,890
│ │ └─DenseLayer: 3-73 [2, 1693, 14, 14] 253,050
│ │ └─DenseLayer: 3-74 [2, 1725, 14, 14] 257,210
│ │ └─DenseLayer: 3-75 [2, 1757, 14, 14] 261,370
│ │ └─DenseLayer: 3-76 [2, 1789, 14, 14] 265,530
│ │ └─DenseLayer: 3-77 [2, 1821, 14, 14] 269,690
│ │ └─DenseLayer: 3-78 [2, 1853, 14, 14] 273,850
│ │ └─DenseLayer: 3-79 [2, 1885, 14, 14] 278,010
│ │ └─DenseLayer: 3-80 [2, 1917, 14, 14] 282,170
│ │ └─DenseLayer: 3-81 [2, 1949, 14, 14] 286,330
│ │ └─DenseLayer: 3-82 [2, 1981, 14, 14] 290,490
│ │ └─DenseLayer: 3-83 [2, 2013, 14, 14] 294,650
│ │ └─DenseLayer: 3-84 [2, 2045, 14, 14] 298,810
│ │ └─DenseLayer: 3-85 [2, 2077, 14, 14] 302,970
│ │ └─DenseLayer: 3-86 [2, 2109, 14, 14] 307,130
│ │ └─Transition: 3-87 [2, 1054, 14, 14] 2,227,104
│ └─Transition: 2-9 [2, 558, 7, 7] --
│ │ └─Sequential: 3-88 [2, 558, 7, 7] 624,960
├─CSPDenseBlock: 1-6 [2, 1186, 7, 7] --
│ └─Sequential: 2-10 [2, 907, 7, 7] --
│ │ └─DenseLayer: 3-89 [2, 311, 7, 7] 73,390
│ │ └─DenseLayer: 3-90 [2, 343, 7, 7] 77,550
│ │ └─DenseLayer: 3-91 [2, 375, 7, 7] 81,710
│ │ └─DenseLayer: 3-92 [2, 407, 7, 7] 85,870
│ │ └─DenseLayer: 3-93 [2, 439, 7, 7] 90,030
│ │ └─DenseLayer: 3-94 [2, 471, 7, 7] 94,190
│ │ └─DenseLayer: 3-95 [2, 503, 7, 7] 98,350
│ │ └─DenseLayer: 3-96 [2, 535, 7, 7] 102,510
│ │ └─DenseLayer: 3-97 [2, 567, 7, 7] 106,670
│ │ └─DenseLayer: 3-98 [2, 599, 7, 7] 110,830
│ │ └─DenseLayer: 3-99 [2, 631, 7, 7] 114,990
│ │ └─DenseLayer: 3-100 [2, 663, 7, 7] 119,150
│ │ └─DenseLayer: 3-101 [2, 695, 7, 7] 123,310
│ │ └─DenseLayer: 3-102 [2, 727, 7, 7] 127,470
│ │ └─DenseLayer: 3-103 [2, 759, 7, 7] 131,630
│ │ └─DenseLayer: 3-104 [2, 791, 7, 7] 135,790
│ │ └─DenseLayer: 3-105 [2, 823, 7, 7] 139,950
│ │ └─DenseLayer: 3-106 [2, 855, 7, 7] 144,110
│ │ └─DenseLayer: 3-107 [2, 887, 7, 7] 148,270
│ │ └─DenseLayer: 3-108 [2, 919, 7, 7] 152,430
│ │ └─DenseLayer: 3-109 [2, 951, 7, 7] 156,590
│ │ └─DenseLayer: 3-110 [2, 983, 7, 7] 160,750
│ │ └─DenseLayer: 3-111 [2, 1015, 7, 7] 164,910
│ │ └─DenseLayer: 3-112 [2, 1047, 7, 7] 169,070
│ │ └─DenseLayer: 3-113 [2, 1079, 7, 7] 173,230
│ │ └─DenseLayer: 3-114 [2, 1111, 7, 7] 177,390
│ │ └─DenseLayer: 3-115 [2, 1143, 7, 7] 181,550
│ │ └─DenseLayer: 3-116 [2, 1175, 7, 7] 185,710
│ │ └─DenseLayer: 3-117 [2, 1207, 7, 7] 189,870
│ │ └─DenseLayer: 3-118 [2, 1239, 7, 7] 194,030
│ │ └─DenseLayer: 3-119 [2, 1271, 7, 7] 198,190
│ │ └─DenseLayer: 3-120 [2, 1303, 7, 7] 202,350
│ │ └─DenseLayer: 3-121 [2, 1335, 7, 7] 206,510
│ │ └─DenseLayer: 3-122 [2, 1367, 7, 7] 210,670
│ │ └─DenseLayer: 3-123 [2, 1399, 7, 7] 214,830
│ │ └─DenseLayer: 3-124 [2, 1431, 7, 7] 218,990
│ │ └─DenseLayer: 3-125 [2, 1463, 7, 7] 223,150
│ │ └─DenseLayer: 3-126 [2, 1495, 7, 7] 227,310
│ │ └─DenseLayer: 3-127 [2, 1527, 7, 7] 231,470
│ │ └─DenseLayer: 3-128 [2, 1559, 7, 7] 235,630
│ │ └─DenseLayer: 3-129 [2, 1591, 7, 7] 239,790
│ │ └─DenseLayer: 3-130 [2, 1623, 7, 7] 243,950
│ │ └─DenseLayer: 3-131 [2, 1655, 7, 7] 248,110
│ │ └─DenseLayer: 3-132 [2, 1687, 7, 7] 252,270
│ │ └─DenseLayer: 3-133 [2, 1719, 7, 7] 256,430
│ │ └─DenseLayer: 3-134 [2, 1751, 7, 7] 260,590
│ │ └─DenseLayer: 3-135 [2, 1783, 7, 7] 264,750
│ │ └─DenseLayer: 3-136 [2, 1815, 7, 7] 268,910
│ │ └─Transition: 3-137 [2, 907, 7, 7] 1,649,835
│ └─Sequential: 2-11 [2, 1186, 7, 7] --
│ │ └─BatchNorm2d: 3-138 [2, 1186, 7, 7] 2,372
│ │ └─ReLU: 3-139 [2, 1186, 7, 7] --
├─AdaptiveAvgPool2d: 1-7 [2, 1186, 1, 1] --
├─Linear: 1-8 [2, 1000] 1,187,000
===============================================================================================
Total params: 26,427,989
Trainable params: 26,427,989
Non-trainable params: 0
Total mult-adds (G): 10.21
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 601.85
Params size (MB): 105.71
Estimated Total Size (MB): 708.76
===============================================================================================