2014년 Ian Goodfellow(.et al)은 Generative Adversarial Networks(줄여서 GAN)라는 생성 모델을 훈련하는 방법을 제시했다.
GAN은 기본적으로 두 개의 다른 신경망(Generator와 Discriminator)간의 적대적인 관계로 대립(Adversarial)하며 서로의 성능을 점차 개선해 나가는 것이다.
GAN을 구현하고 MNIST데이터셋을 통해서 새로운 MNIST를 만들어본다.
GAN 코드구현을 위해 기본적으로 torch 라이브러리를 import 한다.
import torch
import torch.nn as nn
from torch.nn.modules import loss
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
기본적으로 신경망을 만들기 위해 torch.nn
을 부르고,
데이터 셋, 아키텍처 모델, 이미지 변환 기능으로 구성되어 있는 패키지인 torchvision
을 부른다.
class Discriminator(nn.Module):
def __init__(self, in_features):
super().__init__()
self.disc = nn.Sequential(
nn.Linear(in_features, 128),
nn.LeakyReLU(0.1),
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, x):
return self.disc(x)
Pytorch
forward()
는 모델이 학습데이터x
를 입력 받아서 forward propagation을 진행시키는 함수이다.
반드시forward
라는 이름의 함수여야 한다.
class Generator(nn.Module):
def __init__(self, z_dim, img_dim):
super().__init__()
self.gen = nn.Sequential(
nn.Linear(z_dim, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, img_dim), # 28 * 28 * 1 -> 784
nn.Tanh(),
)
def forward(self, x):
return self.gen(x)
z_dim
과 출력인img_dim
을 받는다. Tanh( )
함수를 쓰면 출력되는 값을 -1 ~ 1 사이로 맞출 수 있다.device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64 # 128, 256
image_dim = 28 * 28 * 1 # 784
batch_size = 32
num_epochs = 50
torch.cuda.is_available()
로 GPU를 활성화한다.batch_size
는 32num_epochs
는 50으로 학습한다.MNIST 데이터 특징
- 2차원의 이미지 데이터이지만, 실제로는 가로 28픽셀, 세로 28픽셀, 총 784픽셀의 각각에 화소값이 입력되어 있다.
784차원의 캔버스에 이미지를 생성하는 의미를 생각해보자.
그런데, 생성모델 이란 각 픽셀 값이 임의의 값을 취했을 때의 동시확률분포(joint probability)를 구하는 것과 같다.
예로 각 픽셀이 전부 0이 될 확률이라던가, 전부 1이 될 확률이라던가, 더욱 중간에 가까운 값, 784차원의 모든 값의 분포에 대해서 확률을 구하는 것이 된다.
즉, 전부 x분포에 대한 밀도 분포를 알 수 있게 된다.
그렇다면, MNIST는 처음에 넣을 때 784차원으로 지정해야하나?
아니다.
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real= SummaryWriter(f"runs/GAN_MNIST/real")
step = 0
Discriminator
와 Generator
를 선언하고 noise를 만든다.transform
을 통해서 Tensor형으로 바꾸고 정규화해준다.dataset
을 불러온다.Adam
을 Optimizer로 설정Loss(criterion)
는 논문에서 나왔듯 Real/Fake를 구분하기 위해 Binary cross entropy를 사용한다.for epoch in range(num_epochs):
for batch_idx, (real, _) in enumerate(loader):
real = real.view(-1, 784).to(device)
batch_size = real.shape[0]
### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
noise = torch.randn(batch_size, z_dim).to(device)
fake = gen(noise)
disc_real = disc(real).view(-1)
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake).view(-1)
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
lossD = (lossD_real + lossD_fake) / 2
disc.zero_grad()
lossD.backward(retain_graph=True)
opt_disc.step()
### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
# where the second option of maximizing doesn't suffer from
# saturating gradients
output = disc(fake).view(-1)
lossG = criterion(output, torch.ones_like(output))
gen.zero_grad()
lossG.backward()
opt_gen.step()
if batch_idx == 0:
print(
f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
)
with torch.no_grad():
fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
data = real.reshape(-1, 1, 28, 28)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
writer_fake.add_image(
"Mnist Fake Images", img_grid_fake, global_step=step
)
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
step += 1
fake
)lossD_real = criterion(disc_real, torch.ones_like(disc_real))
(진짜 이미지에 대한 Loss)와 lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
(가짜 이미지에 대한 Loss)를 구해서lossG
를 구하고 backwardTensorBoard로 학습 진행 상황을 볼 수 있다.
성능이 그렇게 좋지는 않다.
당연하다. 신경망을 단순하게 구현 했으니...
이걸 SimpleGAN이라고도 한다.
실제 GAN은 종류가 엄청 많은데
CNN을 통하여 구현한 DCGAN
Wasserstein distance를 적용한 WGAN
CycleGAN, ESRGAN, Pix2Pix, ProGAN, SRGAN, StyleGAN 등등....
목표는 GAN에 대한 여러가지 기법을 구현해보고 내 벨로그에 모두 올리는 것이다.
GAN에 관심을 가지고 공부하고 싶어하는 사람들에게 도움이 되고싶다...ㅎㅎ
다음은 아마 DCGAN과 StyleGAN, 그리고 요즘 공부하는 SinGAN에 대해서 업로드할 예정이다.
기대하시라~~!~!