[Pytorch] DenseNet 구현

도룩·2023년 12월 10일
0

목적

DenseNet 을 이해하고 Pytorch로 구현할 수 있다.

Architecture

  • Network architecture

특징

  • Dense block
    Figure 1은 Dense Block을 나타낸다. growth rate는 4이다. 선들이 복잡하게 그려져 있는데 하나씩 살펴보자.

    1. Input에 x개의 채널를 가진 feature map이 있다고 가정. (위 그림에서 빨간색)

    2. BN-ReLU-Conv를 거치면서 4 개의 feature map을 뽑음. (위 그림에서 초록색)

    3. x+4x + 4 개의 feature map이 BN-ReLU-Conv를 통과해서 4개의 feature map을 뽑음.
      -> 빨간색(x) + 초록색(4) -> BN-ReLU-Conv -> 보라색(4)

    4. x+8x + 8 개의 feature map이 BN-ReLU-Conv를 통과해서 4개의 feature map을 뽑음.
      -> 빨간색(x) + 초록색(4) + 보라색(4) -> BN-ReLU-Conv -> 노란색(4)

    5. x+12x + 12 개의 feature map이 BN-ReLU-Conv를 통과해서 4개의 feature map을 뽑음.
      -> 빨간색(x) + 초록색(4) + 보라색(4) + 노란색(4) -> BN-ReLU-Conv -> 주황색(4)

      \\

    • 점점 Dense Block의 layer가 growth rate만큼 등차수열 형태로 증가하는 것을 알 수 있다. (xx \rightarrow x+4x+4 \rightarrow x+8x+8 \rightarrow x+12x+12 \rightarrow \cdots)

    • 그림에 잘 표현해놓았는데 Denseblock에서는 이전 feature map을 채널축으로 concat 하기 때문에 feature map의 크기 (size) 가 변하지 않는다. (채널 수만 growth rate 만큼 증가)

    • 구조상 층이 깊어질 수록 채널수는 점점 늘어나는데 각 out_channels는 growth rate로 고정되어 있기 때문에 점점 bottleneck이 심해짐. 따라서 중간에 1x1 Conv를 추가하여 이 현상을 어느정도 완화시킴. 또한 파라미터 수를 절약시키는 효과도 볼 수 있다.
      (in -> 1x1 Conv 4k -> 3x3 Conv k)

    • Dense Block이 깊어지고, 또 이 block이 여러 개가 존재할 때 모델의 파라미터 수는 어마어마 하게 커질 것이다.(feature map의 크기는 변하지 않으면서 채널 수는 계속 증가할거니깐) \rightarrow 그래서 Transition layer가 DenseBlock 중간중간에 존재한다.

    • CNN 구조상 언젠가는 size를 줄여야 하고 이 역할은 transition layer가 수행한다.

  • Transition layer

    Transition layer는 그림에서 보는 것 처럼 Conv-Pooling layer로 구성되어 있다.
    1. Conv layer는 채널수를 반으로 줄이는 역할을 한다.

    2. Pooling layer는 feature map의 가로, 세로를 각각 반으로 줄이는 역할을 한다.

      \\

    • 실제 코드에서는 BN-ReLU가 먼저 등장하고 이후에 Conv-Pooling으로 구성되어 있다.
      그 이유는 DenseBlock을 구성하는 layer가 BN-ReLU-Conv 순이기 때문에 DenseBlock의 마지막 layer는 Conv 이다. 따라서 DenseBlock의 마지막 Conv후 Transition layer의 Conv를 하기 전에 BN, ReLU를 먼저 통과 시켜준다.
      -> Transition layer: BN-ReLU-Conv-Pooling
  • ResNet과의 차이점

    1. DenseNet의 DenseBlock은 xx를 채널축 방향으로 쌓기 때문에 이전 정보를 그대로 보존할 수 있다.
      \rightarrow ResNet은 xx를 더하기(plus) 연산 때문에 DenseBlock에 비해 정보가 뭉개지게 된다.

    2. ResNet 이후에 나온 연구인 full pre-activation을 사용했다.
      \rightarrow Conv-Batch-ReLU 였던 순서가 BN-ReLU-Conv로 바뀌었다.

  • 결과

    DenseNet이 ResNet보다 더 적은 파라미터로 더 우수한 성능. (왼쪽)
    파라미터 수가 적기 때문에 inference 속도도 ResNet보다 더 빠르다. (오른쪽)

Code

환경

  • python 3.8.16
  • pytorch 2.1.0
  • torchinfo 1.8.0

구현

import torch
from torch import nn
from torchinfo import summary
class Bottleneck(nn.Module):
    def __init__(self, in_channels, k):
        super().__init__()

        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, 4 * k, 1, bias = False),
            nn.BatchNorm2d(4 * k),
            nn.ReLU(),
            nn.Conv2d(4 * k, k, 3, padding = 1, bias = False),
        )
        
    def forward(self, x):
        return torch.cat([self.residual(x), x], dim = 1)
    
class Transition(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.transition = nn.Sequential(
            nn.BatchNorm2d(in_channels), # Dense block end: Conv
            nn.ReLU(),
            nn.Conv2d(in_channels, int(in_channels / 2), 1, bias = False), # Reduce channels
            nn.AvgPool2d(2),    # Reduce feature map size
        )
    
    def forward(self, x):
        return self.transition(x)

class DenseNet(nn.Module):
    def __init__(self, block_list, growth_rate, n_classes = 1000):
        super().__init__()

        assert len(block_list) == 4

        self.k = growth_rate

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 2 * self.k, 7, stride = 2, padding = 3, bias = False),
            nn.BatchNorm2d(2 * self.k),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)

        self.dense_channels = 2 * self.k
        dense_blocks = []
        dense_blocks.append(self.make_dense_block(block_list[0]))
        dense_blocks.append(self.make_dense_block(block_list[1]))
        dense_blocks.append(self.make_dense_block(block_list[2]))
        dense_blocks.append(self.make_dense_block(block_list[3], last_stage = True))
        self.dense_blocks = nn.Sequential(*dense_blocks)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.dense_channels, n_classes)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.dense_blocks(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    
    def make_dense_block(self, num_blocks, last_stage = False):
        layers = []
        for _ in range(num_blocks):
            layer = Bottleneck(self.dense_channels, self.k)
            layers.append(layer)
            self.dense_channels += self.k

        if last_stage:
            layers.append(nn.BatchNorm2d(self.dense_channels))
            layers.append(nn.ReLU())
        else:
            layers.append(Transition(self.dense_channels))
            assert self.dense_channels % 2 == 0
            self.dense_channels //= 2
            
        return nn.Sequential(*layers)
def DenseNet121():
    return DenseNet(block_list = [6, 12, 24, 16], growth_rate = 32)

def DenseNet169():
    return DenseNet(block_list = [6, 12, 32, 32], growth_rate = 32)

def DenseNet201():
    return DenseNet(block_list = [6, 12, 48, 32], growth_rate = 32)

def DenseNet264():
    return DenseNet(block_list = [6, 12, 64, 48], growth_rate = 32)
model = DenseNet264()
summary(model, input_size = (2, 3, 224, 224), device = "cpu")
#### OUTPUT ####
===============================================================================================
Layer (type:depth-idx)                        Output Shape              Param #
===============================================================================================
DenseNet                                      [2, 1000]                 --
├─Sequential: 1-1                             [2, 64, 112, 112]         --
│    └─Conv2d: 2-1                            [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                       [2, 64, 112, 112]         128
│    └─ReLU: 2-3                              [2, 64, 112, 112]         --
├─MaxPool2d: 1-2                              [2, 64, 56, 56]           --
├─Sequential: 1-3                             [2, 2688, 7, 7]           --
│    └─Sequential: 2-4                        [2, 128, 28, 28]          --
│    │    └─Bottleneck: 3-1                   [2, 96, 56, 56]           45,440
│    │    └─Bottleneck: 3-2                   [2, 128, 56, 56]          49,600
│    │    └─Bottleneck: 3-3                   [2, 160, 56, 56]          53,760
│    │    └─Bottleneck: 3-4                   [2, 192, 56, 56]          57,920
│    │    └─Bottleneck: 3-5                   [2, 224, 56, 56]          62,080
│    │    └─Bottleneck: 3-6                   [2, 256, 56, 56]          66,240
│    │    └─Transition: 3-7                   [2, 128, 28, 28]          33,280
│    └─Sequential: 2-5                        [2, 256, 14, 14]          --
│    │    └─Bottleneck: 3-8                   [2, 160, 28, 28]          53,760
│    │    └─Bottleneck: 3-9                   [2, 192, 28, 28]          57,920
│    │    └─Bottleneck: 3-10                  [2, 224, 28, 28]          62,080
│    │    └─Bottleneck: 3-11                  [2, 256, 28, 28]          66,240
│    │    └─Bottleneck: 3-12                  [2, 288, 28, 28]          70,400
│    │    └─Bottleneck: 3-13                  [2, 320, 28, 28]          74,560
│    │    └─Bottleneck: 3-14                  [2, 352, 28, 28]          78,720
│    │    └─Bottleneck: 3-15                  [2, 384, 28, 28]          82,880
│    │    └─Bottleneck: 3-16                  [2, 416, 28, 28]          87,040
│    │    └─Bottleneck: 3-17                  [2, 448, 28, 28]          91,200
│    │    └─Bottleneck: 3-18                  [2, 480, 28, 28]          95,360
│    │    └─Bottleneck: 3-19                  [2, 512, 28, 28]          99,520
│    │    └─Transition: 3-20                  [2, 256, 14, 14]          132,096
│    └─Sequential: 2-6                        [2, 1152, 7, 7]           --
│    │    └─Bottleneck: 3-21                  [2, 288, 14, 14]          70,400
│    │    └─Bottleneck: 3-22                  [2, 320, 14, 14]          74,560
│    │    └─Bottleneck: 3-23                  [2, 352, 14, 14]          78,720
│    │    └─Bottleneck: 3-24                  [2, 384, 14, 14]          82,880
│    │    └─Bottleneck: 3-25                  [2, 416, 14, 14]          87,040
│    │    └─Bottleneck: 3-26                  [2, 448, 14, 14]          91,200
│    │    └─Bottleneck: 3-27                  [2, 480, 14, 14]          95,360
│    │    └─Bottleneck: 3-28                  [2, 512, 14, 14]          99,520
│    │    └─Bottleneck: 3-29                  [2, 544, 14, 14]          103,680
│    │    └─Bottleneck: 3-30                  [2, 576, 14, 14]          107,840
│    │    └─Bottleneck: 3-31                  [2, 608, 14, 14]          112,000
│    │    └─Bottleneck: 3-32                  [2, 640, 14, 14]          116,160
│    │    └─Bottleneck: 3-33                  [2, 672, 14, 14]          120,320
│    │    └─Bottleneck: 3-34                  [2, 704, 14, 14]          124,480
│    │    └─Bottleneck: 3-35                  [2, 736, 14, 14]          128,640
│    │    └─Bottleneck: 3-36                  [2, 768, 14, 14]          132,800
│    │    └─Bottleneck: 3-37                  [2, 800, 14, 14]          136,960
│    │    └─Bottleneck: 3-38                  [2, 832, 14, 14]          141,120
│    │    └─Bottleneck: 3-39                  [2, 864, 14, 14]          145,280
│    │    └─Bottleneck: 3-40                  [2, 896, 14, 14]          149,440
│    │    └─Bottleneck: 3-41                  [2, 928, 14, 14]          153,600
│    │    └─Bottleneck: 3-42                  [2, 960, 14, 14]          157,760
│    │    └─Bottleneck: 3-43                  [2, 992, 14, 14]          161,920
│    │    └─Bottleneck: 3-44                  [2, 1024, 14, 14]         166,080
│    │    └─Bottleneck: 3-45                  [2, 1056, 14, 14]         170,240
│    │    └─Bottleneck: 3-46                  [2, 1088, 14, 14]         174,400
│    │    └─Bottleneck: 3-47                  [2, 1120, 14, 14]         178,560
│    │    └─Bottleneck: 3-48                  [2, 1152, 14, 14]         182,720
│    │    └─Bottleneck: 3-49                  [2, 1184, 14, 14]         186,880
│    │    └─Bottleneck: 3-50                  [2, 1216, 14, 14]         191,040
│    │    └─Bottleneck: 3-51                  [2, 1248, 14, 14]         195,200
│    │    └─Bottleneck: 3-52                  [2, 1280, 14, 14]         199,360
│    │    └─Bottleneck: 3-53                  [2, 1312, 14, 14]         203,520
│    │    └─Bottleneck: 3-54                  [2, 1344, 14, 14]         207,680
│    │    └─Bottleneck: 3-55                  [2, 1376, 14, 14]         211,840
│    │    └─Bottleneck: 3-56                  [2, 1408, 14, 14]         216,000
│    │    └─Bottleneck: 3-57                  [2, 1440, 14, 14]         220,160
│    │    └─Bottleneck: 3-58                  [2, 1472, 14, 14]         224,320
│    │    └─Bottleneck: 3-59                  [2, 1504, 14, 14]         228,480
│    │    └─Bottleneck: 3-60                  [2, 1536, 14, 14]         232,640
│    │    └─Bottleneck: 3-61                  [2, 1568, 14, 14]         236,800
│    │    └─Bottleneck: 3-62                  [2, 1600, 14, 14]         240,960
│    │    └─Bottleneck: 3-63                  [2, 1632, 14, 14]         245,120
│    │    └─Bottleneck: 3-64                  [2, 1664, 14, 14]         249,280
│    │    └─Bottleneck: 3-65                  [2, 1696, 14, 14]         253,440
│    │    └─Bottleneck: 3-66                  [2, 1728, 14, 14]         257,600
│    │    └─Bottleneck: 3-67                  [2, 1760, 14, 14]         261,760
│    │    └─Bottleneck: 3-68                  [2, 1792, 14, 14]         265,920
│    │    └─Bottleneck: 3-69                  [2, 1824, 14, 14]         270,080
│    │    └─Bottleneck: 3-70                  [2, 1856, 14, 14]         274,240
│    │    └─Bottleneck: 3-71                  [2, 1888, 14, 14]         278,400
│    │    └─Bottleneck: 3-72                  [2, 1920, 14, 14]         282,560
│    │    └─Bottleneck: 3-73                  [2, 1952, 14, 14]         286,720
│    │    └─Bottleneck: 3-74                  [2, 1984, 14, 14]         290,880
│    │    └─Bottleneck: 3-75                  [2, 2016, 14, 14]         295,040
│    │    └─Bottleneck: 3-76                  [2, 2048, 14, 14]         299,200
│    │    └─Bottleneck: 3-77                  [2, 2080, 14, 14]         303,360
│    │    └─Bottleneck: 3-78                  [2, 2112, 14, 14]         307,520
│    │    └─Bottleneck: 3-79                  [2, 2144, 14, 14]         311,680
│    │    └─Bottleneck: 3-80                  [2, 2176, 14, 14]         315,840
│    │    └─Bottleneck: 3-81                  [2, 2208, 14, 14]         320,000
│    │    └─Bottleneck: 3-82                  [2, 2240, 14, 14]         324,160
│    │    └─Bottleneck: 3-83                  [2, 2272, 14, 14]         328,320
│    │    └─Bottleneck: 3-84                  [2, 2304, 14, 14]         332,480
│    │    └─Transition: 3-85                  [2, 1152, 7, 7]           2,658,816
│    └─Sequential: 2-7                        [2, 2688, 7, 7]           --
│    │    └─Bottleneck: 3-86                  [2, 1184, 7, 7]           186,880
│    │    └─Bottleneck: 3-87                  [2, 1216, 7, 7]           191,040
│    │    └─Bottleneck: 3-88                  [2, 1248, 7, 7]           195,200
│    │    └─Bottleneck: 3-89                  [2, 1280, 7, 7]           199,360
│    │    └─Bottleneck: 3-90                  [2, 1312, 7, 7]           203,520
│    │    └─Bottleneck: 3-91                  [2, 1344, 7, 7]           207,680
│    │    └─Bottleneck: 3-92                  [2, 1376, 7, 7]           211,840
│    │    └─Bottleneck: 3-93                  [2, 1408, 7, 7]           216,000
│    │    └─Bottleneck: 3-94                  [2, 1440, 7, 7]           220,160
│    │    └─Bottleneck: 3-95                  [2, 1472, 7, 7]           224,320
│    │    └─Bottleneck: 3-96                  [2, 1504, 7, 7]           228,480
│    │    └─Bottleneck: 3-97                  [2, 1536, 7, 7]           232,640
│    │    └─Bottleneck: 3-98                  [2, 1568, 7, 7]           236,800
│    │    └─Bottleneck: 3-99                  [2, 1600, 7, 7]           240,960
│    │    └─Bottleneck: 3-100                 [2, 1632, 7, 7]           245,120
│    │    └─Bottleneck: 3-101                 [2, 1664, 7, 7]           249,280
│    │    └─Bottleneck: 3-102                 [2, 1696, 7, 7]           253,440
│    │    └─Bottleneck: 3-103                 [2, 1728, 7, 7]           257,600
│    │    └─Bottleneck: 3-104                 [2, 1760, 7, 7]           261,760
│    │    └─Bottleneck: 3-105                 [2, 1792, 7, 7]           265,920
│    │    └─Bottleneck: 3-106                 [2, 1824, 7, 7]           270,080
│    │    └─Bottleneck: 3-107                 [2, 1856, 7, 7]           274,240
│    │    └─Bottleneck: 3-108                 [2, 1888, 7, 7]           278,400
│    │    └─Bottleneck: 3-109                 [2, 1920, 7, 7]           282,560
│    │    └─Bottleneck: 3-110                 [2, 1952, 7, 7]           286,720
│    │    └─Bottleneck: 3-111                 [2, 1984, 7, 7]           290,880
│    │    └─Bottleneck: 3-112                 [2, 2016, 7, 7]           295,040
│    │    └─Bottleneck: 3-113                 [2, 2048, 7, 7]           299,200
│    │    └─Bottleneck: 3-114                 [2, 2080, 7, 7]           303,360
│    │    └─Bottleneck: 3-115                 [2, 2112, 7, 7]           307,520
│    │    └─Bottleneck: 3-116                 [2, 2144, 7, 7]           311,680
│    │    └─Bottleneck: 3-117                 [2, 2176, 7, 7]           315,840
│    │    └─Bottleneck: 3-118                 [2, 2208, 7, 7]           320,000
│    │    └─Bottleneck: 3-119                 [2, 2240, 7, 7]           324,160
│    │    └─Bottleneck: 3-120                 [2, 2272, 7, 7]           328,320
│    │    └─Bottleneck: 3-121                 [2, 2304, 7, 7]           332,480
│    │    └─Bottleneck: 3-122                 [2, 2336, 7, 7]           336,640
│    │    └─Bottleneck: 3-123                 [2, 2368, 7, 7]           340,800
│    │    └─Bottleneck: 3-124                 [2, 2400, 7, 7]           344,960
│    │    └─Bottleneck: 3-125                 [2, 2432, 7, 7]           349,120
│    │    └─Bottleneck: 3-126                 [2, 2464, 7, 7]           353,280
│    │    └─Bottleneck: 3-127                 [2, 2496, 7, 7]           357,440
│    │    └─Bottleneck: 3-128                 [2, 2528, 7, 7]           361,600
│    │    └─Bottleneck: 3-129                 [2, 2560, 7, 7]           365,760
│    │    └─Bottleneck: 3-130                 [2, 2592, 7, 7]           369,920
│    │    └─Bottleneck: 3-131                 [2, 2624, 7, 7]           374,080
│    │    └─Bottleneck: 3-132                 [2, 2656, 7, 7]           378,240
│    │    └─Bottleneck: 3-133                 [2, 2688, 7, 7]           382,400
│    │    └─BatchNorm2d: 3-134                [2, 2688, 7, 7]           5,376
│    │    └─ReLU: 3-135                       [2, 2688, 7, 7]           --
├─AdaptiveAvgPool2d: 1-4                      [2, 2688, 1, 1]           --
├─Linear: 1-5                                 [2, 1000]                 2,689,000
===============================================================================================
Total params: 33,337,704
Trainable params: 33,337,704
Non-trainable params: 0
Total mult-adds (G): 11.50
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 680.70
Params size (MB): 133.35
Estimated Total Size (MB): 815.26
===============================================================================================

0개의 댓글

관련 채용 정보