GAN Mode collapse, Wasserstein Loss, Weight Clipping, Gradient Penalty

Jomii·2024년 4월 26일
0

Computer Vision

목록 보기
15/24

Mode collapse


generator가 discriminator가 못맞추는 클래스를 파악해서 그 클래스만 계속 생성해서 discriminator가 전부 오분류하도록 하는것

즉 generator가 local minima에 갇힌 것이다.


Problem with BCE loss

GAN에서 binary cross entropy를 사용할 경우, 학습 초기에는 discriminator의 성능이 좋지 않기 때문에 상관이 없지만 generator보다 비교적 학습이 쉽기 때문에 학습 속도가 빠를 수 있다. 즉 generator가 이미지를 생성해내는 것보다 discriminator가 real/fake이미지 분류를 잘하게 된다. 이때 discriminator가 분류를 잘하게 되면 0에 가까운 gradient를 넘겨주기 때문에 generator의 성능향상이 어려워지는 상황에서 결국 학습이 종료된다.
Generator는 discriminator가 잘못 분류하는 특정한 이미지 모드에만 집중하여 이 모드에 해당하는 이미지를 생성하려고 시도한다. 이렇게 되면 generator가 다양한 이미지를 생성하는 대신 특정한 이미지 패턴에만 집중하게 되고, 결과적으로 mode collapse가 발생할 수 있다.


이를 해결하기 위해 Earth Mover's Distance를 사용한다.

Earth Mover's Distance


Earth mover's distance는 두 분포를 동일하게 만들기 위해서 얼만큼 옮겨야하는지를 계산하는 것

두 분포 사이의 거리를 측정하는 cost function으로 일반적으로 GAN을 훈련할 때 BCE 관련 cost function보다 성능이 뛰어나다.

예를 들어 분포가 흙더미라고 생각하면, 그 흙더미를 움직여 실제 분포의 모양과 위치로 만드는 것은 얼마나 어려울까? ⇒ 이것이 Earth Mover’s Distance

BCE loss에서는 0과 1사이의 값을 가져 1에 가까운 값이 점점 0에 가까워지면서 학습을 멈추게 되는데, Earth Movers’ Distance에서는 이러한 상한선이 없기 때문에 계속 cost를 증가시킬 수 있음


Wasserstein Loss


Earth Mover’s distance를 근사화한 loss

함수는 discriminator의 prediction의 예상치 차이를 계산한다.
여기서 discriminator가 평가하는 역할을 하기 때문에 critic이라고 한다.

discriminator는 이 두 가지를 보고 진짜에 대한 생각과 가짜에 대한 생각 사이의 거리를 최대화하려고 하는 한편, generator는 가짜 이미지가 진짜와 최대한 가깝다고 discriminator가 생각하기를 원하기 때문에 이 차이를 최소화하려고 한다.
거리 기반이기 때문에 0과 1의 한계가 없다.


1-Lipschitz Continuous

다만 신경망에서 너무 큰 숫자는 피해야 하기 때문에 Lipshitz 제약이라는 제약조건을 걸어 critic을 제한한다.

critic에서 이 condition이 W-loss에 중요한 이유는 W-Loss 함수가 연속적이게 될 뿐 아니라 훈련 중에 너무 많이 성장하지 않고 어느 정도 안정성을 유지하도록 보장하기 때문이다.

위 condition을 만족하기 위한 기법으로는 weight clipping과 gradient penalty가 있다.

Weight Clipping

gradient의 norm을 강제적으로 1보다 크지않게 하는 것

넘어가는 값들을 아예 clipping해버리기 때문에 다양한 가중치값을 받아들이지 못해 최적을 찾지 못할 수 있고, critic을 지나치게 제한할 수도 있다는 단점이 있다.

Gradient Penalty

regularization term을 붙임으로서 1-L 연속성을 좀더 부드럽게 강제하는 방식

weight clipping처럼 값을 자르는 것이 아니라 많이 넘어갈수록 제약을 걸어주는 방식이다.

regularization term은 진짜 이미지와 생성된 가짜 이미지를 이용해 interpolation한 중간이미지를 통해 줄수 있다.

즉 x hat은 진짜와 가짜에 대한 가중치를 부여한 이미지이기 때문에 1-L continuous를 엄격하게 강제하는 것이 아니라 권장하는 것이다.

⚙ interpolated image를 사용해 수식 구현

  1. 보간된 이미지 생성
    실제 이미지(real)와 가짜 이미지(fake)를 일정한 epsilon을 주고 섞어서 보간된 이미지를 생성

    x^=ϵ×real+(1ϵ)×fake\hat{x} = ϵ × real + (1−ϵ) × fake

  2. Critic 모델에 보간된 이미지 입력

  3. Critic의 그래디언트 계산 x^D(x^)\nabla _{\hat{x}} D(\hat{x})

  4. Gradient Penalty 계산
    Gradient Penalty는 critic의 그래디언트 norm이 1에서 벗어나는 정도를 측정해 페널티 부여
    penalty=λExPx[(xD(x)21)2]penalty=λEx^∼Px^[(∥∇x^D(x^)∥2−1)2]







Coursera의 Build Basic Generative Adversarial Networks (GANs) 강의를 바탕으로 작성하였습니다.

profile
📩 qtly_u@naver.com

0개의 댓글