
오랜만에 돌아온 모델 제작 시간이다.
굉장히 심심해서 U-Net 모델 형태를 코드로 제작해보는 시간을 가졌다.
(sigmoid, softmax 버전으로 2개)
U-Net은 이미지 내 각 픽셀이 어떤 클래스에 속하는지를 예측하는 모델이다. 예를 들어, 흑백 위성 이미지에서 도로가 있는 픽셀과 배경 픽셀을 분리하는 작업에 사용된다.
U-Net은 크게 인코더(Encoder) 경로와 디코더(Decoder) 경로로 구성되며, 스킵 연결(Skip Connections)이 존재한다.
| 이름 | 값 예시 |
|---|---|
| 입력 크기 | (1, 572, 572) |
| 인코더 깊이 | 4~5단계 |
| 필터 수 | 64, 128, 256, 512… |
| 커널 크기 | 3x3 |
| 업샘플링 방법 | ConvTranspose2d |
| 출력 채널 수 | 클래스 수 (예: 1) |
# U-Net(sigmoid)
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1):
super(UNet, self).__init__()
self.encoder1 = self.conv_block(in_channels, 64)
self.pool1 = nn.MaxPool2d(2)
self.encoder2 = self.conv_block(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.encoder3 = self.conv_block(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.encoder4 = self.conv_block(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.bottleneck = self.conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, 2)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, 2, 2)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, 2, 2)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, 2, 2)
self.decoder1 = self.conv_block(128, 64)
self.conv_last = nn.Conv2d(64, out_channels, 1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv_last(dec1))
# U-Net(softmax)
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNetMultiClass(nn.Module):
def __init__(self, in_channels=3, num_classes=4):
super(UNetMultiClass, self).__init__()
def conv_block(in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
self.enc1 = conv_block(in_channels, 64)
self.enc2 = conv_block(64, 128)
self.enc3 = conv_block(128, 256)
self.enc4 = conv_block(256, 512)
self.pool = nn.MaxPool2d(2)
self.bottleneck = conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
self.dec4 = conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.dec3 = conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.dec2 = conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.dec1 = conv_block(128, 64)
self.conv_last = nn.Conv2d(64, num_classes, 1)
def forward(self, x):
enc1 = self.enc1(x)
enc2 = self.enc2(self.pool(enc1))
enc3 = self.enc3(self.pool(enc2))
enc4 = self.enc4(self.pool(enc3))
bottleneck = self.bottleneck(self.pool(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.dec4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.dec3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.dec2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.dec1(dec1)
out = self.conv_last(dec1) # shape: [B, num_classes, H, W]
return out # CrossEntropyLoss 사용 시 softmax는 생략
model = UNetMultiClass(in_channels=3, num_classes=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
다음 시간에는 이 두 모델에 직접 데이터를 넣고 진행하는 과정을 적어보도록 하겠다.