[7주차] 조건부 GAN - 20220120

김동영·2022년 1월 20일
0

이제까지 만들어온 GAN 출력물들은,

임의의 랜덤 시드값에 대해서 결과를 만들어냈다.

mnist 예제로 들자면,

3을 목표로 3 이미지를 만들어내는 것이 아니라,

어떠한 시드 값에 대해 만들어진 이미지가 3 이미지였다.

즉, 내가 결과에 대한 의도를 생성기에 전달하지는 않았다.

이번에는 특정한 결과를 만들어내는 GAN을 만들어보자.

이전에 만든 판별기에서,

forward 함수에서 입력값으로 들어오는 seed에 label을 결합시켜주면 된다.

    def forward(self, image_tensor, label_tensor):
        inputs = torch.cat((image_tensor, label_tensor))
        return self.model(inputs)

torch.cat은 하나의 텐서를 다른 텐서의 뒤에 이어준다.

이미지 텐서는 길이가 784, 레이블 텐서의 길이는 10이므로

합해진 텐서는 길이가 794가 된다.

이에 맞추어서, 신경망 또한 길이를 10 늘려준다

        self.model = nn.Sequential(
            nn.Linear(784+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 1),
            nn.Sigmoid()
        )

또한 train 함수에서 입력으로 label_tensor를 전달하도록 한다.

    def train(self, inputs, label_tensor, targets):
      
        outputs = self.forward(inputs, label_tensor)
        
        loss = self.loss_function(outputs, targets)

        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())
            
        if (self.counter % 10000 == 0):
            print("counter = ", self.counter)

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

생성기에서도 마찬가지로 label_tensor를 입력 받아 seed 와 합쳐주어야 한다.

    def forward(self, seed_tensor, label_tensor):        
        inputs = torch.cat((seed_tensor, label_tensor))
        return self.model(inputs)

네트워크에도 10개의 label_tensor만큼 사이즈를 늘려준다

        self.model = nn.Sequential(
            nn.Linear(100+10, 200),
            nn.LeakyReLU(0.02),

            nn.LayerNorm(200),

            nn.Linear(200, 784),
            nn.Sigmoid()
        )

그리고 train 함수에서도 label_tensor를 받도록 수정한다.

def train(self, D, inputs, label_tensor, targets):
        g_output = self.forward(inputs, label_tensor)
        
        d_output = D.forward(g_output, label_tensor)
        
        loss = D.loss_function(d_output, targets)

        self.counter += 1;
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()

마지막으로, 학습 후, 생성기를 통해 이미지를 확인할 때도 label_tensor를 받을 수 있도록 수정한다.

    def plot_images(self, label):
        label_tensor = torch.zeros((10))
        label_tensor[label] = 1.0
        # plot a 3 column, 2 row array of sample images
        f, axarr = plt.subplots(2,3, figsize=(16,8))
        for i in range(2):
            for j in range(3):
                axarr[i,j].imshow(G.forward(generate_random_seed(100), label_tensor).detach().cpu().numpy().reshape(28,28), interpolation='none', cmap='Blues')

학습을 위해 임의의 random label_tensor를 만들 수 있도록 하고

def generate_random_one_hot(size):
    label_tensor = torch.zeros((size))
    random_idx = random.randint(0,size-1)
    label_tensor[random_idx] = 1.0
    return label_tensor
epochs = 12

for epoch in range(epochs):
  print ("epoch = ", epoch + 1)

  for label, image_data_tensor, label_tensor in mnist_dataset:
    D.train(image_data_tensor, label_tensor, torch.FloatTensor([1.0]))

    random_label = generate_random_one_hot(10)
    
    D.train(G.forward(generate_random_seed(100), random_label).detach(), random_label, torch.FloatTensor([0.0]))
    
    random_label = generate_random_one_hot(10)

    G.train(D, generate_random_seed(100), random_label, torch.FloatTensor([1.0]))

이제까지 한 것 처럼 판별기와 생성기의 훈련시에 label_tensor를 추가적으로 제공한다.

학습된 생성기에 목표 label을 제공했을 때 생성되는 이미지를 확인해보면,

G.plot_images(9)

G.plot_images(3)

이와 같이 목표 이미지를 생성할 수 있는 것을 볼 수 있다.

이를 다른 도메인에 적용하면,

감정이 있는 얼굴, 특정 색의 꽃 등의 목적을 가지고 이미지를 생성할 수 있을 것이다.

이를 위해서는 사전 훈련 데이터에 레이블링이 잘 되어있어야 할 것이다.

profile
오래 공부하는 사람

0개의 댓글