[Pytorch] GAN 모델 구현(MNIST)

wh·2024년 8월 29일

GAN 모델 구현

Pytorch를 통해 GAN(Generative Adversarial Network)를 구현해볼 것이다.
데이터셋은 MNIST를 이용할 것이고, Noise로부터 MNIST의 Train 데이터셋과 매우 유사한 데이터를 생성해보겠다.

Import

import random
import torch
import torch.nn as nn
import torch.utils.data
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import os.path as osp

%matplotlib inline

필요한 것들을 import 해준다.



PATH = "./MNIST_dataset"

batch_size = 256

z_size = 100

epochs = 500

learning_rate = 0.001

# Beta1 hyperparameter(for Adam)
beta1 = 0.5

real_label = 1
fake_label = 0

그리고 필요한 변수들 설정도 해준다.


Dataset

train_dataset = dset.MNIST(root=PATH,
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)

data_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)

MNIST 데이터셋을 설정한 PATH에 다운로드 한 후 load 해준다.


Model 구현, 정의

Generator와 Discriminator를 구현할 것이다.

Generator에서는 Latent vector Z(size=100)가 input으로 사용되고, hidden layer 통과 이후에는 ReLU activation function, output layer 통과 이후에는 sigmoid activation function을 사용한다.

Discriminator에서는 실제 데이터 혹은 가짜 데이터가 input으로 사용되고, hidden layer 통과 이후에는 ReLU activation function, output layer 통과 이후에는 sigmoid activation function을 사용한다.

Generator와 Discriminator의 가중치를 모두 Xavier initialization 해줄 것이다.

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.g_fc1 = nn.Linear(in_features=100, out_features=128, bias=True)
        self.relu = nn.ReLU(inplace=True) # inplace=True 일 경우 메모리 소량 절약 가능, 그러나 원본 입력이 수정됨
        self.g_fc2 = nn.Linear(in_features=128, out_features=784, bias=True)
        self.sigmoid = nn.Sigmoid()

        # Initialize weight parameters
        nn.init.xavier_uniform_(self.g_fc1.weight, gain=1.0)
        nn.init.xavier_uniform_(self.g_fc2.weight, gain=1.0)

    def forward(self, x):
        x = self.g_fc1(x)
        x = self.relu(x)
        x = self.g_fc2(x)
        x = self.sigmoid(x)

        return x


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
       
        self.d_fc1 = nn.Linear(in_features=784, out_features=128, bias=True)
        self.relu = nn.ReLU(inplace=True) # inplace=True 일 경우 메모리 소량 절약 가능, 그러나 원본 입력이 수정됨
        self.d_fc2 = nn.Linear(in_features=128, out_features=1, bias=True)
        self.sigmoid = nn.Sigmoid()

        # Initialize weight parameters
        nn.init.xavier_uniform_(self.d_fc1.weight, gain=1.0)
        nn.init.xavier_uniform_(self.d_fc2.weight, gain=1.0)

    def forward(self, x):
        x = self.d_fc1(x)
        x = self.relu(x)
        x = self.d_fc2(x)
        x = self.sigmoid(x)
        return x


model_G = Generator()
model_D = Discriminator()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_G.to(device)
model_D.to(device)
print(model_G)
print(model_D)

Device: cuda
------------------------------------------------------------------------------------------------------------
Generator(
  (g_fc1): Linear(in_features=100, out_features=128, bias=True)
  (relu): ReLU(inplace=True)
  (g_fc2): Linear(in_features=128, out_features=784, bias=True)
  (sigmoid): Sigmoid()
)
Discriminator(
  (d_fc1): Linear(in_features=784, out_features=128, bias=True)
  (relu): ReLU(inplace=True)
  (d_fc2): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)



마지막으로 Loss fucntion과 Optimizer를 정의해준다.
criterion = nn.BCELoss()

optimizer_G = torch.optim.Adam(model_G.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizer_D = torch.optim.Adam(model_D.parameters(), lr=learning_rate, betas=(beta1, 0.999))

Train & Test

먼저 생성된 데이터를 시각화하기 위한 함수를 구현해준다.

def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
    return fig


그리고 Train과 Test를 위한 loop를 구성하여 학습을 진행한다.

# label for real & fake data
# batch_size 단위로 처리하기 위해 torch.full 기능 사용
label_real = torch.full((batch_size,), real_label, device=device, dtype=torch.float)
label_fake = torch.full((batch_size,), fake_label, device=device, dtype=torch.float)

fixed_noise = torch.randn(batch_size, z_size, device=device, dtype=torch.float)

for epoch in range(epochs):

    model_G.train()
    model_D.train()

    for i, data in enumerate(data_loader):

        data = data[0].to(device)
        data = data.view(batch_size, -1)


        # 노이즈 -> 가짜 데이터 생성(반복생성을 통해 오버피팅 방지)
        noise = torch.randn(batch_size, z_size, device=device, dtype=torch.float)
        fake_data = model_G(noise)

        # Discriminator 학습
        model_D.zero_grad()

        output_real = model_D(data).view(-1)
        Loss_D_real = criterion(output_real, label_real)
        Loss_D_real.backward()

        output_fake = model_D(fake_data.detach()).view(-1)
        Loss_D_fake = criterion(output_fake, label_fake)
        Loss_D_fake.backward()

        Loss_D = Loss_D_real + Loss_D_fake
        optimizer_D.step()


        # Generator 학습
        model_G.zero_grad()

        output = model_D(fake_data).view(-1)
        Loss_G = criterion(output, label_real)
        Loss_G.backward()
        optimizer_G.step()



    # Output training stats
    if (epoch+1) % 10 == 0:
        print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
              % ((epoch+1), epochs, i+1, len(data_loader), Loss_D.item(), Loss_G.item()))

    if (epoch) % 50 == 0:
        model_G.eval()
        model_D.eval()
        output = model_G(fixed_noise).detach().cpu().numpy()
        fig = plot(output[:16])

결과는 아래와 같다.



[10/500][234/234]	Loss_D: 0.4954	Loss_G: 2.0873
[20/500][234/234]	Loss_D: 0.1655	Loss_G: 4.0228
[30/500][234/234]	Loss_D: 0.1412	Loss_G: 4.1984
[40/500][234/234]	Loss_D: 0.0683	Loss_G: 6.2528
[50/500][234/234]	Loss_D: 1.0737	Loss_G: 2.0238
[60/500][234/234]	Loss_D: 0.9343	Loss_G: 2.6887
[70/500][234/234]	Loss_D: 0.8270	Loss_G: 1.4034
[80/500][234/234]	Loss_D: 0.7454	Loss_G: 1.7409
[90/500][234/234]	Loss_D: 0.7409	Loss_G: 2.0006
[100/500][234/234]	Loss_D: 0.7229	Loss_G: 1.9061
[110/500][234/234]	Loss_D: 0.7785	Loss_G: 1.9361
[120/500][234/234]	Loss_D: 0.8553	Loss_G: 1.6024
[130/500][234/234]	Loss_D: 0.7067	Loss_G: 1.6570
[140/500][234/234]	Loss_D: 0.7467	Loss_G: 1.7937
[150/500][234/234]	Loss_D: 0.8024	Loss_G: 2.6241
[160/500][234/234]	Loss_D: 0.7929	Loss_G: 1.6627
[170/500][234/234]	Loss_D: 1.0752	Loss_G: 1.1041
[180/500][234/234]	Loss_D: 0.8622	Loss_G: 2.0550
[190/500][234/234]	Loss_D: 0.6993	Loss_G: 1.9360
[200/500][234/234]	Loss_D: 0.8738	Loss_G: 1.4923
[210/500][234/234]	Loss_D: 0.8167	Loss_G: 1.7741
[220/500][234/234]	Loss_D: 0.7779	Loss_G: 2.4709
[230/500][234/234]	Loss_D: 0.7358	Loss_G: 2.4565
[240/500][234/234]	Loss_D: 0.7585	Loss_G: 2.1113
[250/500][234/234]	Loss_D: 0.6703	Loss_G: 2.0536
[260/500][234/234]	Loss_D: 0.7598	Loss_G: 1.5473
[270/500][234/234]	Loss_D: 0.6970	Loss_G: 1.9866
[280/500][234/234]	Loss_D: 0.6878	Loss_G: 1.9665
[290/500][234/234]	Loss_D: 0.7064	Loss_G: 2.1383
[300/500][234/234]	Loss_D: 0.7916	Loss_G: 1.6122
[310/500][234/234]	Loss_D: 0.6098	Loss_G: 1.9801
[320/500][234/234]	Loss_D: 0.8324	Loss_G: 1.8967
[330/500][234/234]	Loss_D: 0.6279	Loss_G: 2.1740
[340/500][234/234]	Loss_D: 0.6652	Loss_G: 1.8633
[350/500][234/234]	Loss_D: 0.5618	Loss_G: 2.4228
[360/500][234/234]	Loss_D: 0.6021	Loss_G: 2.1645
[370/500][234/234]	Loss_D: 0.6099	Loss_G: 2.3311
[380/500][234/234]	Loss_D: 0.5985	Loss_G: 3.0751
[390/500][234/234]	Loss_D: 0.6489	Loss_G: 2.2253
[400/500][234/234]	Loss_D: 0.5930	Loss_G: 1.9898
[410/500][234/234]	Loss_D: 0.5837	Loss_G: 2.2513
[420/500][234/234]	Loss_D: 0.6508	Loss_G: 2.4088
[430/500][234/234]	Loss_D: 0.6256	Loss_G: 2.0771
[440/500][234/234]	Loss_D: 0.6489	Loss_G: 2.4406
[450/500][234/234]	Loss_D: 0.6801	Loss_G: 2.5073
[460/500][234/234]	Loss_D: 0.6494	Loss_G: 3.0150
[470/500][234/234]	Loss_D: 0.5228	Loss_G: 2.6532
[480/500][234/234]	Loss_D: 0.6141	Loss_G: 3.2598
[490/500][234/234]	Loss_D: 0.6196	Loss_G: 1.9002
[500/500][234/234]	Loss_D: 0.5650	Loss_G: 2.8008

완벽하지는 않지만, 그래도 MNIST 데이터셋과 유사한 데이터를 생성할 수 있었다. 학습이 진행될수록 데이터가 점점 더 정교해지는 것을 확인할 수 있다. 더 정교한 데이터를 생성하기 위해서는 GAN 모델의 복잡성을 높이거나, 학습을 더 진행하면 될 것으로 보인다.

profile
열심히 배우는 중! 😌

0개의 댓글