U-Net

박영욱·2025년 4월 27일
post-thumbnail

이번 글에서는 Pytorch-UNet 프로젝트의 unet_model.py 파일을 분석합니다.
unet_parts.py에서 만든 다양한 모듈들을 조립하여 전체 U-Net 모델을 완성하는 부분입니다.

🔹 U-Net이란?

입력 이미지를 점점 압축(Down)했다가
다시 키우면서(Up) 세밀한 정보를 복원하는
"U자 형태"의 세그멘테이션 신경망입니다.
구조는 인코더 → 디코더 순서로 진행됩니다.

""" Full assembly of the parts to form the complete network """

from .unet_parts import *


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

🛠️ 클래스 구성

  • 앞 블로그 parts 부분 참고.

2. 인코더(Downsampling)

self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
  • 입력 이미지를 점점 줄이면서 특징을 풍성하게 뽑습니다

3. 가장 깊은 Bottleneck

factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
  • bilinear 업샘플링이면 채널 수를 조정해줍니다.

4. 디코더(Upsampling)

self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
  • 줄였던 크기를 다시 복구하면서 skip connection으로 과거 정보를 합칩니다.

5. 출력층

self.outc = OutConv(64, n_classes)
  • 최종적으로 1x1 Convolution을 통해 채널 수를 클래스 수로 맞춥니다.

🔥 Forward 함수 (데이터 흐름)

def forward(self, x):
    x1 = self.inc(x)
    x2 = self.down1(x1)
    x3 = self.down2(x2)
    x4 = self.down3(x3)
    x5 = self.down4(x4)
    x = self.up1(x5, x4)
    x = self.up2(x, x3)
    x = self.up3(x, x2)
    x = self.up4(x, x1)
    logits = self.outc(x)
    return logits
  • 입력이 inc → down1 → down2 → down3 → down4로 압축되고,
  • up1 → up2 → up3 → up4 → outc로 복원됩니다.

📈 전체 구조 흐름

Input (이미지)
 ↓
DoubleConv (입력 특징 추출)
 ↓
Down 1 (크기 줄이기 + 특징 뽑기)
 ↓
Down 2
 ↓
Down 3
 ↓
Down 4 (가장 깊은 곳)
 ↓
Up 1 (크기 키우기 + skip 연결)
 ↓
Up 2
 ↓
Up 3
 ↓
Up 4
 ↓
OutConv (최종 결과 뽑기)
 ↓
Output (Segmentation Map)

U자형 구조가 만들어진다.

profile
Medical AI

0개의 댓글