GAN

seongyong·2021년 7월 7일
0

컴퓨터 비전

목록 보기
3/3

학습내용

GAN

특정 이미지를 모방하여 생성하는 generator, 생성된 이미지가 원본인지를 판별하는 discriminator로 구성되어있다. generator는 자신이 생성한 이미지를 discriminator가 원본으로 인식하도록 하기위해 이미지 생성에 대한 학습을 진행하고 discriminator는 원본인지를 더 잘 판별할 수 있도록 학습되는 구조이다.

훈련과정 동안 Generator는 점차 실제같은 이미지를 더 잘 생성하게 되고, Descriminator는 점차 진짜와 가짜를 더 잘 구별하게된다. 이 과정은 Descriminator가 가짜 이미지에서 진짜 이미지를 더이상 구별하지 못하게 될때, 평형상태에 도달하게 된다.

Deep Convolutional GAN(DCGAN)

generator

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # 주목: 배치사이즈로 None이 주어집니다.

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

학습되지않은 generator를 통해 이미지 생성

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')

discriminator

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

학습되지않은 discriminator를 이용해 생성된 이미지 판별

모델은 진짜 이미지에는 양수의 값 (positive values)을, 가짜 이미지에는 음수의 값 (negative values)을 출력

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)

Reference

https://sensibilityit.tistory.com/506

0개의 댓글