생성적 적대 신경망(GAN)은 재미있는 아이디어. 얼굴 변환, 생성, 음성 변조, 그림 스타일 변환, 사진 복원 등 다양한 도메인에 응용이 되는 중 !
G: 생성적 Generative, 이미지를 만들어내는
A: 적대 Adversarial, 적대적인
N: 신경망 Network, 뉴럴 네트워크!
GAN은 크게 generator와 discriminator로 구성되어 있음 (범인과 경찰)
Generative 모델(범인): 실제 이미지의 분포 (data distribution)을 알기 위해 노력
Discriminator 모델(경찰): 현재 자기가 보고 있는 샘플이 real인지 fake인지 구별!
결과적으로 P(generator)=P(real)가 되면 구별 못하는 상황이 옴 (D(x) =1/2 가 되는것!)
(도둑이 만든 지폐가 실제와 같아지면 구별 못함 -> 반반 확률로 구분 -> D(x) = 1/2)
GAN은 수식적으로 이해하는 것이 더 직관적임. GAN의 가치함수 V는 아래와 같음 (결국 minmax 문제)
두 가지 경우를 가정해보자 (범인이 완벽한 위조 지폐를 만들기 / 경찰이 완벽하게 구별하기)
(나만 직관적인 ㅠ)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd.variable import Variable
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import imageio
import numpy as np
from matplotlib import pyplot as plt
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,))
])
to_image = transforms.ToPILImage()
trainset = FashionMNIST(root='./data/', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True)
# Class 종류는 다음과 같음
# 'T-Shirt','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot'
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.n_features = 128
self.n_out = 784
self.linear = nn.Sequential(
nn.Linear(self.n_features, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, self.n_out),
nn.Tanh()
)
def forward(self, x):
x = self.linear(x)
x = x.view(-1, 1, 28, 28)# 이미지 한개의 차원 (1 x 28 x 28)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.n_in = 784
self.n_out = 1
self.linear = nn.Sequential(
nn.Linear(self.n_in, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, self.n_out),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 784)
x = self.linear(x)
return x
generator = Generator().to(device)
discriminator = Discriminator().to(device)
pretrained = False
if pretrained == True:
discriminator.load_state_dict(torch.load('./models/fmnist_disc.pth'))
generator.load_state_dict(torch.load('./models/fmnist_gner.pth'))
g_optim = optim.Adam(generator.parameters(), lr=2e-4)
d_optim = optim.Adam(discriminator.parameters(), lr=2e-4)
g_losses = []
d_losses = []
images = []
criterion = nn.BCELoss() # binary classification
def noise(n, n_features=128):
return Variable(torch.randn(n, n_features)).to(device)
def label_ones(size):
data = Variable(torch.ones(size, 1))
return data.to(device)
def label_zeros(size):
data = Variable(torch.zeros(size, 1))
return data.to(device)
def train_discriminator(optimizer, real_data, fake_data):
n = real_data.size(0)
optimizer.zero_grad()
prediction_real = discriminator(real_data) # real data 얼마나 잘 구분하는지
d_loss = criterion(prediction_real, label_ones(n)) # real data 구분
prediction_fake = discriminator(fake_data) # fake data 얼마나 잘 구분하는지
g_loss = criterion(prediction_fake, label_zeros(n))
loss = d_loss + g_loss # real/fake data loss 줄여야 함
loss.backward()
optimizer.step()
return loss.item()
def train_generator(optimizer, fake_data):
n = fake_data.size(0)
optimizer.zero_grad()
prediction = discriminator(fake_data) # 구별기가 fake data를 구별한 결과
loss = criterion(prediction, label_ones(n)) # 그 결과가 작아야 함 ! (생성기 입장)
loss.backward()
optimizer.step()
return loss.item()
num_epochs = 201
test_noise = noise(64)
l = len(trainloader)
for epoch in range(num_epochs):
g_loss = 0.0
d_loss = 0.0
for data in trainloader:
imgs, _ = data
n = len(imgs)
fake_data = generator(noise(n)).detach() # tensor 추적 중단 (뒤에서 또 써야됨)
real_data = imgs.to(device)
d_loss += train_discriminator(d_optim, real_data, fake_data)
fake_data = generator(noise(n))
g_loss += train_generator(g_optim, fake_data)
img = generator(test_noise).cpu().detach()
img = make_grid(img)
images.append(img)
g_losses.append(g_loss/l)
d_losses.append(d_loss/l)
if epoch % 10 == 0:
print('Epoch {}: g_loss: {:.3f} d_loss: {:.3f}\r'.format(epoch, g_loss/l, d_loss/l))
torch.save(discriminator.state_dict(), './models/fmnist_disc.pth')
torch.save(generator.state_dict(), './models/fmnist_gner.pth')
print('Training Finished')