What is Gan?

  • ์ ๋Œ€์  ์ƒ์„ฑ ์‹ ๊ฒฝ๋ง(Generative Adversarial Networks)
  • ๋‘๊ฐœ์˜์‹ ๊ฒฝ๋ง๋ชจ๋ธ์ด์„œ๋กœ๊ฒฝ์Ÿ => ๋”๋‚˜์€๊ฒฐ๊ณผ๋ฅผ๋งŒ๋“ฆ
    • Generator(์ƒ์„ฑ์ž) : ๊ฑฐ์ง“ ๋ฐ์ดํ„ฐ ์ƒ์„ฑํ•˜๋Š” ๋ชจ๋ธ
    • Discriminator(๊ฐ๋ณ„์ž) : ์‹ค์ œ ๋ฐ์ดํ„ฐ์™€ ๊ฑฐ์ง“ ๋ฐ์ดํ„ฐ ๊ตฌ๋ถ„ํ•˜๋Š” ๋ชจ๋ธ
  • ๊ฒฝ์Ÿ์œผ๋กœ์ธํ•œ๋‘๋ชจ๋ธ์˜์„ฑ๋Šฅโ†‘
  • How to train?โžก๏ธ ๋ชจ๋ธ์ด ์ƒ์„ฑํ•œ ํ™•๋ฅ ๋ถ„ํฌ๊ฐ€ ์‹ค์ œ ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ์™€ ๋™์ผํ•ด์กŒ์Œ D(X) = 1/2 => ํ™•๋ฅ  50%(์ฐ๊ธฐ)

What is DCGAN?

  • Deep Convolutional Generative Adversarial Network
  • ๊ธฐ์กด GAN์— ์ปจ๋ณผ๋ฃจ์ „๋ง์„ ์ ์šฉ

Generator

  • Pooling layer ์‚ฌ์šฉ x
    • Unpooling ํ•  ๋•Œ, blockyํ•œ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋˜๊ธฐ ๋•Œ๋ฌธ
    • Deconvolution ์‚ฌ์šฉ
  • Stride 2์ด์ƒ์„ ์‚ฌ์šฉํ•ด feature map์„ ํ‚ค์›€
  • Batch normalization์„ ์‚ฌ์šฉํ•ด์„œ ํ•™์Šต์˜ ์•ˆ์ •์„ฑ์„ ํ‚ค์›€
  • ๋งˆ์ง€๋ง‰ ์ถœ๋ ฅ์—์„œ tanh ์‚ฌ์šฉ -> ์ถœ๋ ฅ๊ฐ’์„ [-1,1]๋กœ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•จ

Discriminator

  • 64x64x3 ์ด๋ฏธ์ง€์ž…๋ ฅ
  • Conv2D, BatchNorm2D, and LeakyRelu ํ†ตํ•ด ๋ฐ์ดํ„ฐ ๊ฐ€๊ณต
  • ๋งˆ์ง€๋ง‰์— sigmoidํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•ด 0~1 ์‚ฌ์ด์˜ ํ™•๋ฅ ๊ฐ’์œผ๋กœ ์กฐ์ •
  • Stride2 ์“ฐ๋Š” ์ด์œ  : ์‹ ๊ฒฝ๋ง๋‚ด์—์„œ ์Šค์Šค๋กœ ํ’€๋ง ํ•จ์ˆ˜๋ฅผ ํ•™์Šตํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋Š” ๊ณผ์ •์—์„œ ์ง์ ‘์ ์œผ๋กœ ํ’€๋ง ๊ณ„์ธต (MaxPool, AvgPooling)์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค ์œ ๋ฆฌํ•˜๋‹ค๊ณ  ํ•œ๋‹ค.

Code

Data

# ์šฐ๋ฆฌ๊ฐ€ ์„ค์ •ํ•œ ๋Œ€๋กœ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์™€ ๋ด…์‹œ๋‹ค
# ๋จผ์ € ๋ฐ์ดํ„ฐ์…‹์„ ๋งŒ๋“ญ๋‹ˆ๋‹ค
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# dataloader๋ฅผ ์ •์˜ํ•ด๋ด…์‹œ๋‹ค
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# GPU ์‚ฌ์šฉ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •ํ•ด ์ค๋‹ˆ๋‹ค
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# ํ•™์Šต ๋ฐ์ดํ„ฐ๋“ค ์ค‘ ๋ช‡๊ฐ€์ง€ ์ด๋ฏธ์ง€๋“ค์„ ํ™”๋ฉด์— ๋„์›Œ๋ด…์‹œ๋‹ค
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™”

# ``netG`` ์™€ ``netD`` ์— ์ ์šฉ์‹œํ‚ฌ ์ปค์Šคํ…€ ๊ฐ€์ค‘์น˜ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

ํ‰๊ท ์ด 0์ด๊ณ  ๋ถ„์‚ฐ์ด 0.02์ธ ์ •๊ทœ๋ถ„ํฌ๋ฅผ ์‚ฌ์šฉํ•ด์„œ G์™€ D๋ฅผ ๋ชจ๋‘ ๋ฌด์ž‘์œ„ ์ดˆ๊ธฐํ™”๋ฅผ ์ง„ํ–‰ํ•˜๋Š” ๊ฒƒ์ด ์ข‹๋‹ค

์œ„ ํ•จ์ˆ˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜๋กœ ๋ชจ๋ธ์„ ์ž…๋ ฅ๋ฐ›์•„ ๊ฐ€์ค‘์น˜๋“ค์„ ๋ชจ๋‘ ์ดˆ๊ธฐํ™” ํ•จ. ๋ชจ๋ธ์ด ๋งŒ๋“ค์–ด์ง€์ž ๋งˆ์ž ์ด ํ•จ์ˆ˜๋ฅผ call ํ•ด์„œ ์ ์šฉ์„ ์‹œํ‚ด.

Generator

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # ์ž…๋ ฅ๋ฐ์ดํ„ฐ Z๊ฐ€ ๊ฐ€์žฅ ์ฒ˜์Œ ํ†ต๊ณผํ•˜๋Š” ์ „์น˜ ํ•ฉ์„ฑ๊ณฑ ๊ณ„์ธต์ž…๋‹ˆ๋‹ค.
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

nz : ์ž…๋ ฅ๋ฒกํ„ฐ z์˜ ๊ธธ์ด
ngf : ์ƒ์„ฑ์ž๋ฅผ ํ†ต๊ณผํ•˜๋Š” ํŠน์ • ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ
nc : ์ถœ๋ ฅ ์ด๋ฏธ์ง€์˜ ์ฑ„๋„ ๊ฐœ์ˆ˜ (์ฝ”๋“œ์—์„œ๋Š” RGB ์ด๋ฏธ์ง€ ์ด๋ฏ€๋กœ nc = 3)

# ์ƒ์„ฑ์ž๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค
netG = Generator(ngpu).to(device)

# ํ•„์š”ํ•œ ๊ฒฝ์šฐ multi-GPU๋ฅผ ์„ค์ • ํ•ด์ฃผ์„ธ์š”
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# ๋ชจ๋“  ๊ฐ€์ค‘์น˜์˜ ํ‰๊ท ์„ 0( ``mean=0`` ), ๋ถ„์‚ฐ์„ 0.02( ``stdev=0.02`` )๋กœ ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ ์œ„ํ•ด
# ``weight_init`` ํ•จ์ˆ˜๋ฅผ ์ ์šฉ์‹œํ‚ต๋‹ˆ๋‹ค
netG.apply(weights_init)

# ๋ชจ๋ธ์˜ ๊ตฌ์กฐ๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค
print(netG)

์ƒ์„ฑ์ž ๋ชจ๋ธ์˜ ์ธ์Šคํ„ด์Šค๋ฅผ ๋งŒ๋“ค์–ด weights_init ํ•จ์ˆ˜๋ฅผ ์ ์šฉ

Discriminator

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ๋Š” ``(nc) x 64 x 64`` ์ž…๋‹ˆ๋‹ค
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # ์œ„์˜ ๊ณ„์ธต์„ ํ†ต๊ณผํ•œ ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

Discriminator๊ฐ€ ํ•„์š”ํ•œ ๊ฒฝ์šฐ, ๋” ๋‹ค์–‘ํ•œ layer๋“ค์„ ์Œ“์„ ์ˆ˜ ์žˆ์ง€๋งŒ, Batch Normalization, LeaklyReLU, strided ํ•ฉ์„ฑ๊ณฑ ๊ณ„์ธต์„ ์‚ฌ์šฉํ•จ

Why? => DCGAN ๋…ผ๋ฌธ์—์„œ ๋ณดํญ์ด ์žˆ๋Š”(Strided) ํ•ฉ์„ฑ๊ณฑ ๊ณ„์ธต์„ ์‚ฌ์šฉ ํ•˜๋Š” ๊ฒƒ์ด ์‹ ๊ฒฝ๋ง ๋‚ด์—์„œ ์Šค์Šค๋กœ Pooling ํ•จ์ˆ˜๋ฅผ ํ•™์Šตํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์ง์ ‘์ ์œผ๋กœ Pooling ๊ณ„์ธต(MaxPool, AvgPooling)์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ ๋ณด๋‹ค ์ข‹์Œ

# ๊ตฌ๋ถ„์ž๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค
netD = Discriminator(ngpu).to(device)

# ํ•„์š”ํ•œ ๊ฒฝ์šฐ multi-GPU๋ฅผ ์„ค์ • ํ•ด์ฃผ์„ธ์š”
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# ๋ชจ๋“  ๊ฐ€์ค‘์น˜์˜ ํ‰๊ท ์„ 0( ``mean=0`` ), ๋ถ„์‚ฐ์„ 0.02( ``stdev=0.02`` )๋กœ ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ ์œ„ํ•ด
# ``weight_init`` ํ•จ์ˆ˜๋ฅผ ์ ์šฉ์‹œํ‚ต๋‹ˆ๋‹ค
netD.apply(weights_init)

# ๋ชจ๋ธ์˜ ๊ตฌ์กฐ๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค
print(netD)

Loss function๊ณผ Optimizer ์ •์˜

# ``BCELoss`` ํ•จ์ˆ˜์˜ ์ธ์Šคํ„ด์Šค๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค
criterion = nn.BCELoss()

# ์ƒ์„ฑ์ž์˜ ํ•™์Šต์ƒํƒœ๋ฅผ ํ™•์ธํ•  ์ž ์žฌ ๊ณต๊ฐ„ ๋ฒกํ„ฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# ํ•™์Šต์— ์‚ฌ์šฉ๋˜๋Š” ์ฐธ/๊ฑฐ์ง“์˜ ๋ผ๋ฒจ์„ ์ •ํ•ฉ๋‹ˆ๋‹ค
real_label = 1.
fake_label = 0.

# G์™€ D์—์„œ ์‚ฌ์šฉํ•  Adam์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

Pytorch์—์„œ๋Š” ์†์‹คํ•จ์ˆ˜๋ฅผ Binary Cross Entropy Loss (BCELoss) ์‚ฌ์šฉ ์†์‹คํ•จ์ˆ˜๋Š” ์œ„์—์„œ ์„ค๋ช…ํ•œ ํ•จ์ˆ˜์™€ ๋น„์Šท

  • Discriminator ์†์‹คํ•จ์ˆ˜
    • ( 1 ) y๊ฐ€ 1์ผ ๊ฒฝ์šฐ(real data์ผ ๊ฒฝ์šฐ), Xn์— ์‹ค์ œ ๋ฐ์ดํ„ฐ๋ฅผ ๋„ฃ์–ด์ค€๋‹ค. (y = 1, Xn = D(X))
    • ( 2 ) y๊ฐ€ 0์ผ ๊ฒฝ์šฐ(fake data์ผ ๊ฒฝ์šฐ), Xn์— ์ƒˆ๋กœ ์ƒ์„ฑํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ๋„ฃ์–ด์ค€๋‹ค. (y= 0, Xn = D(G(Z)))
  • ์œ„ ( 1 ), ( 2 )๋ฅผ ๊ฐ๊ฐ ln์— ๋„ฃ์–ด์ฃผ๋ฉด,
    ( 1 ) -logD(X)
    ( 2 ) โ€“log(1-D(G(Z)))
    ์ด ๋‘˜์„ ๋”ํ•œ โ€“logD(X)โ€“log(1-D(G(Z))) ๋ฅผ ์ตœ์†Œํ™” ์‹œ์ผœ์•ผํ•จ โ†”๏ธ logD(X)+log(1-D(G(Z))) ๋ฅผ ์ตœ๋Œ€ํ™” ์‹œํ‚ค๋ฉด ๋จ!!
  • Generator ์†์‹ค ํ•จ์ˆ˜
    • Original GAN: ์ƒ์„ฑ์ž๋Š” log(1-D(G(Z)))๋ฅผ ์ตœ์†Œํ™” ์‹œํ‚ค๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šต์‹œํ‚ด
    • But, ์ถฉ๋ถ„ํ•œ ๋ณ€ํ™” x, ํ•™์Šต์ดˆ๊ธฐ์— ๋ฌธ์ œ๊ฐ€ ์ƒ๊ฒจ ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด log(D(G(Z)))๋ฅผ ์ตœ๋Œ€ํ™” ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ํ•™์Šตํ•œ๋‹ค.
    • log(D(G(Z)))๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๊ฒƒ์€ D(G(Z)) ๊ฐ’์„ ์ตœ๋Œ€๋กœ ํ•˜๋Š” ๊ฒƒ๊ณผ ๋™์ผํ•œ ์˜๋ฏธ์ด๋ฉฐ, ์ด๋Š” Discriminator๊ฐ€ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ "์ง„์งœ๋กœ ํŒ๋ณ„ํ•˜๋„ ๋ก ์†์ด๋Š”" ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

Training

# ํ•™์Šต์ƒํƒœ๋ฅผ ์ฒดํฌํ•˜๊ธฐ ์œ„ํ•ด ์†์‹ค๊ฐ’๋“ค์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# ์—ํญ(epoch) ๋ฐ˜๋ณต
for epoch in range(num_epochs):
    # ํ•œ ์—ํญ ๋‚ด์—์„œ ๋ฐฐ์น˜ ๋ฐ˜๋ณต
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) D ์‹ ๊ฒฝ๋ง์„ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค: log(D(x)) + log(1 - D(G(z)))๋ฅผ ์ตœ๋Œ€ํ™” ํ•ฉ๋‹ˆ๋‹ค
        ###########################
        ## ์ง„์งœ ๋ฐ์ดํ„ฐ๋“ค๋กœ ํ•™์Šต์„ ํ•ฉ๋‹ˆ๋‹ค
        netD.zero_grad()
        # ๋ฐฐ์น˜๋“ค์˜ ์‚ฌ์ด์ฆˆ๋‚˜ ์‚ฌ์šฉํ•  ๋””๋ฐ”์ด์Šค์— ๋งž๊ฒŒ ์กฐ์ •ํ•ฉ๋‹ˆ๋‹ค
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label,
                           dtype=torch.float, device=device)
        # ์ง„์งœ ๋ฐ์ดํ„ฐ๋“ค๋กœ ์ด๋ฃจ์–ด์ง„ ๋ฐฐ์น˜๋ฅผ D์— ํ†ต๊ณผ์‹œํ‚ต๋‹ˆ๋‹ค
        output = netD(real_cpu).view(-1)
        # ์†์‹ค๊ฐ’์„ ๊ตฌํ•ฉ๋‹ˆ๋‹ค
        errD_real = criterion(output, label)
        # ์—ญ์ „ํŒŒ์˜ ๊ณผ์ •์—์„œ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
        errD_real.backward()
        D_x = output.mean().item()

        ## ๊ฐ€์งœ ๋ฐ์ดํ„ฐ๋“ค๋กœ ํ•™์Šต์„ ํ•ฉ๋‹ˆ๋‹ค
        # ์ƒ์„ฑ์ž์— ์‚ฌ์šฉํ•  ์ž ์žฌ๊ณต๊ฐ„ ๋ฒกํ„ฐ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # G๋ฅผ ์ด์šฉํ•ด ๊ฐ€์งœ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค
        fake = netG(noise)
        label.fill_(fake_label)
        # D๋ฅผ ์ด์šฉํ•ด ๋ฐ์ดํ„ฐ์˜ ์ง„์œ„๋ฅผ ํŒ๋ณ„ํ•ฉ๋‹ˆ๋‹ค
        output = netD(fake.detach()).view(-1)
        # D์˜ ์†์‹ค๊ฐ’์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
        errD_fake = criterion(output, label)
        # ์—ญ์ „ํŒŒ๋ฅผ ํ†ตํ•ด ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ์ด๋•Œ ์•ž์„œ ๊ตฌํ•œ ๋ณ€ํ™”๋„์— ๋”ํ•ฉ๋‹ˆ๋‹ค(accumulate)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # ๊ฐ€์งœ ์ด๋ฏธ์ง€์™€ ์ง„์งœ ์ด๋ฏธ์ง€ ๋ชจ๋‘์—์„œ ๊ตฌํ•œ ์†์‹ค๊ฐ’๋“ค์„ ๋”ํ•ฉ๋‹ˆ๋‹ค
        # ์ด๋•Œ errD๋Š” ์—ญ์ „ํŒŒ์—์„œ ์‚ฌ์šฉ๋˜์ง€ ์•Š๊ณ , ์ดํ›„ ํ•™์Šต ์ƒํƒœ๋ฅผ ๋ฆฌํฌํŒ…(reporting)ํ•  ๋•Œ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค
        errD = errD_real + errD_fake
        # D๋ฅผ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค
        optimizerD.step()

        ############################
        # (2) G ์‹ ๊ฒฝ๋ง์„ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค: log(D(G(z)))๋ฅผ ์ตœ๋Œ€ํ™” ํ•ฉ๋‹ˆ๋‹ค
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # ์ƒ์„ฑ์ž์˜ ์†์‹ค๊ฐ’์„ ๊ตฌํ•˜๊ธฐ ์œ„ํ•ด ์ง„์งœ ๋ผ๋ฒจ์„ ์ด์šฉํ•  ๊ฒ๋‹ˆ๋‹ค
        # ์šฐ๋ฆฌ๋Š” ๋ฐฉ๊ธˆ D๋ฅผ ์—…๋ฐ์ดํŠธํ–ˆ๊ธฐ ๋•Œ๋ฌธ์—, D์— ๋‹ค์‹œ ๊ฐ€์งœ ๋ฐ์ดํ„ฐ๋ฅผ ํ†ต๊ณผ์‹œํ‚ต๋‹ˆ๋‹ค.
        # ์ด๋•Œ G๋Š” ์—…๋ฐ์ดํŠธ๋˜์ง€ ์•Š์•˜์ง€๋งŒ, D๊ฐ€ ์—…๋ฐ์ดํŠธ ๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์— ์•ž์„  ์†์‹ค๊ฐ’๊ฐ€ ๋‹ค๋ฅธ ๊ฐ’์ด ๋‚˜์˜ค๊ฒŒ ๋ฉ๋‹ˆ๋‹ค
        output = netD(fake).view(-1)
        # G์˜ ์†์‹ค๊ฐ’์„ ๊ตฌํ•ฉ๋‹ˆ๋‹ค
        errG = criterion(output, label)
        # G์˜ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค
        errG.backward()
        D_G_z2 = output.mean().item()
        # G๋ฅผ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค
        optimizerG.step()

        # ํ›ˆ๋ จ ์ƒํƒœ๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # ์ดํ›„ ๊ทธ๋ž˜ํ”„๋ฅผ ๊ทธ๋ฆฌ๊ธฐ ์œ„ํ•ด ์†์‹ค๊ฐ’๋“ค์„ ์ €์žฅํ•ด๋‘ก๋‹ˆ๋‹ค
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # fixed_noise๋ฅผ ํ†ต๊ณผ์‹œํ‚จ G์˜ ์ถœ๋ ฅ๊ฐ’์„ ์ €์žฅํ•ด๋‘ก๋‹ˆ๋‹ค
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

Part 1) Discriminator์˜ ํ•™์Šต

1) ์ง„์งœ ๋ฐ์ดํ„ฐ๋“ค๋กœ๋งŒ ์ด๋ฃจ์–ด์ง„ ๋ฐฐ์น˜๋ฅผ ๋งŒ๋“ค์–ด D์— ํ†ต ๊ณผ์‹œํ‚จ๋‹ค.

  • ์ถœ๋ ฅ๊ฐ’์œผ๋กœ log(D(x))์˜ ์†์‹ค๊ฐ’์„ ๊ณ„์‚ฐ
  • backpropagation ๊ณ ์ •์—์„œ์˜ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐ

2) ๊ฐ€์งœ ๋ฐ์ดํ„ฐ๋“ค๋กœ๋งŒ ์ด๋ฃจ์–ด์ง„ ๋ฐฐ์น˜๋ฅผ ๋งŒ๋“ค์–ด D์— ํ†ต ๊ณผ์‹œํ‚จ๋‹ค.

  • ๊ทธ ์ถœ๋ ฅ๊ฐ’์œผ๋กœ log(1-D(G(z)))์˜ ์†์‹ค๊ฐ’ ๊ณ„์‚ฐ
  • backprobagation ๋ณ€ํ™”๋„ ๊ณ„์‚ฐ

์ด๋•Œ, ๋‘ ๊ฐ€์ง€ ์Šคํ…์—์„œ ๋‚˜์˜ค๋Š” ๋ณ€ํ™”๋„๋“ค์„ ์ถ•์ ์‹œ์ผœ์•ผํ•œ๋‹ค.

  • backpropagation ๊ณ„์‚ฐํ–ˆ์œผ๋‹ˆ, ์˜ตํ‹ฐ๋งˆ์ด์ € ์‚ฌ์šฉํ•ด์„œ backpropagation ์ ์šฉ

Part 2) Generator์˜ ํ•™์Šต
D๋ฅผ ์ด์šฉํ•ด G์˜ ์ถœ๋ ฅ๊ฐ’์„ ํŒ๋ณ„ํ•ด์ฃผ๊ณ , ์ง„์งœ ๋ผ๋ฒจ๊ฐ’์„ ์ด์šฉํ•ด G์˜ ์†์‹ค๊ฐ’์„ ๊ตฌํ•œ๋‹ค.
(์ฆ‰, G๋กœ ์ƒ์„ฑํ•œ ๋ฐ์ดํ„ฐ๋ฅผ ํŒ๋ณ„์ž์— ๋„ฃ์–ด์„œ ํŒ๋ณ„ํ•ด์„œ ๋‚˜์˜จ ๊ฐ’์„ ์ง„์งœ ๋ผ๋ฒจ๊ฐ’์„ ์ด์šฉํ•ด์„œ loss ๊ณ„์‚ฐํ•ด์ค€๋‹ค.)

๊ตฌํ•œ ์†์‹ค๊ฐ’์œผ๋กœ ๋ณ€ํ™”๋„๋ฅผ ๊ตฌํ•˜๊ณ , ์ตœ์ข…์ ์œผ๋กœ๋Š” ์˜ตํ‹ฐ๋งˆ์ด์ €๋ฅผ ์ด์šฉํ•ด G์˜ ๊ฐ€์ค‘์น˜๋“ค์„ ์—…๋ฐ์ดํŠธ ์‹œ์ผœ์ค€๋‹ค.

Result

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

์†์‹ค๊ฐ’ ๊ทธ๋ž˜ํ”„๋กœ ๋‚˜ํƒ€๋‚ด๊ธฐ

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

Real vs. Fake

์ถœ์ฒ˜ : PyTorch Tutorials https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

profile
HGU - ๊ฐœ์ธ ๊ณต๋ถ€ ๊ธฐ๋ก์šฉ ๋ธ”๋กœ๊ทธ

0๊ฐœ์˜ ๋Œ“๊ธ€

๊ด€๋ จ ์ฑ„์šฉ ์ •๋ณด

Powered by GraphCDN, the GraphQL CDN