Pytorch를 이용해 DCGAN을 구현해본다.
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
#%matplotlib inline
# 학습 때 지속해서 랜덤한 값이 등장하지 않게 랜덤 seed를 정함
MANUAL_SEED = 1
random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
dataset = datasets.ImageFolder(root=DATA_ROOT,
transform=transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=N_WORKERS)
device = torch.device("cuda" if (torch.cuda.is_available() and N_GPU > 0) else "cpu")
print("device: ", torch.cuda.is_available())
# 데이터 확인
image = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
image = np.transpose(vutils.make_grid(image[0].to(device)[:64], padding=2, normalize=True).cpu(), (1,2,0))
plt.imshow(image)
plt.show()
def init_weight(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
nn.ConvTranspose2d(LATENT_VECTOR_SIZE, GENERATOR_FEATURE_MAP_SIZE * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(GENERATOR_FEATURE_MAP_SIZE * 8),
nn.ReLU(True),
nn.ConvTranspose2d(GENERATOR_FEATURE_MAP_SIZE * 8, GENERATOR_FEATURE_MAP_SIZE * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(GENERATOR_FEATURE_MAP_SIZE * 4),
nn.ReLU(True),
nn.ConvTranspose2d(GENERATOR_FEATURE_MAP_SIZE * 4, GENERATOR_FEATURE_MAP_SIZE * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(GENERATOR_FEATURE_MAP_SIZE * 2),
nn.ReLU(True),
nn.ConvTranspose2d(GENERATOR_FEATURE_MAP_SIZE * 2, GENERATOR_FEATURE_MAP_SIZE, 4, 2, 1, bias=False),
nn.BatchNorm2d(GENERATOR_FEATURE_MAP_SIZE),
nn.ReLU(True),
nn.ConvTranspose2d(GENERATOR_FEATURE_MAP_SIZE, TRAIN_IMAGE_CHANNEL_SIZE, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, data):
return self.main(data)
generator = Generator(N_GPU).to(device)
# multi gpu 설정
if (device.type == 'cuda') and (N_GPU > 1):
generator = nn.DataParallel(generator, list(range(N_GPU)))
generator.apply(init_weight)
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
nn.Conv2d(TRAIN_IMAGE_CHANNEL_SIZE, DISCRIMINATOR_FEATURE_MAP_SIZE, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(DISCRIMINATOR_FEATURE_MAP_SIZE, DISCRIMINATOR_FEATURE_MAP_SIZE * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 2, DISCRIMINATOR_FEATURE_MAP_SIZE * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 4, DISCRIMINATOR_FEATURE_MAP_SIZE * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(DISCRIMINATOR_FEATURE_MAP_SIZE * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, data):
return self.main(data)
discriminator = Discriminator(N_GPU).to(device)
# multi gpu 설정
if (device.type == 'cuda') and (N_GPU > 1):
discriminator = nn.DataParallel(discriminator, list(range(N_GPU)))
discriminator.apply(init_weight)
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, LATENT_VECTOR_SIZE, 1, 1, device=device)
REAL_LABEL = 1.
FAKE_LABEL = 0.
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(discriminator.parameters(), lr=LR, betas=(BETA1, 0.999))
optimizerG = optim.Adam(generator.parameters(), lr=LR, betas=(BETA1, 0.999))
for epoch in range(N_EPOCHS):
for i, data in enumerate(tqdm(dataloader), 0):
# 실제 이미지에 대해서 구별자 학습
discriminator.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), REAL_LABEL, dtype=torch.float, device=device)
output_real = discriminator(real_cpu).view(-1)
loss_real_d = criterion(output_real, label)
loss_real_d.backward()
# 가짜 이미지 생성한후, 가짜 이미지를 구별자가 구별하게 학습
noise = torch.randn(batch_size, LATENT_VECTOR_SIZE, 1, 1, device=device)
fake = generator(noise).detach()
label = torch.full((batch_size,), FAKE_LABEL, dtype=torch.float, device=device)
output_fake = discriminator(fake).view(-1)
loss_fake_d = criterion(output_fake, label)
loss_fake_d.backward()
optimizerD.step()
# 생성자가 진짜 같은 이미지를 만들도록 학습
generator.zero_grad()
fake = generator(noise).to(device)
output = discriminator(fake).view(-1)
label.fill_(REAL_LABEL)
loss_g = criterion(output, label)
loss_g.backward()
optimizerG.step()
# epoch별 생성된 이미지 확인
fake = generator(fixed_noise).cpu()
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.figure(figsize=(8,8))
image = np.transpose(vutils.make_grid(fake.detach()[:64], padding=2, normalize=True).cpu(), (1,2,0))
plt.imshow(image)
plt.show()