[ML/CV]U-Net: Convolutional Networks for Biomedical Image Segmentation

셍셍정보통·2023년 7월 8일

ReviewPaper

목록 보기
1/2

논문 아이디어


저차원 뿐만 아니라 고차원 정보도 이용하여 이미지의 특징을 추출함과 동시에 정확한 위치 파악도 가능하게 하자 → Skip Connection : 인코딩 단계의 각 레이어에서 얻은 특징을 디코딩 단계의 각 레이어야 합치는(concatenation) 방법

주요 특징

  • Encoder(Contracting Path) 와 Decoder (Expanding Path) 두 가지 경로로 구성
    • encoder-decoder 기반 모델
      • 보통 인코딩 단계에서는 입력 이미지의 특징을 포착할 수 있도록 채널의 수를 늘리면서 차원을 축소
        • 차원축소를 거치며 이미지 객체에 대한 자세한 위치 정보 잃게 됨
      • 디코딩 단계에서는 저차원으로 인코딩된 정보만 이용하여 채널의 수를 줄이거 차원을 늘려 고차원 이미지 복원
  • skip connection을 이용하여 더욱 정확한 예측 가능하게 함

장점

  • 픽셀 정확도의 의미론적 세분화 제공
  • 계산 속도 빠르다
  • 구조의 이해가 쉽다

취약점

U-Net스타일 구조의 유일한 취약점이라고 하면 깊은 모델에서의 중간 단계의 레이어에서 학습이 느려질 수 있기 때문에, 대략적인 Feature에 대한 정보가 유실될 수 있다는 점

U-Net Structure 정리



신경망 구조를 스킵 연결을 평행하게 두고 가운데를 기준으로 좌우가 대칭이되도록 레이어를 배치하여 U 형태 → U-Net

  • U-Net은 인코더 또는 축소 경로(Contracting Path)와 디코더 또는 확장경로(Expanding path)로 구성되며 두 구조는 서로 대칭
  • 인코더와 디코더를 연결하는 부분 브릿지(bridge)
  • 모두 3*3 컨볼루션을 사용

Encoder (Contracting Path)

  • Encoder은 전형적인 Convolutional Neural Network(CNN)를 기반으로 함
  • 여러 개의 Convolution Layer와 Max Pooling Layer 로 구성되어 이미지에서 feature를 추출하며 Spatial Information(공간 정보)를 줄임
  • Encoder 경로를 따라 내려갈수록 공간 차원은 줄어들고, 대신 feature의 복잡성과 수는 늘어남

  • 맨아래 브릿지는 두개의 파란색 박스로만 구성되어 있으므로 1개의 ConvBlock 레이어로 표현할 수 있음

ConvBlock

""" Conv Block """
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(ConvBlock, self).__init__()

        self.conv1 = Conv2D(n_filters, 3, padding='same')
        self.conv2 = Conv2D(n_filters, 3, padding='same')

        self.bn1 = BatchNormalization()
        self.bn2 = BatchNormalization()

        self.activation = Activation('relu')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation(x)

        return x

EncoderBlock

  • 출력 2개
    • U-Net의 디코더로 복사하기 위한 연결선
    • 2*2 max pooling 으로 다운 샘플링(down sampling)하여 인코더의 다음 단계로 내보냄
""" Encoder Block """
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(EncoderBlock, self).__init__()

        self.conv_blk = ConvBlock(n_filters)
        self.pool = MaxPooling2D((2,2))

    def call(self, inputs):
        x = self.conv_blk(inputs)
        p = self.pool(x)
        return x, p

Decoder(Expanding Path)

  • Encoder의 반대 방향으로 작동
  • Upsampling(or Transposed Convolution)을 이용하여 Spatial Information(공간 정보)을 증가시키면서 동시에 feature 의 복잡성을 줄임
  • Decoder 경로를 따라 올라가면 원래 이미지 크기 복원

세부구조

  • 파란색 블록은 인코더의 ConvBlock과 동일
  • 녹색 박스는 스킵 연결을 통해 인코더에 있는 맵 복사
  • 노란색 박스 디코더의 하위 단계에서 전치 컨볼루션(transposed convolution)을 통해서 맵의 차원을 두배로 늘리면서 채널의 수를 반으로 줄임
  • 두개의 맵을 서로 합쳐(concatenation) 저차원 이미지 정보뿐만 아니라 고차원 정보도 이용

DecoderBlock

위 그림에서 단계별 한 뭉탱이 의미

""" Decoder Block """
class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(DecoderBlock, self).__init__()

        self.up = Conv2DTranspose(n_filters, (2,2), strides=2, padding='same')
        self.conv_blk = ConvBlock(n_filters)

    def call(self, inputs, skip):
        x = self.up(inputs)
        x = Concatenate()([x, skip])
        x = self.conv_blk(x)

        return x

Skip Connection

  • U-Net의 중요한 특징 Encoder - Decoder 사이의 Skip Connection
  • Encoder의 각 Layer 에서 추출된 feature map 을 동일한 레빌의 Decoder Layer에 직접 연결
    • fine-grained details(작은 디테일) 이 Decoder로 전달되어 보다 정확한 segmentation 가능해짐
📌 **U-Net은 이미지의 Contextual 정보와 local 정보를 모두 고려하여 세부적인 예측을 가능하게 하는 아키텍처**

Baseline Code

Chatgpt 버전

Keras를 사용하여 U-Net 구조 구현하는 예시

입력 이미지 256*256 / 1개 채널(→ 그레이스케일)

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, UpSampling2D, concatenate

def create_model(img_height, img_width, num_channels):
    inputs = Input((img_height, img_width, num_channels))
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)

    up4 = concatenate([UpSampling2D(size=(2, 2))(conv3), conv2], axis=3)
    conv4 = Conv2D(128, (3, 3), activation='relu', padding='same')(up4)
    conv4 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv4)

    up5 = concatenate([UpSampling2D(size=(2, 2))(conv4), conv1], axis=3)
    conv5 = Conv2D(64, (3, 3), activation='relu', padding='same')(up5)
    conv5 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv5)

    conv6 = Conv2D(1, (1, 1), activation='sigmoid')(conv5)

    model = Model(inputs=[inputs], outputs=[conv6])

    return model

model = create_model(256, 256, 1)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()

블로그 버전

""" U-Net Model """
class UNET(tf.keras.Model):
    def __init__(self, n_classes):
        super(UNET, self).__init__()

        # Encoder
        self.e1 = EncoderBlock(64)
        self.e2 = EncoderBlock(128)
        self.e3 = EncoderBlock(256)
        self.e4 = EncoderBlock(512)

        # Bridge
        self.b = ConvBlock(1024)

        # Decoder
        self.d1 = DecoderBlock(512)
        self.d2 = DecoderBlock(256)
        self.d3 = DecoderBlock(128)
        self.d4 = DecoderBlock(64)

        # Outputs
        if n_classes == 1:
            activation = 'sigmoid'
        else:
            activation = 'softmax'

        self.outputs = Conv2D(n_classes, 1, padding='same', activation=activation)

    def call(self, inputs):
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b = self.b(p4)

        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)

        return outputs
# U-Net model
# coded by st.watermelon

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose
from tensorflow.keras.layers import Activation, BatchNormalization, Concatenate

""" Conv Block """
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(ConvBlock, self).__init__()

        self.conv1 = Conv2D(n_filters, 3, padding='same')
        self.conv2 = Conv2D(n_filters, 3, padding='same')

        self.bn1 = BatchNormalization()
        self.bn2 = BatchNormalization()

        self.activation = Activation('relu')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation(x)

        return x

""" Encoder Block """
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(EncoderBlock, self).__init__()

        self.conv_blk = ConvBlock(n_filters)
        self.pool = MaxPooling2D((2,2))

    def call(self, inputs):
        x = self.conv_blk(inputs)
        p = self.pool(x)
        return x, p

""" Decoder Block """
class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters):
        super(DecoderBlock, self).__init__()

        self.up = Conv2DTranspose(n_filters, (2,2), strides=2, padding='same')
        self.conv_blk = ConvBlock(n_filters)

    def call(self, inputs, skip):
        x = self.up(inputs)
        x = Concatenate()([x, skip])
        x = self.conv_blk(x)

        return x

""" U-Net Model """
class UNET(tf.keras.Model):
    def __init__(self, n_classes):
        super(UNET, self).__init__()

        # Encoder
        self.e1 = EncoderBlock(64)
        self.e2 = EncoderBlock(128)
        self.e3 = EncoderBlock(256)
        self.e4 = EncoderBlock(512)

        # Bridge
        self.b = ConvBlock(1024)

        # Decoder
        self.d1 = DecoderBlock(512)
        self.d2 = DecoderBlock(256)
        self.d3 = DecoderBlock(128)
        self.d4 = DecoderBlock(64)

        # Outputs
        if n_classes == 1:
            activation = 'sigmoid'
        else:
            activation = 'softmax'

        self.outputs = Conv2D(n_classes, 1, padding='same', activation=activation)

    def call(self, inputs):
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        b = self.b(p4)

        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        outputs = self.outputs(d4)

        return outputs

0개의 댓글