DenseNet 을 이해하고 Pytorch로 구현할 수 있다.
Dense block
Figure 1은 Dense Block을 나타낸다. growth rate는 4이다. 선들이 복잡하게 그려져 있는데 하나씩 살펴보자.
Input에 x개의 채널를 가진 feature map이 있다고 가정. (위 그림에서 빨간색)
BN-ReLU-Conv를 거치면서 4 개의 feature map을 뽑음. (위 그림에서 초록색)
개의 feature map이 BN-ReLU-Conv를 통과해서 4개의 feature map을 뽑음.
-> 빨간색(x) + 초록색(4) -> BN-ReLU-Conv -> 보라색(4)
개의 feature map이 BN-ReLU-Conv를 통과해서 4개의 feature map을 뽑음.
-> 빨간색(x) + 초록색(4) + 보라색(4) -> BN-ReLU-Conv -> 노란색(4)
개의 feature map이 BN-ReLU-Conv를 통과해서 4개의 feature map을 뽑음.
-> 빨간색(x) + 초록색(4) + 보라색(4) + 노란색(4) -> BN-ReLU-Conv -> 주황색(4)
점점 Dense Block의 layer가 growth rate만큼 등차수열 형태로 증가하는 것을 알 수 있다. ( )
그림에 잘 표현해놓았는데 Denseblock에서는 이전 feature map을 채널축으로 concat 하기 때문에 feature map의 크기 (size) 가 변하지 않는다. (채널 수만 growth rate 만큼 증가)
구조상 층이 깊어질 수록 채널수는 점점 늘어나는데 각 out_channels
는 growth rate로 고정되어 있기 때문에 점점 bottleneck이 심해짐. 따라서 중간에 1x1 Conv
를 추가하여 이 현상을 어느정도 완화시킴. 또한 파라미터 수를 절약시키는 효과도 볼 수 있다.
(in -> 1x1 Conv 4k -> 3x3 Conv k
)
Dense Block이 깊어지고, 또 이 block이 여러 개가 존재할 때 모델의 파라미터 수는 어마어마 하게 커질 것이다.(feature map의 크기는 변하지 않으면서 채널 수는 계속 증가할거니깐) 그래서 Transition layer가 DenseBlock 중간중간에 존재한다.
CNN 구조상 언젠가는 size를 줄여야 하고 이 역할은 transition layer가 수행한다.
Conv layer는 채널수를 반으로 줄이는 역할을 한다.
Pooling layer는 feature map의 가로, 세로를 각각 반으로 줄이는 역할을 한다.
ResNet과의 차이점
DenseNet의 DenseBlock은 를 채널축 방향으로 쌓기 때문에 이전 정보를 그대로 보존할 수 있다.
ResNet은 를 더하기(plus) 연산 때문에 DenseBlock에 비해 정보가 뭉개지게 된다.
ResNet 이후에 나온 연구인 full pre-activation을 사용했다.
Conv-Batch-ReLU 였던 순서가 BN-ReLU-Conv로 바뀌었다.
import torch
from torch import nn
from torchinfo import summary
class Bottleneck(nn.Module):
def __init__(self, in_channels, k):
super().__init__()
self.residual = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, 4 * k, 1, bias = False),
nn.BatchNorm2d(4 * k),
nn.ReLU(),
nn.Conv2d(4 * k, k, 3, padding = 1, bias = False),
)
def forward(self, x):
return torch.cat([self.residual(x), x], dim = 1)
class Transition(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.transition = nn.Sequential(
nn.BatchNorm2d(in_channels), # Dense block end: Conv
nn.ReLU(),
nn.Conv2d(in_channels, int(in_channels / 2), 1, bias = False), # Reduce channels
nn.AvgPool2d(2), # Reduce feature map size
)
def forward(self, x):
return self.transition(x)
class DenseNet(nn.Module):
def __init__(self, block_list, growth_rate, n_classes = 1000):
super().__init__()
assert len(block_list) == 4
self.k = growth_rate
self.conv1 = nn.Sequential(
nn.Conv2d(3, 2 * self.k, 7, stride = 2, padding = 3, bias = False),
nn.BatchNorm2d(2 * self.k),
nn.ReLU(),
)
self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)
self.dense_channels = 2 * self.k
dense_blocks = []
dense_blocks.append(self.make_dense_block(block_list[0]))
dense_blocks.append(self.make_dense_block(block_list[1]))
dense_blocks.append(self.make_dense_block(block_list[2]))
dense_blocks.append(self.make_dense_block(block_list[3], last_stage = True))
self.dense_blocks = nn.Sequential(*dense_blocks)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(self.dense_channels, n_classes)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.dense_blocks(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def make_dense_block(self, num_blocks, last_stage = False):
layers = []
for _ in range(num_blocks):
layer = Bottleneck(self.dense_channels, self.k)
layers.append(layer)
self.dense_channels += self.k
if last_stage:
layers.append(nn.BatchNorm2d(self.dense_channels))
layers.append(nn.ReLU())
else:
layers.append(Transition(self.dense_channels))
assert self.dense_channels % 2 == 0
self.dense_channels //= 2
return nn.Sequential(*layers)
def DenseNet121():
return DenseNet(block_list = [6, 12, 24, 16], growth_rate = 32)
def DenseNet169():
return DenseNet(block_list = [6, 12, 32, 32], growth_rate = 32)
def DenseNet201():
return DenseNet(block_list = [6, 12, 48, 32], growth_rate = 32)
def DenseNet264():
return DenseNet(block_list = [6, 12, 64, 48], growth_rate = 32)
model = DenseNet264()
summary(model, input_size = (2, 3, 224, 224), device = "cpu")
#### OUTPUT ####
===============================================================================================
Layer (type:depth-idx) Output Shape Param #
===============================================================================================
DenseNet [2, 1000] --
├─Sequential: 1-1 [2, 64, 112, 112] --
│ └─Conv2d: 2-1 [2, 64, 112, 112] 9,408
│ └─BatchNorm2d: 2-2 [2, 64, 112, 112] 128
│ └─ReLU: 2-3 [2, 64, 112, 112] --
├─MaxPool2d: 1-2 [2, 64, 56, 56] --
├─Sequential: 1-3 [2, 2688, 7, 7] --
│ └─Sequential: 2-4 [2, 128, 28, 28] --
│ │ └─Bottleneck: 3-1 [2, 96, 56, 56] 45,440
│ │ └─Bottleneck: 3-2 [2, 128, 56, 56] 49,600
│ │ └─Bottleneck: 3-3 [2, 160, 56, 56] 53,760
│ │ └─Bottleneck: 3-4 [2, 192, 56, 56] 57,920
│ │ └─Bottleneck: 3-5 [2, 224, 56, 56] 62,080
│ │ └─Bottleneck: 3-6 [2, 256, 56, 56] 66,240
│ │ └─Transition: 3-7 [2, 128, 28, 28] 33,280
│ └─Sequential: 2-5 [2, 256, 14, 14] --
│ │ └─Bottleneck: 3-8 [2, 160, 28, 28] 53,760
│ │ └─Bottleneck: 3-9 [2, 192, 28, 28] 57,920
│ │ └─Bottleneck: 3-10 [2, 224, 28, 28] 62,080
│ │ └─Bottleneck: 3-11 [2, 256, 28, 28] 66,240
│ │ └─Bottleneck: 3-12 [2, 288, 28, 28] 70,400
│ │ └─Bottleneck: 3-13 [2, 320, 28, 28] 74,560
│ │ └─Bottleneck: 3-14 [2, 352, 28, 28] 78,720
│ │ └─Bottleneck: 3-15 [2, 384, 28, 28] 82,880
│ │ └─Bottleneck: 3-16 [2, 416, 28, 28] 87,040
│ │ └─Bottleneck: 3-17 [2, 448, 28, 28] 91,200
│ │ └─Bottleneck: 3-18 [2, 480, 28, 28] 95,360
│ │ └─Bottleneck: 3-19 [2, 512, 28, 28] 99,520
│ │ └─Transition: 3-20 [2, 256, 14, 14] 132,096
│ └─Sequential: 2-6 [2, 1152, 7, 7] --
│ │ └─Bottleneck: 3-21 [2, 288, 14, 14] 70,400
│ │ └─Bottleneck: 3-22 [2, 320, 14, 14] 74,560
│ │ └─Bottleneck: 3-23 [2, 352, 14, 14] 78,720
│ │ └─Bottleneck: 3-24 [2, 384, 14, 14] 82,880
│ │ └─Bottleneck: 3-25 [2, 416, 14, 14] 87,040
│ │ └─Bottleneck: 3-26 [2, 448, 14, 14] 91,200
│ │ └─Bottleneck: 3-27 [2, 480, 14, 14] 95,360
│ │ └─Bottleneck: 3-28 [2, 512, 14, 14] 99,520
│ │ └─Bottleneck: 3-29 [2, 544, 14, 14] 103,680
│ │ └─Bottleneck: 3-30 [2, 576, 14, 14] 107,840
│ │ └─Bottleneck: 3-31 [2, 608, 14, 14] 112,000
│ │ └─Bottleneck: 3-32 [2, 640, 14, 14] 116,160
│ │ └─Bottleneck: 3-33 [2, 672, 14, 14] 120,320
│ │ └─Bottleneck: 3-34 [2, 704, 14, 14] 124,480
│ │ └─Bottleneck: 3-35 [2, 736, 14, 14] 128,640
│ │ └─Bottleneck: 3-36 [2, 768, 14, 14] 132,800
│ │ └─Bottleneck: 3-37 [2, 800, 14, 14] 136,960
│ │ └─Bottleneck: 3-38 [2, 832, 14, 14] 141,120
│ │ └─Bottleneck: 3-39 [2, 864, 14, 14] 145,280
│ │ └─Bottleneck: 3-40 [2, 896, 14, 14] 149,440
│ │ └─Bottleneck: 3-41 [2, 928, 14, 14] 153,600
│ │ └─Bottleneck: 3-42 [2, 960, 14, 14] 157,760
│ │ └─Bottleneck: 3-43 [2, 992, 14, 14] 161,920
│ │ └─Bottleneck: 3-44 [2, 1024, 14, 14] 166,080
│ │ └─Bottleneck: 3-45 [2, 1056, 14, 14] 170,240
│ │ └─Bottleneck: 3-46 [2, 1088, 14, 14] 174,400
│ │ └─Bottleneck: 3-47 [2, 1120, 14, 14] 178,560
│ │ └─Bottleneck: 3-48 [2, 1152, 14, 14] 182,720
│ │ └─Bottleneck: 3-49 [2, 1184, 14, 14] 186,880
│ │ └─Bottleneck: 3-50 [2, 1216, 14, 14] 191,040
│ │ └─Bottleneck: 3-51 [2, 1248, 14, 14] 195,200
│ │ └─Bottleneck: 3-52 [2, 1280, 14, 14] 199,360
│ │ └─Bottleneck: 3-53 [2, 1312, 14, 14] 203,520
│ │ └─Bottleneck: 3-54 [2, 1344, 14, 14] 207,680
│ │ └─Bottleneck: 3-55 [2, 1376, 14, 14] 211,840
│ │ └─Bottleneck: 3-56 [2, 1408, 14, 14] 216,000
│ │ └─Bottleneck: 3-57 [2, 1440, 14, 14] 220,160
│ │ └─Bottleneck: 3-58 [2, 1472, 14, 14] 224,320
│ │ └─Bottleneck: 3-59 [2, 1504, 14, 14] 228,480
│ │ └─Bottleneck: 3-60 [2, 1536, 14, 14] 232,640
│ │ └─Bottleneck: 3-61 [2, 1568, 14, 14] 236,800
│ │ └─Bottleneck: 3-62 [2, 1600, 14, 14] 240,960
│ │ └─Bottleneck: 3-63 [2, 1632, 14, 14] 245,120
│ │ └─Bottleneck: 3-64 [2, 1664, 14, 14] 249,280
│ │ └─Bottleneck: 3-65 [2, 1696, 14, 14] 253,440
│ │ └─Bottleneck: 3-66 [2, 1728, 14, 14] 257,600
│ │ └─Bottleneck: 3-67 [2, 1760, 14, 14] 261,760
│ │ └─Bottleneck: 3-68 [2, 1792, 14, 14] 265,920
│ │ └─Bottleneck: 3-69 [2, 1824, 14, 14] 270,080
│ │ └─Bottleneck: 3-70 [2, 1856, 14, 14] 274,240
│ │ └─Bottleneck: 3-71 [2, 1888, 14, 14] 278,400
│ │ └─Bottleneck: 3-72 [2, 1920, 14, 14] 282,560
│ │ └─Bottleneck: 3-73 [2, 1952, 14, 14] 286,720
│ │ └─Bottleneck: 3-74 [2, 1984, 14, 14] 290,880
│ │ └─Bottleneck: 3-75 [2, 2016, 14, 14] 295,040
│ │ └─Bottleneck: 3-76 [2, 2048, 14, 14] 299,200
│ │ └─Bottleneck: 3-77 [2, 2080, 14, 14] 303,360
│ │ └─Bottleneck: 3-78 [2, 2112, 14, 14] 307,520
│ │ └─Bottleneck: 3-79 [2, 2144, 14, 14] 311,680
│ │ └─Bottleneck: 3-80 [2, 2176, 14, 14] 315,840
│ │ └─Bottleneck: 3-81 [2, 2208, 14, 14] 320,000
│ │ └─Bottleneck: 3-82 [2, 2240, 14, 14] 324,160
│ │ └─Bottleneck: 3-83 [2, 2272, 14, 14] 328,320
│ │ └─Bottleneck: 3-84 [2, 2304, 14, 14] 332,480
│ │ └─Transition: 3-85 [2, 1152, 7, 7] 2,658,816
│ └─Sequential: 2-7 [2, 2688, 7, 7] --
│ │ └─Bottleneck: 3-86 [2, 1184, 7, 7] 186,880
│ │ └─Bottleneck: 3-87 [2, 1216, 7, 7] 191,040
│ │ └─Bottleneck: 3-88 [2, 1248, 7, 7] 195,200
│ │ └─Bottleneck: 3-89 [2, 1280, 7, 7] 199,360
│ │ └─Bottleneck: 3-90 [2, 1312, 7, 7] 203,520
│ │ └─Bottleneck: 3-91 [2, 1344, 7, 7] 207,680
│ │ └─Bottleneck: 3-92 [2, 1376, 7, 7] 211,840
│ │ └─Bottleneck: 3-93 [2, 1408, 7, 7] 216,000
│ │ └─Bottleneck: 3-94 [2, 1440, 7, 7] 220,160
│ │ └─Bottleneck: 3-95 [2, 1472, 7, 7] 224,320
│ │ └─Bottleneck: 3-96 [2, 1504, 7, 7] 228,480
│ │ └─Bottleneck: 3-97 [2, 1536, 7, 7] 232,640
│ │ └─Bottleneck: 3-98 [2, 1568, 7, 7] 236,800
│ │ └─Bottleneck: 3-99 [2, 1600, 7, 7] 240,960
│ │ └─Bottleneck: 3-100 [2, 1632, 7, 7] 245,120
│ │ └─Bottleneck: 3-101 [2, 1664, 7, 7] 249,280
│ │ └─Bottleneck: 3-102 [2, 1696, 7, 7] 253,440
│ │ └─Bottleneck: 3-103 [2, 1728, 7, 7] 257,600
│ │ └─Bottleneck: 3-104 [2, 1760, 7, 7] 261,760
│ │ └─Bottleneck: 3-105 [2, 1792, 7, 7] 265,920
│ │ └─Bottleneck: 3-106 [2, 1824, 7, 7] 270,080
│ │ └─Bottleneck: 3-107 [2, 1856, 7, 7] 274,240
│ │ └─Bottleneck: 3-108 [2, 1888, 7, 7] 278,400
│ │ └─Bottleneck: 3-109 [2, 1920, 7, 7] 282,560
│ │ └─Bottleneck: 3-110 [2, 1952, 7, 7] 286,720
│ │ └─Bottleneck: 3-111 [2, 1984, 7, 7] 290,880
│ │ └─Bottleneck: 3-112 [2, 2016, 7, 7] 295,040
│ │ └─Bottleneck: 3-113 [2, 2048, 7, 7] 299,200
│ │ └─Bottleneck: 3-114 [2, 2080, 7, 7] 303,360
│ │ └─Bottleneck: 3-115 [2, 2112, 7, 7] 307,520
│ │ └─Bottleneck: 3-116 [2, 2144, 7, 7] 311,680
│ │ └─Bottleneck: 3-117 [2, 2176, 7, 7] 315,840
│ │ └─Bottleneck: 3-118 [2, 2208, 7, 7] 320,000
│ │ └─Bottleneck: 3-119 [2, 2240, 7, 7] 324,160
│ │ └─Bottleneck: 3-120 [2, 2272, 7, 7] 328,320
│ │ └─Bottleneck: 3-121 [2, 2304, 7, 7] 332,480
│ │ └─Bottleneck: 3-122 [2, 2336, 7, 7] 336,640
│ │ └─Bottleneck: 3-123 [2, 2368, 7, 7] 340,800
│ │ └─Bottleneck: 3-124 [2, 2400, 7, 7] 344,960
│ │ └─Bottleneck: 3-125 [2, 2432, 7, 7] 349,120
│ │ └─Bottleneck: 3-126 [2, 2464, 7, 7] 353,280
│ │ └─Bottleneck: 3-127 [2, 2496, 7, 7] 357,440
│ │ └─Bottleneck: 3-128 [2, 2528, 7, 7] 361,600
│ │ └─Bottleneck: 3-129 [2, 2560, 7, 7] 365,760
│ │ └─Bottleneck: 3-130 [2, 2592, 7, 7] 369,920
│ │ └─Bottleneck: 3-131 [2, 2624, 7, 7] 374,080
│ │ └─Bottleneck: 3-132 [2, 2656, 7, 7] 378,240
│ │ └─Bottleneck: 3-133 [2, 2688, 7, 7] 382,400
│ │ └─BatchNorm2d: 3-134 [2, 2688, 7, 7] 5,376
│ │ └─ReLU: 3-135 [2, 2688, 7, 7] --
├─AdaptiveAvgPool2d: 1-4 [2, 2688, 1, 1] --
├─Linear: 1-5 [2, 1000] 2,689,000
===============================================================================================
Total params: 33,337,704
Trainable params: 33,337,704
Non-trainable params: 0
Total mult-adds (G): 11.50
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 680.70
Params size (MB): 133.35
Estimated Total Size (MB): 815.26
===============================================================================================