UNET3+ 논문 리뷰

김태훈·2023년 6월 23일
0

본 페이지에서는 UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation논문에 대해서 말하고자 합니다.


1. UNET 3+의 특징

기존의 UNET구조는 encoder-decoder 구조를 가진 medical image segmentation 모델이다.

UNET++는 skip connection 구조를 재구성하여 성능을 향상시켰지만 full scale의 측면에서의 정보를 얻지는 못한다.

UNET과 UNET++ 둘다 skip connection 구조를 통해 의미론적 차이를 줄일 수 있었지만 아직은 부족하다.

UNET 3+는 full scale skip connection을 통해 저수준의 정보와 고수준의 정보를 모두 혼합했지만 더 적은 수의 파라미터를 요구한다.

deep supervision을 사용해 높은 정확도를 얻었다.

추가적으로 hybrid loss와 classification guided module을 사용해 성능을 올렸다.

2. 구조

UNET3+의 구조는 Fig1의 C와 같다.

이를 식으로 표현하면 아래와 같다.

여기서 C는 Conv operation을 말하는 것이고 H는 Conv->BN->ReLU를 의미하고 D,U는 down-up sample을 나타낸다

이때 Decoder 블록의 X3을 만드는 과정을 시각적으로 표현하면 다음과 같다

같은 수준의 Encoder 블록인 X3는 3x3 conv를 통해 64개의 채널의 수로 직접적으로 연결되고

높은 수준의 Encoder 블록인 X1,X2는 각각 4,2의 scale로 MaxPooling을 한 후 3x3 Conv 연산을 통해 64개의 채널로 만들어 연결하고

낮은 수준의 Decoder 블록인 X4,X5는 각각 2,4 의 scale로 Upsample을 한 후 3x3 Conv 연산을 통해 64개의 채널로 만들어 연결한다

이후 모든 블록을 Concatenate연산을 한 후 3x3 Conv를 통해 320(64x5)개의 채널을 만든 후 BN ReLU를 적용한다.

이러한 구조 덕분에 파라미터 수가 줄게 되었다.

2.1 Intra connection

저 수준의 정보는 경계정보와 같은 spatial information을 얻어내고

높은 수준의 정보는 위치정보를 구체화 한다.

이러한 저 수준의 정보와 높은 수준의 정보를 모두 활용해 더 확실한 정보를 얻기위해 intra skip connection을 만들었다.

2.2 Classification Guided Module(CGM)

Classification 과정 처럼 물체가 있는지 없는지 확인하기 위해서 UNET3+는 마지막 Encoder 블럭에 Classification Guided Module을 추가하였다.

각각의 과정은 다음과 같은데

DropOut을 0.5의 비율로 한다.

1x1 conv를 통해 두개의 채널로 줄여준다(물체가 있다 없다를 판단하기 위해)

AdaptiveMaxPool을 적용한다(keras에서는 GlobalMaxPool을 적용)

이후 Sigmoid를 각각 적용한다.(왜 SoftMax가 아닌지 이해가 안됨)

Argmax를 통해 0또는 1로 변환한다.

이를 각 deep supervision의 결과에 곱해준다.

이를 통해 물체가 없다고 판단하면 모두 0이 되어 False positive일 확률이 줄어든다.

2.3 Deep supervision

UNET++와는 다르게 Deep Supervison을 모든 Decoder 블럭에 대해서 적용하고

각각의 Decoder 블럭은 3x3 conv를 적용한 후 bilinear upsample 이후에 sigmoid 함수를 적용한다.

2.4 Hybrid loss

애매한 경계에 높은 가중치를 주기위해 유사도를 확인하는 MS-SSIM을 Loss함수로 두었다.

Focal Loss,MS-SSIM loss,IoU Loss를 같이 사용하면서 pixel,patch,map level의 세가지를 확인했다.

3. 코드구현

Keras

import keras

def conv(input,features,kernel_size=3,strides = 1,padding='same',is_relu=True,is_bn=False):
  x= keras.layers.Conv2D(features,kernel_size,strides,padding)(input) 
  if is_bn:
    x = keras.layers.BatchNormalization()(x)
  if is_relu:
    x = keras.activations.relu(x)
  return x


def unet3(n_levels,DSV=True, initial_features=64, n_blocks=2, kernel_size=3, pooling_size=2, in_channels=1, out_channels=1):
  inputs = keras.layers.Input(shape=(400, 400, in_channels))
  x = inputs
  
  #인코더부분
  skips = {}
  output = []
  for level in range(n_levels):
    if level != 0 :
      x = keras.layers.MaxPool2D(pooling_size,pooling_size)(x)
    for _ in range(n_blocks):
      x = conv(x,initial_features * 2 ** level,3,1,'same')
    skips[level] = x 

    if level == n_levels-1:
      output.append(x)

  #디코더 부분
  for level in reversed(range(n_levels-1)): 
    list_concat = []

    #위에 레벨+ 동일레벨
    for i in range(level+1):
      x = skips[i]
      x = keras.layers.MaxPool2D(pooling_size**(level-i),pooling_size**(level-i))(x)
      x = conv(x,initial_features,3,1,'same')
      list_concat.append(x)

    #아래 레벨
    for i in range(level+1,n_levels):
      x = output[n_levels-i-1]
      x = keras.layers.UpSampling2D(pooling_size**(i-level), interpolation='bilinear')(x)
      x = conv(x,initial_features,3,1,'same')
      list_concat.append(x)

    #concat 부분
    x = keras.layers.concatenate(list_concat)
    for i in range(2): 
      x = conv(x,initial_features * n_levels,3,1,'same',is_bn = True)
    output.append(x)

  # 출력부분
  result = []
  if DSV:
    for i in range(n_levels):
      output[i] = conv(output[i],out_channels,3,1,'same',False,False)
      output[i] = keras.layers.UpSampling2D(2**(n_levels-i-1), interpolation='bilinear')(output[i])
      result.append(output[i])
  else:
    result.append(conv(output[n_levels-1],out_channels,3,1,'same',False,False))

  for i in range(len(result)):
    if out_channels == 1:
      result[i] = keras.activations.sigmoid(result[i])
    else:
      result[i] = keras.activations.softmax(result[i])


  #이름 설정
  output_name=f'UNET3-L{n_levels}-F{initial_features}'
  if DSV:
    output_name+='-DSV'

  return keras.Model(inputs=[inputs], outputs=result, name=output_name)


model = unet3(5,True)

PyTorch

import torch
import torch.nn as nn
from torchsummary import summary

class UNet3(nn.Module):
    def __init__(self,DSV=True):
        super(UNet3,self).__init__()

        self.DSV = DSV
        def cbr(in_channels, out_channels, kernel_size = 3, stride = 1):
            layers = nn.Sequential(
                nn.Conv2d(in_channels,out_channels,kernel_size,stride,1),
                #nn.BatchNorm2d(out_channels),
                nn.ReLU(),

                nn.Conv2d(out_channels,out_channels,kernel_size,stride,1),
                #nn.BatchNorm2d(out_channels),
                nn.ReLU()
            )
            return layers
        
        self.h = nn.Sequential(
            nn.Conv2d(320,320,3,1,1),
            nn.BatchNorm2d(320),
            nn.ReLU(),
            
            nn.Conv2d(320,320,3,1,1),
            nn.BatchNorm2d(320),
            nn.ReLU()
        )

        
        self.enc0 = cbr(1,64)
        self.enc1 = cbr(64,128)
        self.enc2 = cbr(128,256)
        self.enc3 = cbr(256,512)
        self.enc4 = cbr(512,1024)

        self.pool = nn.MaxPool2d(2,2)
        
        self.enc0_dec0 = nn.Sequential(
            nn.Conv2d(64,64,3,1,1)
        )
        self.enc0_dec1 = nn.Sequential(
            nn.MaxPool2d(2,2),
            nn.Conv2d(64,64,3,1,1)
        )
        self.enc0_dec2 = nn.Sequential(
            nn.MaxPool2d(4,4),
            nn.Conv2d(64,64,3,1,1)
        )
        self.enc0_dec3 = nn.Sequential(
            nn.MaxPool2d(8,8),
            nn.Conv2d(64,64,3,1,1)
        )

        self.enc1_dec1 = nn.Sequential(
            nn.Conv2d(128,64,3,1,1)
        )
        self.enc1_dec2 = nn.Sequential(
            nn.MaxPool2d(2,2),
            nn.Conv2d(128,64,3,1,1)
        )
        self.enc1_dec3 = nn.Sequential(
            nn.MaxPool2d(4,4),
            nn.Conv2d(128,64,3,1,1)
        )

        self.enc2_dec2 = nn.Sequential(
            nn.Conv2d(256,64,3,1,1)
        )
        self.enc2_dec3 = nn.Sequential(
            nn.MaxPool2d(2,2),
            nn.Conv2d(256,64,3,1,1)
        )

        self.enc3_dec3 = nn.Sequential(
            nn.Conv2d(512,64,3,1,1)
        )


        self.dec_up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(320,64,3,1,1)
        )

        self.dec_up4 = nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear'),
            nn.Conv2d(320,64,3,1,1)
        )

        self.dec_up8 = nn.Sequential(
            nn.Upsample(scale_factor=8, mode='bilinear'),
            nn.Conv2d(320,64,3,1,1)
        )

        self.enc4_up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(1024,64,3,1,1)
        )

        self.enc4_up4 = nn.Sequential(
            nn.Upsample(scale_factor=4, mode='bilinear'),
            nn.Conv2d(1024,64,3,1,1)
        )

        self.enc4_up8 = nn.Sequential(
            nn.Upsample(scale_factor=8, mode='bilinear'),
            nn.Conv2d(1024,64,3,1,1)
        )

        self.enc4_up16 = nn.Sequential(
            nn.Upsample(scale_factor=16, mode='bilinear'),
            nn.Conv2d(1024,64,3,1,1)
        )
        
        self.result1 = nn.Sequential(
            nn.Conv2d(320,1,3,1,1),
            nn.Sigmoid()
        )
        self.result2 = nn.Sequential(
            nn.Conv2d(1024,1,3,1,1),
            nn.Sigmoid()
        )


        
      
    def forward(self,x):
        
        enc0 = self.enc0(x)
        pool1 = self.pool(enc0)

        enc1 = self.enc1(pool1)
        pool2 = self.pool(enc1)

        enc2 = self.enc2(pool2)
        pool3 = self.pool(enc2)

        enc3 = self.enc3(pool3)
        pool4 = self.pool(enc3)

        enc4 = self.enc4(pool4)

        dec3 = self.h(torch.cat([self.enc0_dec3(enc0),self.enc1_dec3(enc1),self.enc2_dec3(enc2),self.enc3_dec3(enc3),self.enc4_up2(enc4)],1))
        dec2 = self.h(torch.cat([self.enc0_dec2(enc0),self.enc1_dec2(enc1),self.enc2_dec2(enc2),self.dec_up2(dec3),self.enc4_up4(enc4)],1))
        dec1 = self.h(torch.cat([self.enc0_dec1(enc0),self.enc1_dec1(enc1),self.dec_up2(dec2),self.dec_up4(dec3),self.enc4_up8(enc4)],1))
        dec0 = self.h(torch.cat([self.enc0_dec0(enc0),self.dec_up2(dec1),self.dec_up4(dec2),self.dec_up8(dec3),self.enc4_up16(enc4)],1))

        out = []
        if self.DSV:
            out.append(self.result1(dec0))
            out.append(self.result1(dec1))
            out.append(self.result1(dec2))
            out.append(self.result1(dec3))
            out.append(self.result2(enc4))
        else:
            out.append(self.result1(dec0))

        return out 

summary(UNet3(),(1,400,400))

profile
👋 인공지능을 통해 다음 세대가 더 나은 삶을 살도록

0개의 댓글