Generative Adversarial Nets 논문 리뷰

박민수·2023년 4월 2일
1

Generative Adversarial Nets


Introduction

GAN 논문이 나올 당시까지는 생성 모델들이 많지 않았다. 그에 비해 Classification과 같은 Task를 수행하는 모델들이 발전하고 있는 시기라고 할 수 있었는데, GAN의 등장으로 생성 모델들이 비약적으로 발전했다고 해도 과언이 아닐 것 같다.


Adversarial Nets

Generator & Discriminator

이름에서 알 수 있듯이 GAN에서는 두 가지 네트워크가 경쟁하듯 학습한다. 논문에서는 경찰과 위조 지폐범을 예시로 GAN의 개념을 설명한다. 위조 지폐범은 가짜 지폐를 만들어 경찰을 속이려고 하고, 경찰은 가짜 지폐인지 진짜 지폐인지를 판단할 수 있는 능력을 길러 위조지폐범을 막으려고 할 것이다.

처음에는 위조 지폐범이 가짜 지폐를 잘 만들어내지 못할 것이다. 그렇다면 경찰 입장에서는 가짜 지폐와 진짜 지폐를 구별하기 쉬울 것이고, 구별을 위해 아주아주 쉬운 기준만 세워도 충분히 잘 구별해낼 수 있을 것이다.
그렇지만 시간이 지나면 위조 지폐범이 경찰을 속이기 위해 진짜와 더 비슷한 가짜 지폐를 만들 것이고, 경찰도 그에 맞는 더 구체적인 기준을 세워 진짜 지폐와 가짜 지폐를 구별해야 할 것이다. 둘은 이렇게 서로 경쟁하면서 발전하게 된다.
정말 긴 시간이 지나면 위조 지폐범은 진짜와 다를 바가 없는 거의 똑같은 지폐를 만들게 될 것이고, 경찰은 아무리 좋은 기준을 세워도 구별할 수 없어 50% 확률로 찍는 것과 다름없는 정확도를 가지게 될 것이다.

GAN에서 Generator는 위조지폐범, Discriminator는 경찰의 역할을 수행하게 된다고 생각할 수 있다. Generator는 주어진 데이터의 분포를 파악해 그것과 유사한 데이터를 생성하는 것이 목표이고, Discriminator는 Generator가 생성한 가짜 데이터와 진짜 데이터를 잘 구별하는 것이 목표이다.

원본 데이터 xx가 갖는 분포가 pgp_g라고 하자. 그렇다면 Generator는 pgp_g를 잘 표현할 수 있어야 한다. 그런데 여기서 pgp_g를 표현하는 것이 어렵기 때문에 위해 우리가 잘 알고있는 분포인 prior pz(z)p_z(z)의 분포를 하나 정하고, 그것을 pgp_g로 매핑시켜줄 수 있는 함수 G(z;θg)G(z;\theta_g)를 잘 학습하여 그것으로 pgp_g를 구한다. GAN 논문에서는 이 GG로 MLP를 사용한다. 여기서 prior pzp_z를 통해 샘플링한 zz를 latent vector 라고도 많이 이야기한다.

Discriminator는 Binary Classification Task를 해결하는 것과 같은 구조의 MLP를 사용한다. 진짜 데이터인지 가짜 데이터인지 구분하는 것이기 때문에 둘 중 하나로 분류하는 것으로 생각할 수 있기 때문이다. 결국 데이터를 넣었을 때 그것이 진짜일 확률을 계산하는 함수라고 생각하면 된다.


loss

GAN 에서 사용하는 loss함수는 다음과 같이 표현된다. 기존의 loss 함수는 보통 loss를 줄이는 것에 집중하지만, 여기서는 두 네트워크가 경쟁하는 것이 목표이기 때문에 하나의 네트워크에서는 loss를 줄이고, 나머지 하나의 네트워크에서는 loss를 늘리는 방향으로 학습한다.

minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]\min_{G} \max_{D} V(D, G) = \mathbb{E}_{x\sim p_{data}(x)}[log D(x)] + \mathbb{E}_{z\sim p_z(z)}[log(1 - D(G(z)))]

이제부터 수식을 하나씩 뜯어보도록 하자. DD, GG는 각각 Discriminator와 Generator 함수를 의미하고, pdata(x)p_{data(x)}는 진짜 데이터들의 분포를 의미한다.

그렇다면 V(D,G)V(D, G) 식의 첫 번째 항인
Expdata(x)[logD(x)]\mathbb{E}_{x\sim p_{data}(x)}[log D(x)] 는 실제 데이터들을 Discriminator에 넣어 구한 확률값(진짜 데이터일 확률)에 로그를 씌우는 것이다. 물론 평균 기호이므로 로그값들의 평균을 내준다. 이 항에서는 Generator가 관여할 여지가 없다(식에 GG가 없다는 뜻이다). Generator로 생성한 데이터가 아니기 때문이다. 따라서 이 항으로는 Discrimiator의 파라미터들만 학습하게 된다.

이제 V(D,G)V(D, G) 식의 두 번째 항을 살펴보자.
pz(z)p_z(z) 분포를 우리가 잘 알고 있는 분포로 골랐기 때문에 zz를 얼마든지 샘플링해낼 수 있다. 그렇게 샘플링 한 zzGG에 통과시키면 Generator가 생성한 이미지가 될 것이다.
Ezpz(z)[log(1D(G(z)))]\mathbb{E}_{z\sim p_z(z)}[log(1 - D(G(z)))]는 Generator가 생성한 데이터들을 Discriminator에 넣어 구한 확률값을 1에서 뺀 것(가짜 데이터일 확률)에 로그를 씌우는 것이다. 여기서도 당연히 로그값들의 평균을 내준다. 이 항에서는 Generator가 생성한 데이터를 Discriminator에 통과시키기 때문에 Generator와 Discriminator 모두의 파라미터가 학습된다.

식에서 로그를 씌우는 것은 그냥 정보이론 관점에서 정보량을 구하기 위해 로그를 씌우는 것과 같은 느낌이라고 보면 좋은데, Binary Cross Entropy에서 식에 로그를 씌워야 하는 것과 같은 이유이다. 그러나 후속 논문들에서 꼭 이렇게 loss를 구성하는 것이 좋지만은 않다는 주장이 있어서 크게 중요하지는 않은 것 같다(잘 모르겠으면 넘어가도 괜찮을 것 같다).


여기까지 식이 어떻게 생겼고 각각의 의미가 무엇인지 살짝 살펴 보았는데, 이제부터는 Generator와 Discriminator의 각각의 목표에 대해 살펴보도록 하겠다.

Generator의 목표는

minGEzpz(z)[log(1D(G(z)))]\min_G \mathbb{E}_{z\sim p_z(z)}[log(1 - D(G(z)))]

이다.

의미를 생각해 보면,

Generator가 생성한 데이터를 Discriminator 입장에서 보았을 때 가짜 데이터라고 판단할 확률(혹은 정보량)을 줄인다.

인데, 이는 원래부터 우리가 생각했던 Generator의 목표와 같다. Discriminator를 속이려고 노력한다는 말과 같은 의미이다.

그렇다면 Discriminator의 목표는 아래와 같은데,

maxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]\max_{D} V(D, G) = \mathbb{E}_{x\sim p_{data}(x)}[log D(x)] + \mathbb{E}_{z\sim p_z(z)}[log(1 - D(G(z)))]

이 식의 의미를 생각해 보면,

Generator가 생성한 데이터를 Discriminator 입장에서 보았을 때 가짜 데이터라고 판단할 확률(혹은 정보량)을 늘린다. 또, 실제 데이터를 Discriminator 입장에서 보았을 때 진짜 데이터라고 판단할 확률(혹은 정보량)을 늘린다.

인데, 이것도 원래부터 우리가 생각했던 Discriminator의 목표와 같다. 진짜 데이터인지 가짜 데이터인지 잘 구분하려고 노력한다는 말과 같은 의미인 것이다.

그런데 Discriminator를 학습시킬 때는, 식의 값을 최대로 만드는 것이 목표이기 때문에 당연하게도 Gradient Descent가 아니라 Gradient Ascent를 해야한다. 실제로 계산할 때는 식에 -를 붙여서 Gradient Descent를 한다.

Algorithm

전체 훈련 과정을 표현하면

이렇다. 조금 특이한 점은 Discriminator를 k번 훈련시킬 때 Generator를 한 번만 훈련시킨다는 점이다. Discriminator가 Generator보다 조금 더 앞서 있어야 원활하게 학습될 것이기 때문에 이렇게 하는 것이라고 생각한다.(정확한 이유가 언급되어있지는 않은 것 같다. 그렇지만 ChatGPT가 나랑 똑같은 생각을 하는 것 같다.)
본 논문에서는 k=1을 사용하지만, Generator가 학습되는 속도에 비해 Discriminator가 너무 느리다면 k로 2나 3정도의 값도 넣어봐도 괜찮지 않을까 하는 생각을 한다.
그러나 생각해 봤을 때 Discriminator가 학습해야 하는 것은 단순히 진짜인지 가짜인지를 판단하는 것이고, Generator가 학습되는 것은 무려 데이터의 분포를 파악하는 것이라 Generator가 학습하는 것이 훨씬 어렵고 학습에 걸리는 시간이 길다고 판단되는 경우가 많기에 k=1이더라도 Generator보다는 Discriminator가 빨리 학습되어 문제가 없는 것 같다.

그렇다면 반대로 Discriminator가 너무 빨리 학습되면 어떨까? 그러니까 Generator가 그럴듯한 데이터를 생성해내기 전에 Discriminator가 진짜인지 가짜인지 너무 완벽하게 맞춰 버리는 경우에 어떻게 될까를 생각해 보자는 것이다. 이 경우에는 Ezpz(z)[log(1D(G(z)))]\mathbb{E}_{z\sim p_z(z)}[log(1 - D(G(z)))]항이 제대로 작동하지 않을 것이다(라고 논문에 언급되어 있는데, 사실 완전히 마음에 들지는 않는 설명이다). Generator 입장에서 생성하는 전부가 실패했다는 피드백만으로 학습하는 것은 매우 어려운 일일 것이다.
우리가 뭔가를 공부할 때도 뭐가 맞고, 뭐가 틀린지를 판단하는 눈을 기르려고 문제를 풀면서 공부를 한다. 시간이 지날수록 점점 어려운 문제들을 풀면서 실력을 기르고 디테일들을 알아가게 되는데, 처음부터 너무 어려운 문제들만 풀어서 다 틀리기만 한다면 문제풀이에서 피드백을 얻어 실력이 올라가는 것은 정말 어려운 일이다.
아마 Generator도 이번 면에서 우리랑 비슷한 입장인 것 같다.

이러한 이유들로 사실 GAN을 학습시키는 데에는 Generator와 Discriminator의 학습 속도를 적절히 조절하는 노력이 필요하고, 그렇기에 GAN은 학습 난이도가 대체로 높다고 평가받는다.
나중에 어떤 프로젝트를 진행하면서 문제 해결을 위해 GAN을 사용하려고 한다면 다른 좋은 방법들이 있나 미리 살펴보고 사용하는 것을 추천한다,,


Global Optimality & Convergence of Algorithm

논문에서 제시한 학습법이 유효한지를 판단하기 위해서 논문에서는 두 가지 증명을 하였다.

첫 번째는 pg=pdatap_g=p_{data}로 학습된 경우가 loss의 global optima가 맞는가

두 번째는 위에서 제시한 알고리즘으로 학습하면 수렴하는게 맞는가

이다. 이 내용들에 대한 증명은 논문의 4, 5페이지를 정독해 보는 것이 좋을 것 같다.


Experiments

학습시키고 inference 해 본 결과는 아래와 같다.

a가 MNIST, b가 TFD, c와 d가 CIFAR10에 대해 학습한 것인데, c는 MLP, d는 CNN Discriminator와 deconvolution layer를 이용한 Generator로 학습한 결과이다.
노란 테두리로 되어있는 것들은 데이터셋에서 가장 비슷한 샘플을 찾은 것이다. 봤을 때 MNIST는 정말 잘한다는 것을 알 수 있고 다른 데이터셋들은 조금 흐릿하긴 하지만, 그럼에도 불구하고 데이터들의 분포를 잘 학습했음을 알 수 있다. 솔직히 CIFAR-10은 너무 저해상도라 우리가 봐도 못 맞출법한 데이터가 많다고 생각하는데, 그것들로 학습해서 이정도 퀄리티면 잘 한게 아닐까?

마무리

GAN에는 정말 많은 문제점들이 있지만, 문제점들을 해결하기 위한 GAN에 대한 선행 연구도 너무나도 많다. 그래서 GAN을 사용해 문제를 해결하려고 한다면 선행 연구에 대한 공부가 정말정말 중요하다고 생각한다. 추후에 GAN의 후속 논문들도 다뤄볼 예정이다.

아래는 필자가 GAN 구현 연습해 본 Github이다. MNIST 데이터이지만 아무리 봐도 학습이 잘 안된 것 같다. 그렇지만 구현 자체에는 큰 문제가 없는 것 같고, 파라미터 튜닝 잘 해보면 될 것 같은데 필자는 실패했다...
Batch size가 GAN 학습에 큰 영향을 주는 것 같고, 지금 생각해보면 latent vector의 길이를 너무 길게 주고 학습시켜서 문제가 생겼을 수도 있을 것 같다. 추후에 줄여서 시도해보면 수정하겠다.
Github

0개의 댓글