비지도 학습: 생산적 적대 신경망 (GAN)

pppanghyun·2022년 8월 1일
0

Pytorch 기본

목록 보기
19/21

생성적 적대 신경망(GAN)은 재미있는 아이디어. 얼굴 변환, 생성, 음성 변조, 그림 스타일 변환, 사진 복원 등 다양한 도메인에 응용이 되는 중 !

G: 생성적 Generative, 이미지를 만들어내는
A: 적대 Adversarial, 적대적인
N: 신경망 Network, 뉴럴 네트워크!

GAN은 크게 generator와 discriminator로 구성되어 있음 (범인과 경찰)
Generative 모델(범인): 실제 이미지의 분포 (data distribution)을 알기 위해 노력
Discriminator 모델(경찰): 현재 자기가 보고 있는 샘플이 real인지 fake인지 구별!

결과적으로 P(generator)=P(real)가 되면 구별 못하는 상황이 옴 (D(x) =1/2 가 되는것!)
(도둑이 만든 지폐가 실제와 같아지면 구별 못함 -> 반반 확률로 구분 -> D(x) = 1/2)

GAN은 수식적으로 이해하는 것이 더 직관적임. GAN의 가치함수 V는 아래와 같음 (결국 minmax 문제)

두 가지 경우를 가정해보자 (범인이 완벽한 위조 지폐를 만들기 / 경찰이 완벽하게 구별하기)

(나만 직관적인 ㅠ)

아래 수식이 좀 더 깔끔함

1. 라이브러리

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd.variable import Variable
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import imageio

import numpy as np
from matplotlib import pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

2. 데이터 불러오기 (Fashion MNIST)

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,),(0.5,))
                ])
to_image = transforms.ToPILImage()
trainset = FashionMNIST(root='./data/', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=100, shuffle=True)

# Class 종류는 다음과 같음
# 'T-Shirt','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot'

3. GAN 모델 정의하기 (Discriminator and Generator)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.n_features = 128
        self.n_out = 784
        self.linear = nn.Sequential(
                    nn.Linear(self.n_features, 256),
                    nn.LeakyReLU(0.2),
                    nn.Linear(256, 512),
                    nn.LeakyReLU(0.2),
                    nn.Linear(512, 1024),
                    nn.LeakyReLU(0.2),
                    nn.Linear(1024, self.n_out),
                    nn.Tanh()
                    )
 
    def forward(self, x):
        x = self.linear(x)
        x = x.view(-1, 1, 28, 28)# 이미지 한개의 차원 (1 x 28 x 28)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_in = 784
        self.n_out = 1
        self.linear = nn.Sequential(
                    nn.Linear(self.n_in, 1024),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(1024, 512),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(512, 256),
                    nn.LeakyReLU(0.2),
                    nn.Dropout(0.3),
                    nn.Linear(256, self.n_out),
                    nn.Sigmoid()
                    )
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.linear(x)
        return x

4. 손실함수 및 최적화 방법 정의

generator = Generator().to(device)
discriminator = Discriminator().to(device)

pretrained =  False
if pretrained == True:
    discriminator.load_state_dict(torch.load('./models/fmnist_disc.pth'))
    generator.load_state_dict(torch.load('./models/fmnist_gner.pth'))

g_optim = optim.Adam(generator.parameters(), lr=2e-4)
d_optim = optim.Adam(discriminator.parameters(), lr=2e-4)

g_losses = []
d_losses = []
images = []

criterion = nn.BCELoss() # binary classification

def noise(n, n_features=128):
    return Variable(torch.randn(n, n_features)).to(device)

def label_ones(size):
    data = Variable(torch.ones(size, 1))
    return data.to(device)

def label_zeros(size):
    data = Variable(torch.zeros(size, 1))
    return data.to(device)

5. 학습방법 정의 (생성기, 분류기 만들어 각각 학습)

def train_discriminator(optimizer, real_data, fake_data):
    n = real_data.size(0)

    optimizer.zero_grad()
    
    prediction_real = discriminator(real_data) # real data 얼마나 잘 구분하는지
    d_loss = criterion(prediction_real, label_ones(n)) # real data 구분

    prediction_fake = discriminator(fake_data) # fake data 얼마나 잘 구분하는지
    g_loss = criterion(prediction_fake, label_zeros(n))
    
    loss = d_loss + g_loss # real/fake data loss 줄여야 함

    loss.backward()
    optimizer.step()
    
    return loss.item()

def train_generator(optimizer, fake_data):
    n = fake_data.size(0)
    optimizer.zero_grad()
    
    prediction = discriminator(fake_data) # 구별기가 fake data를 구별한 결과
    loss = criterion(prediction, label_ones(n)) # 그 결과가 작아야 함 ! (생성기 입장)
    
    loss.backward()
    optimizer.step()
    
    return loss.item()

6. 학습하기

num_epochs = 201
test_noise = noise(64)

l = len(trainloader)

for epoch in range(num_epochs):
    g_loss = 0.0
    d_loss = 0.0

    for data in trainloader:
        imgs, _ = data
        n = len(imgs)
        
        fake_data = generator(noise(n)).detach() # tensor 추적 중단 (뒤에서 또 써야됨)
        real_data = imgs.to(device)
        
        d_loss += train_discriminator(d_optim, real_data, fake_data)
        
        fake_data = generator(noise(n))
        
        g_loss += train_generator(g_optim, fake_data)

    img = generator(test_noise).cpu().detach()
    img = make_grid(img)
    images.append(img)
    g_losses.append(g_loss/l)
    d_losses.append(d_loss/l)

    if epoch % 10 == 0:
        print('Epoch {}: g_loss: {:.3f} d_loss: {:.3f}\r'.format(epoch, g_loss/l, d_loss/l))

torch.save(discriminator.state_dict(), './models/fmnist_disc.pth') 
torch.save(generator.state_dict(), './models/fmnist_gner.pth')   
print('Training Finished')

7. 학습 시각화

결과적으로 학습이 진행될 수록 Discriminator와 Genarator가 수렴!! + 개념 이해하고 코드 보면 재밌는 모델인듯ㅎㅎ

*참고영상

  1. Naver D2 1시간만에 GAN(Generative Adversarial Network) 완전 정복하기
    https://www.youtube.com/watch?v=odpjk7_tGY0
profile
pppanghyun

0개의 댓글