[Pytorch를 통한 인공신경망 모델 구현] 3. GAN (Generative Adversarial Network)

김성욱·2023년 5월 15일
0

1. 논문 읽기 / 모르는 내용 정리

특별히 모르는 내용은 존재하지 않았다.

2. 핵심 내용 , insight 요약

  • 두 가지 Network를 학습시키는 것
  • 다른(적대적인) 목표를 가지고 있다.
  • Generator는 Discriminator를 속이기 위해 / Discriminator는 Generator에게 속지 않기 위해 학습
  • 하지만 Discriminator를 학습하는 것도 결국 뛰어난 Generator를 만들기 위함임을 알아야 한다.

3. 구조 파악 / 필요하다면 sketch

  • Generator / Discriminator 두 가지 Network를 각각 구성
  • pseudocode에 따라서 discriminator와 generator를 번갈아 가며 학습

4. 구현

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import tqdm

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        
        self.net = nn.Sequential(
                nn.Linear(28*28,512),
                nn.ReLU(),
                nn.Linear(512,256),
                nn.ReLU(),
                nn.Linear(256,1),
                nn.Sigmoid()
        )
        
    def forward(self,x):
        
        x = x.view(x.size(0), -1)
        output = self.net(x)
        return output
        
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(100,256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256,512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512,1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024,28*28),
            nn.Tanh()
        )
        
    def forward(self,x):
        output = self.net(x)
        
        return output.view(x.size(0),28,28)
        
  

5. 어려웠던 점 / 찾아본 내용

구현 자체는 간단하게 했지만 시간이 매우 오래 걸렸고.. 이유는 GAN의 학습이 잘 이뤄지지 않아서 였다.
구글링을 해보니 간단하게 학습이 되었다고 말하는 사람은 거의 없었다.

Generator에 Dropout layer를 추가하기 전까지는 심지어 전혀 이미지를 생성하지 못하는 모습 까지 보였다.
Adversarial 이라는 특성이 서로의 학습을 방해하므로 이런 모델을 구현할 때는 특히 더 학습이 잘 되도록 BN , Dropout 의 중요성이 올라갈 것 같다.

그리고 GAN이라는 개념 자체가 모델을 설명한다기보다는 방법론의 일종이므로 정해진 모델이 없었다.
이후로 나온 수많은 GAN의 변형들은 아마 D와 G를 어떻게 설정하고, 어떤 도메인에 적용할 수 있는지를 연구한 결과일 것이다.

6. 간단한 실험

MNIST dataset을 활용해 간단하게 이미지를 생성해 봤다.

transforms = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.5],[0.5])])  

mnist_train = datasets.MNIST(root='MNIST_data/',
                          train=True,
                          transform=transforms,
                          download=True)

mnist_test = datasets.MNIST(root='MNIST_data/',
                         train=False,
                         transform=transforms,
                         download=True)

mnist_loader = DataLoader(mnist_train,batch_size=128,shuffle=True)

D= Discriminator().to('cuda')
G= Generator().to('cuda')

optimizer_G = optim.Adam(G.parameters(),lr=0.0002)
optimizer_D = optim.Adam(D.parameters(),lr=0.0002)
criterion = nn.BCELoss().to('cuda')

epochs = 200
test_noise = torch.randn(16,100)

for epoch in range(epochs):
    g_losses= []
    d_losses= []
    
    if epoch % 5 ==0:
        plt.figure(figsize=(8,4))
        fake_imgs = G(test_noise.to('cuda'))
        for i in range(1,17):
            plt.subplot(2,8,i)    
            plt.imshow(fake_imgs[i-1].reshape(28,28).cpu().detach().numpy(),cmap='gray')
        plt.show()
        
    for img , _ in mnist_loader:
        real_label = torch.ones(img.size(0),1).to('cuda')
        fake_label = torch.zeros(img.size(0),1).to('cuda')
        
        optimizer_G.zero_grad()
        
        z = torch.Tensor(np.random.normal(0, 1, (img.size(0), 100)))
        fake_imgs = G(z.to('cuda'))
        
        g_loss = criterion(D(fake_imgs.unsqueeze(1).to('cuda')), real_label)
        g_loss.backward()
        optimizer_G.step()
        
        optimizer_D.zero_grad()

        output_real = D(img.to('cuda'))
        output_fake = D(fake_imgs.unsqueeze(1).detach().to('cuda'))
        
        real_loss = criterion(output_real,real_label)
        fake_loss = criterion(output_fake,fake_label)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        
        g_losses.append(g_loss)
        d_losses.append(d_loss)
    print("Epochs : ", epoch, "D_loss : ",sum(d_losses)/len(d_losses) , "G_loss : ", sum(g_losses)/len(g_losses))      
    

학습 코드는 위와 같다.

학습 도중 생성한 숫자들이다.

최종 모델로 100개의 숫자 생성. GAN을 통한 augmentation이 통하는 이유가 이처럼 원본과 비슷한 이미지를 잘 만들기 때문 아닐까? 몇몇 이미지를 제외하고는 dataset에 그대로 추가해도 손색이 없을 정도라고 생각한다.

7. 여담

사실 GAN에 대한 구현은 SBS에서 면접을 보고 나온 후로 갑작스럽게 결정했다.
면접에서 SBS 컨텐츠에 대한 생성형 모델을 적용해서 가치를 창출하겠다고 답변했는데,
어떻게 라는 방법이 없었다. 생성형 모델을 쓰면 되는데요? 라는게 아니라 어떤 모델을 써서 어떤 방식으로 적용하면 결과물이 나올 것 같다 라고 답했다면 좋았을텐데 그런 부분이 아쉬웠다.

면접 결과와 상관없이 당분간은 생성형 모델 쪽을 공부해보려고 한다.

다음 모델은 예상대로 G / D 에 Convolutional layer를 적용한 DCGAN(2015) / citation : 14000+ 이다

profile
someone

0개의 댓글