이번 과제는 GAN을 구현해보는 과제이다. GAN은 generator와 discriminator의 2-player game 구조를 가진 재밌는 모델이다. Generator는 실제 사진과 비슷한 가짜 사진을 만들어 discriminator가 헷갈리게 하는 것이 목표이고, Discriminator는 Generator가 만든 사진과 실제 사진을 잘 구별해 내는 것이 목표이다. 이 둘이 서로를 속고 속이며 훈련하고, generator 네트워크를 이용해 새로운 이미지들을 만들어내는 방식이다. 이번에도 역시, 구현하면서 중요하거나 헷갈렸던 부분들을 짚어보겠다. pytorch를 이용해서 구현해 그닥 어렵진 않았다.
이 함수에서는 reshape을 해주는 것이 키 포인트였다. 그냥 loss 계산을 하면 두 matrix의 차원이 맞지 않는다는 오류 메세지가 뜬다. 따라서 다 reshape를 해주면 문제를 해결할 수 있다.
def discriminator_loss(logits_real, logits_fake):
"""
Computes the discriminator loss described above.
Inputs:
- logits_real: PyTorch Tensor of shape (N,) giving scores for the real data.
- logits_fake: PyTorch Tensor of shape (N,) giving scores for the fake data.
Returns:
- loss: PyTorch Tensor containing (scalar) the loss for the discriminator.
"""
loss = None
# *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
loss_real=bce_loss(logits_real.reshape(-1), torch.ones_like(logits_real,dtype=float).reshape(-1))
loss_fake=bce_loss(logits_fake.reshape(-1),torch.zeros_like(logits_fake,dtype=float).reshape(-1))
loss=loss_fake+loss_real
# *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
return loss
또 하나 중요한 부분은 real loss와 fake loss를 계산할 때 true target 값의 설정이다. 함수에서 logits_real이 이므로 1을 target 값으로 주고, 는 logits_fake에 해당하므로 target 값을 0으로 준다. Discriminator와 Generator를 구현한 후, 이미 구현되어 있는 run_a_gan 함수를 통해 네트워크를 학습시킨다. solver를 D,G 2개에 대해 따로 만들어주고, 각각 훈련시키는 방식이다. 이전의 모델 solver와 같이 loss를 계산하고, backward를 한 후, step을 하나 앞으로 가는 과정을 반복한다. 과제 코드를 보면 더 자세히 이해할 수 있다.
이렇게 Vanilla GAN을 구현한 후, LSGAN과 DCGAN도 구현한다. 과제에 적혀 있는대로 Discriminator와 Generator를 만들면 되기 때문에 어렵지 않다.
Pytorch 함수들 중 헷갈렸던 것들을 정리하려고 한다.
위의 함수 모두 BatchNorm 함수들이지만, 차원에 차이가 있다. 정확히 말하면, BN을 적용하는 data에 따라 사용할 함수가 적용하는 것이다. 말 그대로 1d는 1d data인 FC layer나 1D convolution에 적용하고, 2d는 2d convolution이나 image data에 적용하면 되는 것이다.
이 함수는 우리가 흔히 말하는 Deconvolution, 혹은 transposed convolution이라고 부르는 연산을 지원해준다. transposed convolution은 upsample된 output을 얻을 수 있는 장점이 있다. 자세히는 적지 않겠다.
이번 과제는 생각보다 적을 것이 많지 않았다. pytorch를 오랜만에 해봐서 조금 헷갈리는 부분들이 있었던 것 같다. noise에서 시작해 학습을 거듭할수록 유의미한 image를 만들어내는 네트워크를 실제로 보니 신기하기도 했다.