cs231n 과제 3 Q3- Generative Adversarial Networks

이준학·2024년 7월 17일

cs231n 과제

목록 보기
13/15

  이번 과제는 GAN을 구현해보는 과제이다. GAN은 generator와 discriminator의 2-player game 구조를 가진 재밌는 모델이다. Generator는 실제 사진과 비슷한 가짜 사진을 만들어 discriminator가 헷갈리게 하는 것이 목표이고, Discriminator는 Generator가 만든 사진과 실제 사진을 잘 구별해 내는 것이 목표이다. 이 둘이 서로를 속고 속이며 훈련하고, generator 네트워크를 이용해 새로운 이미지들을 만들어내는 방식이다. 이번에도 역시, 구현하면서 중요하거나 헷갈렸던 부분들을 짚어보겠다. pytorch를 이용해서 구현해 그닥 어렵진 않았다.

1. discriminator_loss()

  이 함수에서는 reshape을 해주는 것이 키 포인트였다. 그냥 loss 계산을 하면 두 matrix의 차원이 맞지 않는다는 오류 메세지가 뜬다. 따라서 다 reshape를 해주면 문제를 해결할 수 있다.

def discriminator_loss(logits_real, logits_fake):
    """
    Computes the discriminator loss described above.

    Inputs:
    - logits_real: PyTorch Tensor of shape (N,) giving scores for the real data.
    - logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.

    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the discriminator.
    """
    loss = None
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    loss_real=bce_loss(logits_real.reshape(-1), torch.ones_like(logits_real,dtype=float).reshape(-1))
    loss_fake=bce_loss(logits_fake.reshape(-1),torch.zeros_like(logits_fake,dtype=float).reshape(-1))
    loss=loss_fake+loss_real

    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    return loss

  또 하나 중요한 부분은 real loss와 fake loss를 계산할 때 true target 값의 설정이다. 함수에서 logits_real이 E[logD(x)]E[logD(x)]이므로 1을 target 값으로 주고, E[log(1D(G(Z)))]E[log(1-D(G(Z)))]는 logits_fake에 해당하므로 target 값을 0으로 준다. Discriminator와 Generator를 구현한 후, 이미 구현되어 있는 run_a_gan 함수를 통해 네트워크를 학습시킨다. solver를 D,G 2개에 대해 따로 만들어주고, 각각 훈련시키는 방식이다. 이전의 모델 solver와 같이 loss를 계산하고, backward를 한 후, step을 하나 앞으로 가는 과정을 반복한다. 과제 코드를 보면 더 자세히 이해할 수 있다.
 이렇게 Vanilla GAN을 구현한 후, LSGAN과 DCGAN도 구현한다. 과제에 적혀 있는대로 Discriminator와 Generator를 만들면 되기 때문에 어렵지 않다.

2. Pytorch

  Pytorch 함수들 중 헷갈렸던 것들을 정리하려고 한다.

1) nn.BatchNorm1d VS nn.BatchNorm2d

  위의 함수 모두 BatchNorm 함수들이지만, 차원에 차이가 있다. 정확히 말하면, BN을 적용하는 data에 따라 사용할 함수가 적용하는 것이다. 말 그대로 1d는 1d data인 FC layer나 1D convolution에 적용하고, 2d는 2d convolution이나 image data에 적용하면 되는 것이다.

2)nn.ConvTranspose2d()

   이 함수는 우리가 흔히 말하는 Deconvolution, 혹은 transposed convolution이라고 부르는 연산을 지원해준다. transposed convolution은 upsample된 output을 얻을 수 있는 장점이 있다. 자세히는 적지 않겠다.

3. 정리

  이번 과제는 생각보다 적을 것이 많지 않았다. pytorch를 오랜만에 해봐서 조금 헷갈리는 부분들이 있었던 것 같다. noise에서 시작해 학습을 거듭할수록 유의미한 image를 만들어내는 네트워크를 실제로 보니 신기하기도 했다.

내 과제 풀이:

https://github.com/danlee0113/cs231n

profile
AI/ Computer Vision

0개의 댓글