Pytorch를 통해 GAN(Generative Adversarial Network)를 구현해볼 것이다.
데이터셋은 MNIST를 이용할 것이고, Noise로부터 MNIST의 Train 데이터셋과 매우 유사한 데이터를 생성해보겠다.
import random
import torch
import torch.nn as nn
import torch.utils.data
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import os.path as osp
%matplotlib inline
필요한 것들을 import 해준다.
PATH = "./MNIST_dataset"
batch_size = 256
z_size = 100
epochs = 500
learning_rate = 0.001
# Beta1 hyperparameter(for Adam)
beta1 = 0.5
real_label = 1
fake_label = 0
그리고 필요한 변수들 설정도 해준다.
train_dataset = dset.MNIST(root=PATH,
train=True,
transform=transforms.ToTensor(),
download=True)
data_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
MNIST 데이터셋을 설정한 PATH에 다운로드 한 후 load 해준다.
Generator와 Discriminator를 구현할 것이다.
Generator에서는 Latent vector Z(size=100)가 input으로 사용되고, hidden layer 통과 이후에는 ReLU activation function, output layer 통과 이후에는 sigmoid activation function을 사용한다.
Discriminator에서는 실제 데이터 혹은 가짜 데이터가 input으로 사용되고, hidden layer 통과 이후에는 ReLU activation function, output layer 통과 이후에는 sigmoid activation function을 사용한다.
Generator와 Discriminator의 가중치를 모두 Xavier initialization 해줄 것이다.
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.g_fc1 = nn.Linear(in_features=100, out_features=128, bias=True)
self.relu = nn.ReLU(inplace=True) # inplace=True 일 경우 메모리 소량 절약 가능, 그러나 원본 입력이 수정됨
self.g_fc2 = nn.Linear(in_features=128, out_features=784, bias=True)
self.sigmoid = nn.Sigmoid()
# Initialize weight parameters
nn.init.xavier_uniform_(self.g_fc1.weight, gain=1.0)
nn.init.xavier_uniform_(self.g_fc2.weight, gain=1.0)
def forward(self, x):
x = self.g_fc1(x)
x = self.relu(x)
x = self.g_fc2(x)
x = self.sigmoid(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.d_fc1 = nn.Linear(in_features=784, out_features=128, bias=True)
self.relu = nn.ReLU(inplace=True) # inplace=True 일 경우 메모리 소량 절약 가능, 그러나 원본 입력이 수정됨
self.d_fc2 = nn.Linear(in_features=128, out_features=1, bias=True)
self.sigmoid = nn.Sigmoid()
# Initialize weight parameters
nn.init.xavier_uniform_(self.d_fc1.weight, gain=1.0)
nn.init.xavier_uniform_(self.d_fc2.weight, gain=1.0)
def forward(self, x):
x = self.d_fc1(x)
x = self.relu(x)
x = self.d_fc2(x)
x = self.sigmoid(x)
return x
model_G = Generator()
model_D = Discriminator()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_G.to(device)
model_D.to(device)
print(model_G)
print(model_D)
Device: cuda
------------------------------------------------------------------------------------------------------------
Generator(
(g_fc1): Linear(in_features=100, out_features=128, bias=True)
(relu): ReLU(inplace=True)
(g_fc2): Linear(in_features=128, out_features=784, bias=True)
(sigmoid): Sigmoid()
)
Discriminator(
(d_fc1): Linear(in_features=784, out_features=128, bias=True)
(relu): ReLU(inplace=True)
(d_fc2): Linear(in_features=128, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(model_G.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizer_D = torch.optim.Adam(model_D.parameters(), lr=learning_rate, betas=(beta1, 0.999))
먼저 생성된 데이터를 시각화하기 위한 함수를 구현해준다.
def plot(samples):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig
그리고 Train과 Test를 위한 loop를 구성하여 학습을 진행한다.
# label for real & fake data
# batch_size 단위로 처리하기 위해 torch.full 기능 사용
label_real = torch.full((batch_size,), real_label, device=device, dtype=torch.float)
label_fake = torch.full((batch_size,), fake_label, device=device, dtype=torch.float)
fixed_noise = torch.randn(batch_size, z_size, device=device, dtype=torch.float)
for epoch in range(epochs):
model_G.train()
model_D.train()
for i, data in enumerate(data_loader):
data = data[0].to(device)
data = data.view(batch_size, -1)
# 노이즈 -> 가짜 데이터 생성(반복생성을 통해 오버피팅 방지)
noise = torch.randn(batch_size, z_size, device=device, dtype=torch.float)
fake_data = model_G(noise)
# Discriminator 학습
model_D.zero_grad()
output_real = model_D(data).view(-1)
Loss_D_real = criterion(output_real, label_real)
Loss_D_real.backward()
output_fake = model_D(fake_data.detach()).view(-1)
Loss_D_fake = criterion(output_fake, label_fake)
Loss_D_fake.backward()
Loss_D = Loss_D_real + Loss_D_fake
optimizer_D.step()
# Generator 학습
model_G.zero_grad()
output = model_D(fake_data).view(-1)
Loss_G = criterion(output, label_real)
Loss_G.backward()
optimizer_G.step()
# Output training stats
if (epoch+1) % 10 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
% ((epoch+1), epochs, i+1, len(data_loader), Loss_D.item(), Loss_G.item()))
if (epoch) % 50 == 0:
model_G.eval()
model_D.eval()
output = model_G(fixed_noise).detach().cpu().numpy()
fig = plot(output[:16])
결과는 아래와 같다.
[10/500][234/234] Loss_D: 0.4954 Loss_G: 2.0873
[20/500][234/234] Loss_D: 0.1655 Loss_G: 4.0228
[30/500][234/234] Loss_D: 0.1412 Loss_G: 4.1984
[40/500][234/234] Loss_D: 0.0683 Loss_G: 6.2528
[50/500][234/234] Loss_D: 1.0737 Loss_G: 2.0238
[60/500][234/234] Loss_D: 0.9343 Loss_G: 2.6887
[70/500][234/234] Loss_D: 0.8270 Loss_G: 1.4034
[80/500][234/234] Loss_D: 0.7454 Loss_G: 1.7409
[90/500][234/234] Loss_D: 0.7409 Loss_G: 2.0006
[100/500][234/234] Loss_D: 0.7229 Loss_G: 1.9061
[110/500][234/234] Loss_D: 0.7785 Loss_G: 1.9361
[120/500][234/234] Loss_D: 0.8553 Loss_G: 1.6024
[130/500][234/234] Loss_D: 0.7067 Loss_G: 1.6570
[140/500][234/234] Loss_D: 0.7467 Loss_G: 1.7937
[150/500][234/234] Loss_D: 0.8024 Loss_G: 2.6241
[160/500][234/234] Loss_D: 0.7929 Loss_G: 1.6627
[170/500][234/234] Loss_D: 1.0752 Loss_G: 1.1041
[180/500][234/234] Loss_D: 0.8622 Loss_G: 2.0550
[190/500][234/234] Loss_D: 0.6993 Loss_G: 1.9360
[200/500][234/234] Loss_D: 0.8738 Loss_G: 1.4923
[210/500][234/234] Loss_D: 0.8167 Loss_G: 1.7741
[220/500][234/234] Loss_D: 0.7779 Loss_G: 2.4709
[230/500][234/234] Loss_D: 0.7358 Loss_G: 2.4565
[240/500][234/234] Loss_D: 0.7585 Loss_G: 2.1113
[250/500][234/234] Loss_D: 0.6703 Loss_G: 2.0536
[260/500][234/234] Loss_D: 0.7598 Loss_G: 1.5473
[270/500][234/234] Loss_D: 0.6970 Loss_G: 1.9866
[280/500][234/234] Loss_D: 0.6878 Loss_G: 1.9665
[290/500][234/234] Loss_D: 0.7064 Loss_G: 2.1383
[300/500][234/234] Loss_D: 0.7916 Loss_G: 1.6122
[310/500][234/234] Loss_D: 0.6098 Loss_G: 1.9801
[320/500][234/234] Loss_D: 0.8324 Loss_G: 1.8967
[330/500][234/234] Loss_D: 0.6279 Loss_G: 2.1740
[340/500][234/234] Loss_D: 0.6652 Loss_G: 1.8633
[350/500][234/234] Loss_D: 0.5618 Loss_G: 2.4228
[360/500][234/234] Loss_D: 0.6021 Loss_G: 2.1645
[370/500][234/234] Loss_D: 0.6099 Loss_G: 2.3311
[380/500][234/234] Loss_D: 0.5985 Loss_G: 3.0751
[390/500][234/234] Loss_D: 0.6489 Loss_G: 2.2253
[400/500][234/234] Loss_D: 0.5930 Loss_G: 1.9898
[410/500][234/234] Loss_D: 0.5837 Loss_G: 2.2513
[420/500][234/234] Loss_D: 0.6508 Loss_G: 2.4088
[430/500][234/234] Loss_D: 0.6256 Loss_G: 2.0771
[440/500][234/234] Loss_D: 0.6489 Loss_G: 2.4406
[450/500][234/234] Loss_D: 0.6801 Loss_G: 2.5073
[460/500][234/234] Loss_D: 0.6494 Loss_G: 3.0150
[470/500][234/234] Loss_D: 0.5228 Loss_G: 2.6532
[480/500][234/234] Loss_D: 0.6141 Loss_G: 3.2598
[490/500][234/234] Loss_D: 0.6196 Loss_G: 1.9002
[500/500][234/234] Loss_D: 0.5650 Loss_G: 2.8008

완벽하지는 않지만, 그래도 MNIST 데이터셋과 유사한 데이터를 생성할 수 있었다. 학습이 진행될수록 데이터가 점점 더 정교해지는 것을 확인할 수 있다. 더 정교한 데이터를 생성하기 위해서는 GAN 모델의 복잡성을 높이거나, 학습을 더 진행하면 될 것으로 보인다.