U-Net(sigmoid, softmax)

seongyun·2025년 5월 12일

Neural Network

목록 보기
5/8

오랜만에 돌아온 모델 제작 시간이다.
굉장히 심심해서 U-Net 모델 형태를 코드로 제작해보는 시간을 가졌다.
(sigmoid, softmax 버전으로 2개)

U-Net이란?

U-Net은 이미지 내 각 픽셀이 어떤 클래스에 속하는지를 예측하는 모델이다. 예를 들어, 흑백 위성 이미지에서 도로가 있는 픽셀과 배경 픽셀을 분리하는 작업에 사용된다.

네트워크 구조 (U자 형태)

U-Net은 크게 인코더(Encoder) 경로와 디코더(Decoder) 경로로 구성되며, 스킵 연결(Skip Connections)이 존재한다.

(1) 인코더 – Contracting Path

  • 목적: 이미지의 특징(feature)을 추출하고 점점 공간 해상도를 줄여 나갑니다.
  • 구성:
    -Conv 3x3 + ReLU
    -Conv 3x3 + ReLU
    -MaxPooling 2x2
    -이 과정을 반복하면서, 이미지의 너비·높이는 줄어들고 채널 수는 늘어납니다.
  • 예: 572×572×1 → 284×284×64 → … → 28×28×1024

(2) 디코더 – Expanding Path

  • 목적: 인코더에서 추출한 특징을 사용하여 원래 크기의 세그멘테이션 맵을 재구성합니다.
  • 구성:
    -Up-convolution (Transpose Conv 또는 Upsampling)
    -Skip Connection (같은 레벨의 인코더 출력과 concat)
    -Conv 3x3 + ReLU
    -Conv 3x3 + ReLU
  • 예: 28×28×1024 → 56×56×512 → … → 388×388×64

(3) 스킵 연결

  • 인코더의 feature map을 디코더로 그대로 연결(concat)하여 공간적 정보 손실을 보완합니다.
  • 이는 정확한 경계선을 복원하는 데 매우 중요합니다.

마지막 출력층

  • Sigmoid를 사용하면 이진 분할 (예: 도로 vs 비도로)
  • Softmax를 사용하면 다중 클래스 분할 (예: 건물, 도로, 식생 등)
  • 출력 채널 수는 분할하려는 클래스 수에 따라 결정됩니다.

손실 함수 (Loss Function)

  • Binary segmentation → Binary Cross Entropy (BCE)
  • Multi-class segmentation → Categorical Cross Entropy (Softmax)
  • Intersection over Union (IoU), Dice Loss 등도 자주 사용

U-Net의 핵심 장점

  • 적은 수의 학습 이미지로도 좋은 성능을 냄 (데이터 증강 필수)
  • 경계선이 명확한 객체의 분할에 강함
  • End-to-End 학습 가능
  • Skip connection 덕분에 작은 객체도 잘 잡아냄

U-Net의 주요 파라미터 예시

이름값 예시
입력 크기(1, 572, 572)
인코더 깊이4~5단계
필터 수64, 128, 256, 512…
커널 크기3x3
업샘플링 방법ConvTranspose2d
출력 채널 수클래스 수 (예: 1)

U-Net 응용 분야

  • 의료 영상 (CT, MRI) 내 장기/종양 분할
  • 위성/항공 이미지 내 도로, 강, 건물 탐지
  • 농업 (작물 및 병변 식별)
  • 로보틱스 시각 처리
  • 자율주행 (도로 차선, 보행자 분할)

데이터를 넣지 않은 코드 2개

# U-Net(sigmoid)
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()
        self.encoder1 = self.conv_block(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = self.conv_block(256, 512)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = self.conv_block(512, 1024)

        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, 2)
        self.decoder4 = self.conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.decoder3 = self.conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.decoder2 = self.conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.decoder1 = self.conv_block(128, 64)

        self.conv_last = nn.Conv2d(64, out_channels, 1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return torch.sigmoid(self.conv_last(dec1))
# U-Net(softmax)
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNetMultiClass(nn.Module):
    def __init__(self, in_channels=3, num_classes=4):
        super(UNetMultiClass, self).__init__()
        
        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True)
            )
        
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)
        self.enc4 = conv_block(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = conv_block(512, 1024)
        
        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)
        
        self.conv_last = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        
        bottleneck = self.bottleneck(self.pool(enc4))
        
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.dec4(dec4)
        
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.dec3(dec3)
        
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.dec2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.dec1(dec1)
        
        out = self.conv_last(dec1)  # shape: [B, num_classes, H, W]
        return out  # CrossEntropyLoss 사용 시 softmax는 생략

model = UNetMultiClass(in_channels=3, num_classes=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

다음 시간에는 이 두 모델에 직접 데이터를 넣고 진행하는 과정을 적어보도록 하겠다.

0개의 댓글