UNet(2015) 논문 리뷰

이성원·2023년 7월 12일
0
post-thumbnail

이 페이지는 Image Segmentaion에 관련된 UNet(2015)에 관한 내용입니다.


1. UNet 특징

이전의 모델들은 모델의 크기에 비해 학습 데이터가 적어서 성능 향상은 이루어지지만 어느정도 제한이 있었습니다.

기존의 Segmenation 방식은 다음과 같은 두 가지 취약점이 있습니다.

  • patch별로 연산을 하니까 연산 속도가 매우 느리며 서로 중복된 영역을 가지는 patch가 많아 중복된 예측 결과 발생

  • localization accuracy와 use of context 사이에 존재하는 trade-off 발생

patch의 크기가 크면 더 많은 max-pooling layer를 요구하여 localization accuracy가 떨어지고 patch의 크기가 작으면 litte context만 알 수 있다고 합니다.


2. 구조

UNET 은 다음과 같은 두 구조로 나뉘는데

위 이미지에서 절반을 기준으로 왼쪽을 Contracting path(수축) 오른쪽을 Expanding path(팽창)라고 합니다.

Contracting Path

contracting path는 우리가 아는 전형적인 CNN과 같습니다.

contracting path는 3x3 사이즈의 kernel을 가진 CNN(3x3 CNN)을 2번 적용하며 이 과정에서 활성화 함수로 ReLU(rectified linear unit)를 사용합니다.

그리고 stride = 2의 2x2 max pooling을 적용해 너비와 높이를 반으로 줄여버립니다(downsampling).

Expanding Path

contracting path와 좌우로 대칭되는 구조를 가집니다. 3x3 CNN을 2번 적용한 뒤 feature map의 너비와 높이를 2배로 늘리는(upsampling) 과정을 여러번 반복하는데 upsampling을 할 때마다 적용하는 CNN에 있는 kerel의 개수를 반으로 줄입니다.

이 때, Kernel의 개수를 반으로 줄인 CNN에 적용하기 전에 반대쪽 contracting path에서 같은 층에 있는 feature map과 합칩니다.

논문에서는 concatenation라고 표현했습니다.

이러한 구조 때문에 context 정보의 손실을 보완할 수 있게 됩니다.

이후 이어지는 Convolution layer는 이를 바탕으로 더 정확한 출력을 낼 수 있도록 학습됩니다.

이를 반복해 다양한 레이어의 출력을 이용하는 것은 위치정보와 context 정보 사용을 동시에 가능하게 했습니다.


3. 코드 구현

pytorch

import torch.nn as nn
import torch
from torchsummary import summary

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def CBR(in_channels, out_channels, kernel_size = 3, stride = 1):
            layers = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,kernel_size,stride,1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
            return layers

        self.enc1_1 = CBR(1, 64)
        self.enc1_2 = CBR(64, 64)
        self.pool1 = nn.MaxPool2d(2)

        self.enc2_1 = CBR(64, 128)
        self.enc2_2 = CBR(128, 128)
        self.pool2 = nn.MaxPool2d(2)

        self.enc3_1 = CBR(128, 256)
        self.enc3_2 = CBR(256, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.enc4_1 = CBR(256, 512)
        self.enc4_2 = CBR(512, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.enc5_1 = CBR(512, 1024)
        self.enc5_2 = CBR(1024, 1024)

        self.unpool4 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.dec4_2 = CBR(1024, 512)
        self.dec4_1 = CBR(512,512)

        self.unpool3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec3_2 = CBR(512, 256)
        self.dec3_1 = CBR(256, 256)

        self.unpool2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2_2 = CBR(256, 128)
        self.dec2_1 = CBR(128, 128)

        self.unpool1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1_2 = CBR(128, 64)
        self.dec1_1 = CBR(64, 64)

        self.result = nn.Conv2d(64,2,1,1,1)

    def forward(self, x):
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        unpool4 = self.unpool4(enc5_2)
        dec4_2 = self.dec4_2(torch.cat((unpool4, enc4_2), 1))
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool3(dec4_1)
        dec3_2 = self.dec3_2(torch.cat((unpool3, enc3_2), 1))
        dec3_1 = self.dec3_1(dec3_2)

        unpool2 = self.unpool2(dec3_1)
        dec2_2 = self.dec2_2(torch.cat((unpool2, enc2_2), 1))
        dec2_1 = self.dec2_1(dec2_2)

        unpool1 = self.unpool1(dec2_1)
        dec1_2 = self.dec1_2(torch.cat((unpool1, enc1_2), 1))
        dec1_1 = self.dec1_1(dec1_2)

        out = self.result(dec1_1)
        return out

model = UNet().cuda()

summary(model, (1,400,400))
profile
개발자

0개의 댓글