이 페이지는 Image Segmentaion에 관련된 UNet(2015)에 관한 내용입니다.
이전의 모델들은 모델의 크기에 비해 학습 데이터가 적어서 성능 향상은 이루어지지만 어느정도 제한이 있었습니다.
기존의 Segmenation 방식은 다음과 같은 두 가지 취약점이 있습니다.
patch별로 연산을 하니까 연산 속도가 매우 느리며 서로 중복된 영역을 가지는 patch가 많아 중복된 예측 결과 발생
localization accuracy와 use of context 사이에 존재하는 trade-off 발생
patch의 크기가 크면 더 많은 max-pooling layer를 요구하여 localization accuracy가 떨어지고 patch의 크기가 작으면 litte context만 알 수 있다고 합니다.
UNET 은 다음과 같은 두 구조로 나뉘는데
위 이미지에서 절반을 기준으로 왼쪽을 Contracting path(수축) 오른쪽을 Expanding path(팽창)라고 합니다.
contracting path는 우리가 아는 전형적인 CNN과 같습니다.
contracting path는 3x3 사이즈의 kernel을 가진 CNN(3x3 CNN)을 2번 적용하며 이 과정에서 활성화 함수로 ReLU(rectified linear unit)를 사용합니다.
그리고 stride = 2의 2x2 max pooling을 적용해 너비와 높이를 반으로 줄여버립니다(downsampling).
contracting path와 좌우로 대칭되는 구조를 가집니다. 3x3 CNN을 2번 적용한 뒤 feature map의 너비와 높이를 2배로 늘리는(upsampling) 과정을 여러번 반복하는데 upsampling을 할 때마다 적용하는 CNN에 있는 kerel의 개수를 반으로 줄입니다.
이 때, Kernel의 개수를 반으로 줄인 CNN에 적용하기 전에 반대쪽 contracting path에서 같은 층에 있는 feature map과 합칩니다.
논문에서는 concatenation라고 표현했습니다.
이러한 구조 때문에 context 정보의 손실을 보완할 수 있게 됩니다.
이후 이어지는 Convolution layer는 이를 바탕으로 더 정확한 출력을 낼 수 있도록 학습됩니다.
이를 반복해 다양한 레이어의 출력을 이용하는 것은 위치정보와 context 정보 사용을 동시에 가능하게 했습니다.
import torch.nn as nn
import torch
from torchsummary import summary
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
def CBR(in_channels, out_channels, kernel_size = 3, stride = 1):
layers = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size,stride,1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
return layers
self.enc1_1 = CBR(1, 64)
self.enc1_2 = CBR(64, 64)
self.pool1 = nn.MaxPool2d(2)
self.enc2_1 = CBR(64, 128)
self.enc2_2 = CBR(128, 128)
self.pool2 = nn.MaxPool2d(2)
self.enc3_1 = CBR(128, 256)
self.enc3_2 = CBR(256, 256)
self.pool3 = nn.MaxPool2d(2)
self.enc4_1 = CBR(256, 512)
self.enc4_2 = CBR(512, 512)
self.pool4 = nn.MaxPool2d(2)
self.enc5_1 = CBR(512, 1024)
self.enc5_2 = CBR(1024, 1024)
self.unpool4 = nn.ConvTranspose2d(1024, 512, 2, 2)
self.dec4_2 = CBR(1024, 512)
self.dec4_1 = CBR(512,512)
self.unpool3 = nn.ConvTranspose2d(512, 256, 2, 2)
self.dec3_2 = CBR(512, 256)
self.dec3_1 = CBR(256, 256)
self.unpool2 = nn.ConvTranspose2d(256, 128, 2, 2)
self.dec2_2 = CBR(256, 128)
self.dec2_1 = CBR(128, 128)
self.unpool1 = nn.ConvTranspose2d(128, 64, 2, 2)
self.dec1_2 = CBR(128, 64)
self.dec1_1 = CBR(64, 64)
self.result = nn.Conv2d(64,2,1,1,1)
def forward(self, x):
enc1_1 = self.enc1_1(x)
enc1_2 = self.enc1_2(enc1_1)
pool1 = self.pool1(enc1_2)
enc2_1 = self.enc2_1(pool1)
enc2_2 = self.enc2_2(enc2_1)
pool2 = self.pool2(enc2_2)
enc3_1 = self.enc3_1(pool2)
enc3_2 = self.enc3_2(enc3_1)
pool3 = self.pool3(enc3_2)
enc4_1 = self.enc4_1(pool3)
enc4_2 = self.enc4_2(enc4_1)
pool4 = self.pool4(enc4_2)
enc5_1 = self.enc5_1(pool4)
enc5_2 = self.enc5_2(enc5_1)
unpool4 = self.unpool4(enc5_2)
dec4_2 = self.dec4_2(torch.cat((unpool4, enc4_2), 1))
dec4_1 = self.dec4_1(dec4_2)
unpool3 = self.unpool3(dec4_1)
dec3_2 = self.dec3_2(torch.cat((unpool3, enc3_2), 1))
dec3_1 = self.dec3_1(dec3_2)
unpool2 = self.unpool2(dec3_1)
dec2_2 = self.dec2_2(torch.cat((unpool2, enc2_2), 1))
dec2_1 = self.dec2_1(dec2_2)
unpool1 = self.unpool1(dec2_1)
dec1_2 = self.dec1_2(torch.cat((unpool1, enc1_2), 1))
dec1_1 = self.dec1_1(dec1_2)
out = self.result(dec1_1)
return out
model = UNet().cuda()
summary(model, (1,400,400))