[Basic Generative Model] GAN

이한결·2024년 11월 12일
0

CS492(D)

목록 보기
2/14

출처[LECTURE02]

강의 영상: https://www.youtube.com/watch?v=Nh9MIEbCJIw

해당 강의를 기반으로 추가적인 코드 구현과 설명을 정리했습니다.

Generative Adversarial Network

Adversarial Attack

빨간색 네모박스가 된 이미지들은 왼쪽의 이미지에 간단한 노이즈를 추가한 이미지입니다. 육안으로는 차이가 없다고 느끼지만 neural network는 해당 이미지를 ostrich(오른쪽 아래 그림)으로 판별합니다. 이처럼 간단한 노이즈만를 추가하는 adversarial attack으로인해 neural network는 잘못된 예측을 합니다. 이는 neural network의 보안, 취약성, 신뢰성의 문제를 보여주는 한 사례이기도 합니다.

Discriminator & Generator

Discriminator(Classifier)

진짜 이미지인지 AI가 생성한 가짜 이미지인지 판별하는 역할

  • 입력값으로 이미지가 들어오면 진짜인지 가짜인지 판별합니다.

Generator(Decoder)

Adverarial attack을 통해서 Discriminator가 구별하지 못하도록 하는 역할. 다시 말해 진짜 이미지를 생성하는 역할

  • latent(간단한 데이터 분포)는 Gaussian(정규 분포)로 부터 샘플된 분포이고, 이를 이용해서 이미지를 생성합니다.

데이터의 구조는 위와 같습니다. 다시 한번 설명하자면 맨 왼쪽 아래에 보이는 latent(z)로부터 Generator가 이미지를 생성합니다. 이후 실제 이미지와 Generator가 생성된 이미지는 Discriminator에 들어갑니다. 이때 Discriminator는 진짜 이미지인지 Generator가 생성한 이미지인지 즉 Real or Fake를 분류하게 됩니다.

Loss

  • 첫번째 부분의 Discriminator는 real or fake를 잘 구분해야 하기 때문에 이에 대한 값
  • 두번째 부분의 Generator는 Discriminator가 틀려야하기 때문에 이에 대한 값

GAN의 한계

Loss의 맨앞쪽에 나온것처럼 Loss는 Min-Max 문제입니다. 이는 상충되는 목표를 달성해야하기 때문에 어려운 문제로 여겨집니다.

  • Gradient dscent의 최적화 알고리즘에서 한 변수의 값이 변하며 다른 변수의 값이 다시 반대로 변하려는 상충이 발생
  • Generator와 Discriminator의 균형점을 찾기 힘듦
    • 예를 들어, Generator가 학습을 너무 잘하면 Discriminator는 판별하기 더욱 어려워질 것이고, 반대의 경우도 동일합니다.
  • Mode Collapse 발생
    • Generator는 단순히 Discriminator만 속이면되기 때문에 특정 데이터를 real로 잘만든다면 해당 데이터만 생성하게 될 것입니다.
    • 이는 다양한 출력대신 한정된 몇 가지 유형의 출력만을 하도록 유도합니다.

https://velog.io/@guts4/GANGenerative-Adversal-Network-%EB%85%BC%EB%AC%B8%EB%A6%AC%EB%B7%B0

더자세한 논문 설명은 해당 링크를 참고하시기 바랍니다.

GAN Code

자료 출처: https://github.com/Yangyangii/GAN-Tutorial/blob/master/MNIST/VanillaGAN.ipynb

get_sample_image

정규 분포로부터 Generator를 통해서 100개의 이미지 생성

def get_sample_image(G, n_noise):
    """
        save sample 100 images
    """
    z = torch.randn(100, n_noise).to(DEVICE) # 정규 분포에서 100개의 latent vector 생성
    y_hat = G(z).view(100, 28, 28) # Generator가 (28,28)크기의 이미지 생성
    result = y_hat.cpu().data.numpy() # 생성된 이미지를 cpu로 옮기고 numpy 배열로 변환
    img = np.zeros([280, 280]) #(280,280) 빈 배열 생성 -> (28,28) 100개를 합치기 위해 사용될 예정
    for j in range(10):
        img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)
    return img

Discriminator

Real or Fake 이미지를 구별하는 판별자

class Discriminator(nn.Module):
    """
        Simple Discriminator w/ MLP
    """
    def __init__(self, input_size=784, num_classes=1): # 28 * 28 = 784(input_size), 이진 분류이기때문에 class개수는 한개
        super(Discriminator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, num_classes),
            nn.Sigmoid(), # 이진 분류이기때문에 Sigmoid
        )
    
    def forward(self, x):
        y_ = x.view(x.size(0), -1) # [batch_size, 1, 28, 28] -> [batch_size, 784]
        y_ = self.layer(y_)
        return y_

Generator

Discriminator를 속이면서 fake 이미지를 생성하는 Generator

class Generator(nn.Module):
    """
        Simple Generator w/ MLP
    """
    def __init__(self, input_size=100, num_classes=784): # 크기가 100인 latent vector를 입력으로 받고, 크기가 784인 이미지를 출력
        super(Generator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, num_classes),
            nn.Tanh()
        )
        
    def forward(self, x):
        y_ = self.layer(x)
        y_ = y_.view(x.size(0), 1, 28, 28)
        return y_

Train code

for epoch in range(max_epoch):
    for idx, (images, _) in enumerate(data_loader): # _는 레이블인데 여기서는 가져올 필요가 없으므로 무시
        # Training Discriminator
        x = images.to(DEVICE)
        x_outputs = D(x) # Discriminator가 실제 이미지에 대한 예측
        D_x_loss = criterion(x_outputs, D_labels)

        z = torch.randn(batch_size, n_noise).to(DEVICE)
        z_outputs = D(G(z)) # Discriminator가 가짜 이미지에 대한 예측
        D_z_loss = criterion(z_outputs, D_fakes) 
        D_loss = D_x_loss + D_z_loss # Discriminator가 예측한 진짜 + 가짜 이미지 예측에 대한 loss 합
        
        D.zero_grad()
        D_loss.backward()
        D_opt.step()

        if step % n_critic == 0:
            # Training Generator
            z = torch.randn(batch_size, n_noise).to(DEVICE)
            z_outputs = D(G(z)) # Generator가 가짜 이미지 생성
            G_loss = criterion(z_outputs, D_labels) 

            G.zero_grad()
            G_loss.backward()
            G_opt.step()
        
        if step % 500 == 0:
            print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))
        
        if step % 1000 == 0:
            G.eval()
            img = get_sample_image(G, n_noise)
            imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')
            G.train()
        step += 1

MNIST 데이터셋의 결과

profile
열정으로 가득할 페이지

0개의 댓글

관련 채용 정보