GAN(Generative Adversarial Network)은 복잡한 데이터 생성을 가능하게 하는 강력한 딥러닝 기법으로, Ian Goodfellow와 동료들이 2014년에 제안한 모델이다(Goodfellow et al., 2014). 두 가지 주요 신경망 구조, 즉 생성자(Generator)와 판별자(Discriminator)로 구성되어 있다. 생성자는 실제와 구분하기 어려운 새로운 데이터를 생성하는 목적을 가지고, 반면 판별자는 주어진 데이터가 실제 데이터인지 아니면 생성자가 만든 가짜 데이터인지를 판별하는 역할을 수행한다. 이 두 네트워크는 서로 대립하는 관계에 있으며, 이를 통해 모델 전체의 성능을 점진적으로 향상시킨다.
import torch
import torch.nn as nn
import pandas
import matplotlib.pyplot as plt
import random
import numpy
GAN 학습을 위해서는 두 종류의 데이터가 필요하다:
def generate_real():
real_data = torch.FloatTensor(
[random.uniform(0.8, 1.0),
random.uniform(0.0, 0.2),
random.uniform(0.8, 1.0),
random.uniform(0.0, 0.2)])
return real_data
def generate_random(size):
random_data = torch.rand(size)
return random_data
실제 데이터 생성 함수(generate_real): 실제 데이터셋을 모방하여 생성된 데이터. 이 함수는 특정 패턴을 가진 데이터를 생성하여 판별자가 실제 데이터와 생성된 데이터를 구분하는 능력을 학습할 수 있도록 한다.
무작위 데이터 생성 함수(generate_random): 생성자가 실제 데이터와 유사한 데이터를 생성하기 위한 초기 단계로, 무작위 데이터를 생성한다.
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(4, 3),
nn.Sigmoid(),
nn.Linear(3, 1),
nn.Sigmoid()
)
self.loss_function = nn.MSELoss()
self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
self.counter = 0
self.progress = []
판별자(Discriminator): 실제 데이터와 생성자가 생성한 데이터를 구분하는 역할을 한다. 일반적으로 심층 신경망으로 구성되며, 학습 과정에서 점차 생성된 데이터와 실제 데이터를 구분하는 능력을 향상시킨다.
이를 위해 선형 레이어와 시그모이드 활성화 함수를 사용하여 구성된 신경망을 정의한다. 손실 함수로는 평균 제곱 오차(MSE)를, 최적화 알고리즘으로는 SGD(Stochastic Gradient Descent)를 사용한다.
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(1, 3),
nn.Sigmoid(),
nn.Linear(3, 4),
nn.Sigmoid()
)
self.optimiser = torch.optim.SGD(self.parameters(), lr=0.01)
self.counter = 0
self.progress = []
생성자(Generator): 무작위 데이터로부터 시작하여 실제 데이터와 유사한 데이터를 생성하는 목표를 가진다. 생성자의 학습 목표는 판별자를 속이는 것이며, 이를 통해 점차 실제와 구분하기 어려운 데이터를 생성할 수 있게 된다. 생성자 또한 선형 레이어와 시그모이드 활성화 함수로 구성된 신경망을 사용한다.
GAN의 학습 과정은 다음과 같은 단계를 포함한다:
%%time
# create Discriminator and Generator
D = Discriminator()
G = Generator()
image_list = []
# train Discriminator and Generator
for i in range(10000):
# train discriminator on true
D.train(generate_real(), torch.FloatTensor([1.0]))
# train discriminator on false
# use detach() so gradients in G are not calculated
D.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))
# train generator
G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))
# add image to list every 1000
if (i % 1000 == 0):
image_list.append( G.forward(torch.FloatTensor([0.5])).detach().numpy() )
pass
판별자 학습: 실제 데이터와 생성된 데이터를 구분할 수 있도록 학습한다.
생성자 학습: 판별자를 속일 수 있는 더 정교한 데이터를 생성하기 위해 학습한다.
이러한 학습 과정은 생성자와 판별자 사이의 지속적인 경쟁을 통해 이루어진다. 초기에는 생성된 데이터와 실제 데이터 사이에 큰 차이가 있을 수 있으나, 반복적인 학습을 통해 생성자는 점점 실제 데이터와 유사한 데이터를 생성하게 된다.
소개한 GAN을 잘 학습하면, 출력은 0.5로 0과 1사이의 중간값이다.
MSE의 이상적인 값은 0.25이다.
다음 post는 MNIST이미지를 통해 학습하는 과정을 소개하겠다.