입력 X ─▶ [Encoder] ─▶ 잠재 표현 Z ─▶ [Decoder] ─▶ 복원된 X'
Latent Space (잠재 공간): 핵심 특징만 담은 벡터 (차원이 작음)
입력 X
↓
[Encoder]
- Linear(784 → 256)
- ReLU
- Linear(256 → 64)
↓
Latent Vector (64차원)
↓
[Decoder]
- Linear(64 → 256)
- ReLU
- Linear(256 → 784)
- Sigmoid
↓
출력 X' (복원된 입력)
복원된 X'와 원래 입력 X의 차이를 줄이는 게 목표.
대표적인 손실 함수: MSE (Mean Squared Error)
"""
[입력 이미지: 28x28]
→ 인코더: Conv + ReLU + MaxPool
→ 잠재 표현: 압축된 Feature Map
→ 디코더: ConvTranspose + ReLU
→ 출력 이미지: 28x28 복원
"""
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# 장치 설정 (GPU 사용 가능 시 GPU 사용)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 하이퍼파라미터
epochs = 5
batch_size = 128
learning_rate = 1e-3
# MNIST 데이터셋 불러오기 (흑백 이미지 28x28)
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# CNN Autoencoder 클래스 정의
class CNNAutoencoder(nn.Module):
def __init__(self):
super(CNNAutoencoder, self).__init__()
# 인코더: Conv → ReLU → MaxPool
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding=1), # 28x28 → 16x28x28
nn.ReLU(),
nn.MaxPool2d(2, 2), # → 16x14x14
nn.Conv2d(16, 8, kernel_size=3, padding=1), # → 8x14x14
nn.ReLU(),
nn.MaxPool2d(2, 2) # → 8x7x7
)
# 디코더: ConvTranspose → ReLU → 복원
self.decoder = nn.Sequential(
nn.ConvTranspose2d(8, 16, kernel_size=2, stride=2), # → 16x14x14
nn.ReLU(),
nn.ConvTranspose2d(16, 1, kernel_size=2, stride=2), # → 1x28x28
nn.Sigmoid() # 픽셀 값 0~1로 맞춤
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 모델, 손실 함수, 옵티마이저 정의
model = CNNAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 학습 루프
for epoch in range(epochs):
for data, _ in train_loader:
data = data.to(device)
# 순전파
output = model(data)
loss = criterion(output, data)
# 역전파
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
# 테스트용 이미지 시각화
with torch.no_grad():
sample = next(iter(train_loader))[0][:6].to(device) # 6장 샘플 이미지
reconstructed = model(sample).cpu()
# 결과 시각화
plt.figure(figsize=(9,2))
for i in range(6):
# 원본
plt.subplot(2,6,i+1)
plt.imshow(sample[i].cpu().squeeze(), cmap='gray')
plt.title("Original")
plt.axis('off')
# 복원
plt.subplot(2,6,i+7)
plt.imshow(reconstructed[i].squeeze(), cmap='gray')
plt.title("Reconstructed")
plt.axis('off')
plt.tight_layout()
plt.show()