[Pytorch] CSPNet 구현

도룩·2024년 3월 8일
1

목적

CSPNet을 이해하고 Pytorch로 구현할 수 있다.

Architecture

  • Network architecture
    CSP(Cross Stage Partial) Net은 기존의 ResNet, DenseNet 등 여러 backbone Network에 도입할 수 있다. 논문에서 첨부한 Figure 2를 보면서 확인해보자.
    Figure 2에서는 DenseNet과 CSPDenseNet를 보여준다.
    \\
    DenseNet은 DenseBlock과 Transition Layer로 이루어져 있다. DenseBlock에서 여러 DenseLayer를 거치면서 growth rate 만큼 feature map의 채널수가 점점 늘어나고, Transition Layer에서는 feature map의 크기와 채널 수를 조절한다.
    \\
    CSPDenseNet에서는 각 DenseBlock으로 들어가기 직전의 feature map의 채널 수를 기준으로 반으로 나눠서 두 개의 feature map으로 만든다. 하나는 DenseBlock으로 들어갈 feature map이고, 나머지 하나는 Transition Layer에서 concat할 feature map이다.
    \\
    Feature map의 shape을 (N, C, H, W) 라고 하고, DenseBlock으로 들어갈 feature map을 CSP_features_01, Transition Layer에서 concat할 feature map을 CSP_features_02라고 한다면 각각의 feature map 사이즈는 아래와 같을 것이다.
    \rightarrow CSP_features_01.shape = (N, C//2, H, W),
    \rightarrow CSP_features_02.shape = (N, C - C//2, H, W)
    \\
    CSP_features_01은 DenseBlock의 DenseLayer를 여러 개 통과하면서 (N, C//2 + growth_rate * ?, H, W)가 될 것이고, DenseBlock을 통과하면 CSP_features_02와 채널축 방향으로 concat한 뒤 Transition Layer를 통과할 것이다.
    이 과정을 반복하는 것이 CSPNet의 아이디어 이다.
    \\
    \\

특징 및 결과

  • 특징
    \\
    Cross Stage Partial Network

    Forward propagation 먼저 살펴보자.
    \\
    DenseNet의 forward propagation는 layer가 깊어짐에 따라 xix_i가 누적되어 convolution 연산을 수행한다.
    \\
    CSPDenseNet은 조금 다르다. 먼저 x0x_0을 채널축 기준으로 나누어(split) 2개의 feature map (x0x_0^{'}x0x_0^{''})으로 만든다. 수식을 보면 x0x_0^{''}는 Dense layer들을 통과하고 (xkx_k) transition layer를 거쳐 xTx_T가 된다. 이후 또 Dense layer를 통과하지 않은 x0x_0^{'}과 다시 채널 축 방향으로 concat한 뒤 또 다른 transition layer를 통과하여 xUx_U를 생성한다.
    \\
    이제 backward propagation 부분을 살펴보자.
    \\
    DenseNet의 backward propagation을 보면 많은 양의 gradient 정보가 다른 dense layer의 weight를 업데이트 하는 과정에서 재사용 되는 것을 볼 수 있다. 이는 서로 다른 dense layer가 copied gradient 정보를 반복적으로 학습한다는 것을 의미한다.
    \\
    CSPDenseNet에서는 gradient flow 상 g0g_0^{'}gTg_T를 이용해 wUw_U를 업데이트 하는 과정에서 볼 수 있듯이 서로 다른 쪽에 속하는 gradient를 이용함으로써 기존의 DenseNet과는 조금 다르다.
    \\
    물론 x0x_0^{'}가 통과하는 dense layer의 weight를 업데이트 할 때는 dense layer 특성상 gradient가 재사용 되는 부분이 분명 존재한다. 하지만 기존 feature map에서 반(half)의 채널 수를 갖는 feature map만 dense block에 통과시킴으로써 재사용 되는 gradient의 정보량은 DenseNet보다 적을 것이다.
    \\
    정리하자면, CSPDenseNet은 과도하게 재사용 되는 gradient flow를 중간에 잘라내고, 새로운 gradient flow를 추가함으로써 DenseNet의 특징인 feature map을 재사용한다는 특징은 최대한 보존하면서 과도하게 중복된 gradient 정보를 줄여 연산량을 줄이면서도 성능을 보존하거나 향상시킬 수 있었다.
    \\
    \\
    Different kind of feature fusion strategies

    CSPDenseNet은 Fusion First와 Fusion Last를 모두 접목한 (b)를 적용하였다.
    \\
    \\
    Applying CSPNet to ResNe(X)t

    ResNe(X)t에서도 채널 수가 커져 모델이 커지는 것을 우려해 50 layers 이상부터는 residual block으로 bottleneck을 사용했는데 CSPNet을 도입한다면 residual block에 들어가기 전 채널 수를 반으로 줄인 feature map이 통과하기 때문에 굳이 bottleneck 구조의 residual block을 사용할 필요가 없다.
    \\
    결과 1 - Image Classification
    CSPNet을 도입한 모델이 더 작은 BFLOPS를 보이면서 ACC가 유지되거나 약간 상승한 것을 볼 수 있다.
    \\
    \\
    결과 2 - Object detection

    마찬가지로 CSPNet을 도입한 모델이 더 작은 BFLOPS를 보였으며, AP도 더 높은 경향을 보였다.
    (특히 AP가 높을수록 다른 모델과 차별화된 성능을 보였음.)
    \\
    \\
    결과요약
    CSPNet은 기존의 네트워크 구조를 크게 변경하지 않으면서 모델의 연산량(BFLOPS)을 줄였으며 성능을 유지하거나 향상시켰다.
    Object detection 모델에서도 FPS와 mAP 측면에서 큰 효과를 볼 수 있었다.
    \\
    \\

Code

이전에 구현한 DenseNet에 CSPNet을 추가하여 CSPDenseNet을 구현하였다.
\\

환경

  • python 3.8.16
  • pytorch 2.1.0
  • torchinfo 1.8.0

구현

import torch
from torch import nn
from torchinfo import summary
class DenseLayer(nn.Module):
    def __init__(self, in_channels, k):
        super().__init__()

        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace = True),
            nn.Conv2d(in_channels, 4 * k, 1, bias = False),
            nn.BatchNorm2d(4 * k),
            nn.ReLU(inplace = True),
            nn.Conv2d(4 * k, k, 3, padding = 1, bias = False),
        )
    
    def forward(self, x):
        return torch.concat([self.residual(x), x], dim = 1)

class Transition(nn.Module):
    def __init__(self, in_channels, csp_transition = False):
        super().__init__()

        transition_layers = [
                nn.BatchNorm2d(in_channels),
                nn.ReLU(inplace = True),
                nn.Conv2d(in_channels, in_channels // 2, 1, bias = False),
        ]

        if csp_transition is not True:
            transition_layers.append(nn.AvgPool2d(2))
        
        self.transition = nn.Sequential(*transition_layers)
    
    def forward(self, x):
        return self.transition(x)

class CSPDenseBlock(nn.Module):
    def __init__(self, in_channels, num_blocks, k, last_stage = False):
        super().__init__()

        self.in_channels = in_channels
        csp_channels_01 = in_channels // 2
        csp_channels_02 = in_channels - csp_channels_01

        layers = []
        for _ in range(num_blocks):
            layers.append(DenseLayer(csp_channels_01, k))
            csp_channels_01 += k
        layers.append(Transition(csp_channels_01, csp_transition = True))
        csp_channels_01 //= 2
        self.dense_block = nn.Sequential(*layers)

        self.last = nn.Sequential(nn.BatchNorm2d(csp_channels_01 + csp_channels_02), nn.ReLU(inplace = True)) if last_stage else Transition(csp_channels_01 + csp_channels_02)
        self.channels = csp_channels_01 + csp_channels_02 if last_stage else (csp_channels_01 + csp_channels_02) // 2
    
    def forward(self, x):
        if self.in_channels % 2:
            csp_x_01 = x[:, self.in_channels // 2 + 1:, ...]
            csp_x_02 = x[:, :self.in_channels // 2 + 1, ...]
        else:
            csp_x_01 = x[:, self.in_channels // 2:, ...]
            csp_x_02 = x[:, :self.in_channels // 2, ...]

        csp_x_01 = self.dense_block(csp_x_01)
        csp_x = torch.cat([csp_x_01, csp_x_02], dim = 1)

        return self.last(csp_x)


class CSPDenseNet(nn.Module):
    def __init__(self, block_list, growth_rate, n_classes = 1000):
        super().__init__()

        assert len(block_list) == 4
        self.k = growth_rate

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 2 * self.k, 7, stride = 2, padding = 3, bias = False),
            nn.BatchNorm2d(2 * self.k),
            nn.ReLU(inplace = True),
        )
        self.maxpool = nn.MaxPool2d(3, stride = 2, padding = 1)

        self.dense_block_01 = CSPDenseBlock(2 * self.k, block_list[0], self.k)
        self.dense_block_02 = CSPDenseBlock(self.dense_block_01.channels, block_list[1], self.k)
        self.dense_block_03 = CSPDenseBlock(self.dense_block_02.channels, block_list[2], self.k)
        self.dense_block_04 = CSPDenseBlock(self.dense_block_03.channels, block_list[3], self.k, last_stage = True)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self.dense_block_04.channels, n_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.dense_block_01(x)
        x = self.dense_block_02(x)
        x = self.dense_block_03(x)
        x = self.dense_block_04(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
model = CSPDenseNet264()
summary(model, input_size = (2, 3, 224, 224), device = "cpu")
#### OUTPUT ####
===============================================================================================
Layer (type:depth-idx)                        Output Shape              Param #
===============================================================================================
CSPDenseNet                                   [2, 1000]                 --
├─Sequential: 1-1                             [2, 64, 112, 112]         --
│    └─Conv2d: 2-1                            [2, 64, 112, 112]         9,408
│    └─BatchNorm2d: 2-2                       [2, 64, 112, 112]         128
│    └─ReLU: 2-3                              [2, 64, 112, 112]         --
├─MaxPool2d: 1-2                              [2, 64, 56, 56]           --
├─CSPDenseBlock: 1-3                          [2, 72, 28, 28]           --
│    └─Sequential: 2-4                        [2, 112, 56, 56]          --
│    │    └─DenseLayer: 3-1                   [2, 64, 56, 56]           41,280
│    │    └─DenseLayer: 3-2                   [2, 96, 56, 56]           45,440
│    │    └─DenseLayer: 3-3                   [2, 128, 56, 56]          49,600
│    │    └─DenseLayer: 3-4                   [2, 160, 56, 56]          53,760
│    │    └─DenseLayer: 3-5                   [2, 192, 56, 56]          57,920
│    │    └─DenseLayer: 3-6                   [2, 224, 56, 56]          62,080
│    │    └─Transition: 3-7                   [2, 112, 56, 56]          25,536
│    └─Transition: 2-5                        [2, 72, 28, 28]           --
│    │    └─Sequential: 3-8                   [2, 72, 28, 28]           10,656
├─CSPDenseBlock: 1-4                          [2, 123, 14, 14]          --
│    └─Sequential: 2-6                        [2, 210, 28, 28]          --
│    │    └─DenseLayer: 3-9                   [2, 68, 28, 28]           41,800
│    │    └─DenseLayer: 3-10                  [2, 100, 28, 28]          45,960
│    │    └─DenseLayer: 3-11                  [2, 132, 28, 28]          50,120
│    │    └─DenseLayer: 3-12                  [2, 164, 28, 28]          54,280
│    │    └─DenseLayer: 3-13                  [2, 196, 28, 28]          58,440
│    │    └─DenseLayer: 3-14                  [2, 228, 28, 28]          62,600
│    │    └─DenseLayer: 3-15                  [2, 260, 28, 28]          66,760
│    │    └─DenseLayer: 3-16                  [2, 292, 28, 28]          70,920
│    │    └─DenseLayer: 3-17                  [2, 324, 28, 28]          75,080
│    │    └─DenseLayer: 3-18                  [2, 356, 28, 28]          79,240
│    │    └─DenseLayer: 3-19                  [2, 388, 28, 28]          83,400
│    │    └─DenseLayer: 3-20                  [2, 420, 28, 28]          87,560
│    │    └─Transition: 3-21                  [2, 210, 28, 28]          89,040
│    └─Transition: 2-7                        [2, 123, 14, 14]          --
│    │    └─Sequential: 3-22                  [2, 123, 14, 14]          30,750
├─CSPDenseBlock: 1-5                          [2, 558, 7, 7]            --
│    └─Sequential: 2-8                        [2, 1054, 14, 14]         --
│    │    └─DenseLayer: 3-23                  [2, 93, 14, 14]           45,050
│    │    └─DenseLayer: 3-24                  [2, 125, 14, 14]          49,210
│    │    └─DenseLayer: 3-25                  [2, 157, 14, 14]          53,370
│    │    └─DenseLayer: 3-26                  [2, 189, 14, 14]          57,530
│    │    └─DenseLayer: 3-27                  [2, 221, 14, 14]          61,690
│    │    └─DenseLayer: 3-28                  [2, 253, 14, 14]          65,850
│    │    └─DenseLayer: 3-29                  [2, 285, 14, 14]          70,010
│    │    └─DenseLayer: 3-30                  [2, 317, 14, 14]          74,170
│    │    └─DenseLayer: 3-31                  [2, 349, 14, 14]          78,330
│    │    └─DenseLayer: 3-32                  [2, 381, 14, 14]          82,490
│    │    └─DenseLayer: 3-33                  [2, 413, 14, 14]          86,650
│    │    └─DenseLayer: 3-34                  [2, 445, 14, 14]          90,810
│    │    └─DenseLayer: 3-35                  [2, 477, 14, 14]          94,970
│    │    └─DenseLayer: 3-36                  [2, 509, 14, 14]          99,130
│    │    └─DenseLayer: 3-37                  [2, 541, 14, 14]          103,290
│    │    └─DenseLayer: 3-38                  [2, 573, 14, 14]          107,450
│    │    └─DenseLayer: 3-39                  [2, 605, 14, 14]          111,610
│    │    └─DenseLayer: 3-40                  [2, 637, 14, 14]          115,770
│    │    └─DenseLayer: 3-41                  [2, 669, 14, 14]          119,930
│    │    └─DenseLayer: 3-42                  [2, 701, 14, 14]          124,090
│    │    └─DenseLayer: 3-43                  [2, 733, 14, 14]          128,250
│    │    └─DenseLayer: 3-44                  [2, 765, 14, 14]          132,410
│    │    └─DenseLayer: 3-45                  [2, 797, 14, 14]          136,570
│    │    └─DenseLayer: 3-46                  [2, 829, 14, 14]          140,730
│    │    └─DenseLayer: 3-47                  [2, 861, 14, 14]          144,890
│    │    └─DenseLayer: 3-48                  [2, 893, 14, 14]          149,050
│    │    └─DenseLayer: 3-49                  [2, 925, 14, 14]          153,210
│    │    └─DenseLayer: 3-50                  [2, 957, 14, 14]          157,370
│    │    └─DenseLayer: 3-51                  [2, 989, 14, 14]          161,530
│    │    └─DenseLayer: 3-52                  [2, 1021, 14, 14]         165,690
│    │    └─DenseLayer: 3-53                  [2, 1053, 14, 14]         169,850
│    │    └─DenseLayer: 3-54                  [2, 1085, 14, 14]         174,010
│    │    └─DenseLayer: 3-55                  [2, 1117, 14, 14]         178,170
│    │    └─DenseLayer: 3-56                  [2, 1149, 14, 14]         182,330
│    │    └─DenseLayer: 3-57                  [2, 1181, 14, 14]         186,490
│    │    └─DenseLayer: 3-58                  [2, 1213, 14, 14]         190,650
│    │    └─DenseLayer: 3-59                  [2, 1245, 14, 14]         194,810
│    │    └─DenseLayer: 3-60                  [2, 1277, 14, 14]         198,970
│    │    └─DenseLayer: 3-61                  [2, 1309, 14, 14]         203,130
│    │    └─DenseLayer: 3-62                  [2, 1341, 14, 14]         207,290
│    │    └─DenseLayer: 3-63                  [2, 1373, 14, 14]         211,450
│    │    └─DenseLayer: 3-64                  [2, 1405, 14, 14]         215,610
│    │    └─DenseLayer: 3-65                  [2, 1437, 14, 14]         219,770
│    │    └─DenseLayer: 3-66                  [2, 1469, 14, 14]         223,930
│    │    └─DenseLayer: 3-67                  [2, 1501, 14, 14]         228,090
│    │    └─DenseLayer: 3-68                  [2, 1533, 14, 14]         232,250
│    │    └─DenseLayer: 3-69                  [2, 1565, 14, 14]         236,410
│    │    └─DenseLayer: 3-70                  [2, 1597, 14, 14]         240,570
│    │    └─DenseLayer: 3-71                  [2, 1629, 14, 14]         244,730
│    │    └─DenseLayer: 3-72                  [2, 1661, 14, 14]         248,890
│    │    └─DenseLayer: 3-73                  [2, 1693, 14, 14]         253,050
│    │    └─DenseLayer: 3-74                  [2, 1725, 14, 14]         257,210
│    │    └─DenseLayer: 3-75                  [2, 1757, 14, 14]         261,370
│    │    └─DenseLayer: 3-76                  [2, 1789, 14, 14]         265,530
│    │    └─DenseLayer: 3-77                  [2, 1821, 14, 14]         269,690
│    │    └─DenseLayer: 3-78                  [2, 1853, 14, 14]         273,850
│    │    └─DenseLayer: 3-79                  [2, 1885, 14, 14]         278,010
│    │    └─DenseLayer: 3-80                  [2, 1917, 14, 14]         282,170
│    │    └─DenseLayer: 3-81                  [2, 1949, 14, 14]         286,330
│    │    └─DenseLayer: 3-82                  [2, 1981, 14, 14]         290,490
│    │    └─DenseLayer: 3-83                  [2, 2013, 14, 14]         294,650
│    │    └─DenseLayer: 3-84                  [2, 2045, 14, 14]         298,810
│    │    └─DenseLayer: 3-85                  [2, 2077, 14, 14]         302,970
│    │    └─DenseLayer: 3-86                  [2, 2109, 14, 14]         307,130
│    │    └─Transition: 3-87                  [2, 1054, 14, 14]         2,227,104
│    └─Transition: 2-9                        [2, 558, 7, 7]            --
│    │    └─Sequential: 3-88                  [2, 558, 7, 7]            624,960
├─CSPDenseBlock: 1-6                          [2, 1186, 7, 7]           --
│    └─Sequential: 2-10                       [2, 907, 7, 7]            --
│    │    └─DenseLayer: 3-89                  [2, 311, 7, 7]            73,390
│    │    └─DenseLayer: 3-90                  [2, 343, 7, 7]            77,550
│    │    └─DenseLayer: 3-91                  [2, 375, 7, 7]            81,710
│    │    └─DenseLayer: 3-92                  [2, 407, 7, 7]            85,870
│    │    └─DenseLayer: 3-93                  [2, 439, 7, 7]            90,030
│    │    └─DenseLayer: 3-94                  [2, 471, 7, 7]            94,190
│    │    └─DenseLayer: 3-95                  [2, 503, 7, 7]            98,350
│    │    └─DenseLayer: 3-96                  [2, 535, 7, 7]            102,510
│    │    └─DenseLayer: 3-97                  [2, 567, 7, 7]            106,670
│    │    └─DenseLayer: 3-98                  [2, 599, 7, 7]            110,830
│    │    └─DenseLayer: 3-99                  [2, 631, 7, 7]            114,990
│    │    └─DenseLayer: 3-100                 [2, 663, 7, 7]            119,150
│    │    └─DenseLayer: 3-101                 [2, 695, 7, 7]            123,310
│    │    └─DenseLayer: 3-102                 [2, 727, 7, 7]            127,470
│    │    └─DenseLayer: 3-103                 [2, 759, 7, 7]            131,630
│    │    └─DenseLayer: 3-104                 [2, 791, 7, 7]            135,790
│    │    └─DenseLayer: 3-105                 [2, 823, 7, 7]            139,950
│    │    └─DenseLayer: 3-106                 [2, 855, 7, 7]            144,110
│    │    └─DenseLayer: 3-107                 [2, 887, 7, 7]            148,270
│    │    └─DenseLayer: 3-108                 [2, 919, 7, 7]            152,430
│    │    └─DenseLayer: 3-109                 [2, 951, 7, 7]            156,590
│    │    └─DenseLayer: 3-110                 [2, 983, 7, 7]            160,750
│    │    └─DenseLayer: 3-111                 [2, 1015, 7, 7]           164,910
│    │    └─DenseLayer: 3-112                 [2, 1047, 7, 7]           169,070
│    │    └─DenseLayer: 3-113                 [2, 1079, 7, 7]           173,230
│    │    └─DenseLayer: 3-114                 [2, 1111, 7, 7]           177,390
│    │    └─DenseLayer: 3-115                 [2, 1143, 7, 7]           181,550
│    │    └─DenseLayer: 3-116                 [2, 1175, 7, 7]           185,710
│    │    └─DenseLayer: 3-117                 [2, 1207, 7, 7]           189,870
│    │    └─DenseLayer: 3-118                 [2, 1239, 7, 7]           194,030
│    │    └─DenseLayer: 3-119                 [2, 1271, 7, 7]           198,190
│    │    └─DenseLayer: 3-120                 [2, 1303, 7, 7]           202,350
│    │    └─DenseLayer: 3-121                 [2, 1335, 7, 7]           206,510
│    │    └─DenseLayer: 3-122                 [2, 1367, 7, 7]           210,670
│    │    └─DenseLayer: 3-123                 [2, 1399, 7, 7]           214,830
│    │    └─DenseLayer: 3-124                 [2, 1431, 7, 7]           218,990
│    │    └─DenseLayer: 3-125                 [2, 1463, 7, 7]           223,150
│    │    └─DenseLayer: 3-126                 [2, 1495, 7, 7]           227,310
│    │    └─DenseLayer: 3-127                 [2, 1527, 7, 7]           231,470
│    │    └─DenseLayer: 3-128                 [2, 1559, 7, 7]           235,630
│    │    └─DenseLayer: 3-129                 [2, 1591, 7, 7]           239,790
│    │    └─DenseLayer: 3-130                 [2, 1623, 7, 7]           243,950
│    │    └─DenseLayer: 3-131                 [2, 1655, 7, 7]           248,110
│    │    └─DenseLayer: 3-132                 [2, 1687, 7, 7]           252,270
│    │    └─DenseLayer: 3-133                 [2, 1719, 7, 7]           256,430
│    │    └─DenseLayer: 3-134                 [2, 1751, 7, 7]           260,590
│    │    └─DenseLayer: 3-135                 [2, 1783, 7, 7]           264,750
│    │    └─DenseLayer: 3-136                 [2, 1815, 7, 7]           268,910
│    │    └─Transition: 3-137                 [2, 907, 7, 7]            1,649,835
│    └─Sequential: 2-11                       [2, 1186, 7, 7]           --
│    │    └─BatchNorm2d: 3-138                [2, 1186, 7, 7]           2,372
│    │    └─ReLU: 3-139                       [2, 1186, 7, 7]           --
├─AdaptiveAvgPool2d: 1-7                      [2, 1186, 1, 1]           --
├─Linear: 1-8                                 [2, 1000]                 1,187,000
===============================================================================================
Total params: 26,427,989
Trainable params: 26,427,989
Non-trainable params: 0
Total mult-adds (G): 10.21
===============================================================================================
Input size (MB): 1.20
Forward/backward pass size (MB): 601.85
Params size (MB): 105.71
Estimated Total Size (MB): 708.76
===============================================================================================

0개의 댓글

관련 채용 정보