Conditional GAN - MNIST

송용호·2024년 4월 26일
0

GAN

목록 보기
5/5

GAN의 골치 중 하나는 모드 붕괴 현상이다. 학습하기 쉬운 지점을 찾아서 그 지점만 학습해, 한 종류의 이미지만 생성하는 현상이다.

생성하는 이미지를 하나의 클래스로 고정하고, 다양한 이미지를 만든다고 생각해보자. 예를 들어, GAN에게 3을 표현하는 다양한 이미지를 생성하라 or 훈련 데이터에 감정을 나타내는 클래스가 있다면, 행복한 표정의 얼굴만 만들라 이런식으로 말이다.

이런 구조를 조건부 GAN이라고 한다. 생성기는, 주어진 클래스에 해당하는 이미지를 생성하게 하려면, 어떤 클래스를 목표로 하는지 알려줘야한다. 판별기는 클래스와 이미지 사이의 관계를 학습해야한다. 학습하지 못하면, 생성기에 피드백도 못하며 클래스랑 이미지를 연관 짓지도 못한다.
여기서 요는, 판별기와 생성기 모두 이미지 데이터 외에도 클래스 레이블을 추가로 입력 받아야 한다는 것이다.

MNIST GAN가지고 실습해보자.

판별기


이미지와, 클래스 레이블 정보를 동시에 받아야한다.
간단한 방법은 그냥 결합하는거다.

   def forward(self, image_tensor, label_tensor):
       # combine seed and label
       inputs = torch.cat((image_tensor, label_tensor))
       return self.model(inputs)

고로 이미지 텐서의 길이는 784 + 10 = 794가 된다. 모델 define 부분도 수정해주자

        self.model = nn.Sequential(
            nn.Linear(784+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 1),
            nn.Sigmoid()
        )

train 부분도 수정해주자. forward를 바꿨으니 호출하는 함수도 바뀌어야한다.

    def train(self, inputs, label_tensor, targets):
        # calculate the output of the network
        outputs = self.forward(inputs, label_tensor)

이미지와 함께 클래스 레이블이 필요하니, 원핫 인코딩된 레이블 벡터를 만들게하자.

def generate_random_one_hot(size):
    label_tensor = torch.zeros((size))
    random_idx = random.randint(0,size-1)
    label_tensor[random_idx] = 1.0
    return label_tensor

생성기


얜 seed와 레이블 텐서를 투입해야한다. 판별기를 수정했던 방식 그대로 수정해주면 된다.

    def forward(self, seed_tensor, label_tensor):        
        # combine seed and label
        inputs = torch.cat((seed_tensor, label_tensor))
        return self.model(inputs)
        self.model = nn.Sequential(
            nn.Linear(100+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 784),
            nn.Sigmoid()
        )
    def train(self, D, inputs, label_tensor, targets):
        # calculate the output of the network
        g_output = self.forward(inputs, label_tensor)

forward함수에 label과 함께 넘겨주도록 했으니, 생성기가 다른 레이블로 잘못 판단하는걸 막을 수 있을거다.

훈련


%%time 

# train Discriminator and Generator

epochs = 12

for epoch in range(epochs):
  print ("epoch = ", epoch + 1)

  # train Discriminator and Generator

  for label, image_data_tensor, label_tensor in mnist_dataset:
    # train discriminator on true
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))

    # random 1-hot label for generator
    random_label = generate_random_one_hot(10)
    
    # train discriminator on false
    # use detach() so gradients in G are not calculated
    D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0]))
    
    # different random 1-hot label for generator
    random_label = generate_random_one_hot(10)

    # train generator
    G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))

    pass
    
  pass

훈련 반복문을 보면 label_tensor가 들어갔다.

Plot을 봐보자.

판별기 손실



기존과 비슷하고, 손실이 상승하고 있는 것 처럼 보인다.
GAN의 목적 손실 값이 0이 아니기 때문에, 긍정적인 그래프 이다.

생성기 손실



평균이 0이 아니니 이것도 좋은 징조이다.
GAN을 훈련할 때 추가적인 레이블 정보가 도움이 된다는 것을 시사한다.

결과


G.plot_images(9)
G.plot_images(3)
G.plot_images(1)
G.plot_images(5)

다 똑같지 않은 같은 레이블 이미지를 그렸다.

결론


  • 조건부 GAN은 원하는 클래스의 데이터 생성이 가능하다.
  • 판별기에 이미지를 보강해서 전잘하며, 생성기엔 클래스 레이블을 통해 시드가 투입되어야한다.
  • 조건부 GAN은 레이블 정보를 받지 않는 GAN보다 더 나은 데이터를 생산한다.

0개의 댓글