PyTorch 실전 - CycleGAN 학습 과정 구현하기

sp·2022년 3월 22일
0

PyTorch 실전

목록 보기
2/4
post-thumbnail

이전 포스트에서는 CycleGAN 모델을 구현하였습니다. 그 다음에는 구현된 모델을 순방향 통과시키고, 계산된 손실로 역전파 알고리즘을 수행하는 코드를 구현해 보도록 하겠습니다.

전체적인 학습과 손실 구조 파악하기

논문에서 제시하는 전체적인 데이터의 흐름은 다음과 같습니다.

  • 2개의 생성자(GG, FF), 2개의 구별자(DxD_x, DyD_y)로 총 4개의 모델이 학습됩니다.

  • (a)에서 GGFF를 통과한 데이터는 구별자로 들어가서 판별됩니다.

  • (b)와 (c)에서 GFG \circ F, FGF \circ G를 통과한 x^\hat{x}y^\hat{y}에 대해 원본 xx, yy로 손실을 계산합니다.

이를 논문에서는 다음과 같은 수식으로 나타내었습니다.

L\mathcal{L}를 생성자와 구별자는 다음과 같은 목적을 가집니다.

간단하게 이야기하면 생성자는 진짜같은 이미지를 생성하는 것이고, 구별자는 진짜와 가짜 이미지를 구별하도록 학습합니다.

전체적인 목적 함수는 크게 두 부분인 LGAN\mathcal{L}_{GAN}Lcyc\mathcal{L}_{cyc}로 이루어져 있습니다. 먼저 GAN 손실입니다.

이 수식은 구별자만을 고려했고, negative log likelihood를 사용해서 큰 페널티를 더 잘 주도록 학습합니다. 그런데, 이를 실질적으로는 활용하지 않고 GAN 문제에서 더 좋은 퀄리티의 이미지를 생성하기 위해 다음과 같이 손실을 정의합니다.

여기서는 생성자와 구별자에서 모두 손실로 고려되었습니다. 생성자에서 추가적으로 생성된 이미지를 진짜로 구별하기 위해 역전파를 수행합니다. 구별자는 진/가짜 구별을 그대로 수행하면서 MSE 함수를 활용합니다.

다음은 cycle-consistancy 손실입니다.

GGFF를 통과해서 원본과 일치하는지 비교합니다. 이는 두 클래스 간의 특징을 활용해 unpaired한 상태에서도 학습을 가능하게 만들어줍니다.

또한, 일부 상황에서는 identity 손실도 다음과 같이 정의합니다.

이는 같은 성질을 가지는 생성자를 통과해도 동일한 상태로 유지함을 나타냅니다.

위에서 정의한 손실들을 정리했을 때, 생성자 학습과 구별자 학습으로 나누어 구현합니다. 생성자와 구별자에서 사용할 손실들을 정리하겠습니다.

  • 생성자에서는 구별자를 통과했을 때 진짜 이미지이도록 하는 GAN 손실, 두 생성자를 통과했을 때의 cycle-consistancy 손실, 일부 경우에서의 identity 손실

  • 구별자에서는 진/가짜 이미지가 주어졌을 때 이를 잘 판단하도록 하는 GAN 손실

Nonation 재정의하기

논문에서 사용하는 nonation이 헷갈릴 수 있으므로, 조금 더 알아보기 쉽게 바꿔보겠습니다. 먼저 모델과 입력 텐서의 이름입니다.

  • GG, FFnetG_A2B, netG_B2A

  • DxD_x, DyD_ynetD_A, netD_B

  • xx, yyreal_A, real_B

두 이미지 스타일을 A와 B로 정의하고, 생성자와 구분자를 G와 D로 사용하고, 텐서 또한 진짜 이미라는 것을 강조하기 위해 real을 사용하였습니다.

다음은 이를 기반으로 손실을 계산하기 위해 모델에 통과한 텐서들도 정의하겠습니다.

  • fake_B = G_A2B(real_A), fake_A = G_B2A(real_B)
  • cycle_A = G_B2A(fake_A), cycle_B = G_A2B(fake_A)
  • identity_A = G_B2A(real_A), identity_B = G_A2B(real_B)

생성자 학습하기

모델의 학습에는 데이터의 순방향 전달로 손실을 계산하고, 역전파 알고리즘으로 파라메터들을 조정하는 단계로 이루어집니다. GAN은 생성자와 구별자 두 종류의 모델, CycleGAN에서는 총 4개의 모델이 학습됩니다. 그래서 생성자와 구별자를 별도로 학습하는데, 먼저 생성자를 학습하는 코드를 구현해보겠습니다.

# 1
real_A = real_A.to(device)
real_B = real_B.to(device)

fake_B = netG_A2B(real_A)
fake_A = netG_B2A(real_B)
cycle_A = netG_B2A(fake_B)
cycle_B = netG_A2B(fake_A)

# 2
pred_fake_A = netD_A(fake_A)
pred_fake_B = netD_B(fake_B)

# 3
loss_cycle_A = criterion_cycle(cycle_A, real_A)
loss_cycle_B = criterion_cycle(cycle_B, real_B)
loss_GAN_A = criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A))
loss_GAN_B = criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B))

# 4
loss_G = lamb * (loss_cycle_A + loss_cycle_B) + loss_GAN_A + loss_GAN_B

# 5
if use_identity_loss:
    identity_A = netG_B2A(real_A)
    identity_B = netG_A2B(real_B)
    loss_identity_A = criterion_identity(identity_A, real_A)
    loss_identity_B = criterion_identity(identity_B, real_B)
    loss_G += 0.5 * lamb * (loss_identity_A + loss_identity_B)

# 6
optim_G.zero_grad()
loss_G.backward()
optim_G.step()
  1. 진짜 두 스타일의 이미지로부터 가짜 이미지, cycle 이미지를 생성합니다.

  2. 생성된 가짜 이미지를 판별자를 통과해 점수를 계산합니다.

  3. cycle과 GAN 손실을 계산합니다. 여기서 생성자가 학습할 때에는 판별자의 점수가 1에 가깝게 학습되어야 하는 점에서 ones_like 함수를 사용해 타겟 벡터를 생성합니다.

  4. λ\lambda로 가중치를 조절해 생성자 손실을 계산합니다.

  5. identity loss를 사용할 때 추가적으로 동작을 수행해 손실을 추가합니다.

  6. optimizer로 역전파 알고리즘을 수행합니다.

구별자 학습하기

구별자는 조금 더 간단합니다. 이 두 구별자의 경우 분리해서 학습할 수 있는 점을 고려해서, 먼저 netD_A를 학습하는 코드를 아래에 구현하였습니다. A와 B의 순서만 바꾸면 net_D_B를 학습하도록 동작합니다.

# 7
pred_real_A = netD_A(real_A)
pred_fake_A = netD_A(fake_A.detach())

# 8
loss_D_A = 0.5 * (
    criterion_GAN(pred_real_A, torch.ones_like(pred_real_A))
    + criterion_GAN(pred_fake_A, torch.zeros_like(pred_fake_A))
)

optim_D_A.zero_grad()
loss_D_A.backward()
optim_D_A.step()
  1. 진짜와 가짜 이미지를 구별자에 통과합니다. 가짜 이미지는 이미 생성자를 통과한 상태이기 때문에 내부 기울기 정보를 가질 수 있는데, 구별자에서는 이를 사용하지 않으므로 detach로 정보를 없애 이미지 정보만 입력으로 주게 됩니다.

  2. GAN 손실을 계산합니다. 여기서는 진짜 이미지에 대해서는 1, 생성된 가짜 이미지에 대해 0으로 추정하도록 학습하게 됩니다. 여기서 2를 나누어 주게 되는데, 이는 논문 중 학습 디테일 일부를 참고한 것입니다.

입력 데이터를 받은 상태에서 모델에 통과시키는 방식으로 학습하는 코드는 구현이 되었습니다. 추가적으로 논문에서 구현해야 하는 부분들은 다음과 같습니다.

  • 모델 가중치 초기화

  • 생성한 이미지들 history로부터 학습

  • 학습률 스케쥴링

이를 다음 포스트부터 하나씩 알아본 다음에, 전체적인 학습 파이프라인에 전체적으로 연결해 학습해보겠습니다.

0개의 댓글