인코딩 : 특정 파일을 특정 형태로 변화하는 과정
인코더 : 잠재공간과 동일 , 특정 구조화된 값으로 매핑하는 함수
디코딩 : 잠재공간에 의해 변형된 파일을 특정 형태로 복원하는 과정
디코더 : 구조화된 값을 원래 형태로 복원하는 함수
입력을 저차원 잠재공간으로 인코딩한 후 디코딩하여 복원하는 네트워크
즉, 이미지를 입력받아 인코더 모듈을 사용하여 잠재 벡터 공간으로 매핑하고,
디코더 모듈을 사용하여 원본 이미지와 동일한 차원으로 복원하여 출력
원본 입력을 재구성하는 방법으로 학습 (원본 입력과 재구성된 자료와의 차이를 비교)
고전적인 방식은 구조화가 잘된 잠재 공간을 만들지 못하고,
압축도 뛰어나지 않음
## 라이브러리
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.losses import MeanSquaredError
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
# 데이터 가져오기
(x_train, _), (x_test, _) = fashion_mnist.load_data()
# 정규화 - normalize
x_train = x_train.astype('float32')/255.
x_test = x_test.astype('float32')/255.
print(x_train.shape)
print(x_test.shape)
latent_dim = 64
class Autoencoder(Model):
def __init__ (self,latent_dim):
super(Autoencoder, self).__init__()
self.latent_dim = latent_dim
# encode - compress
self.encoder = tf.keras.Sequential([Flatten(), Dense(latent_dim,activation='relu')])
#decode - decompress
self.decoder = tf.keras.Sequential([Dense(784,activation = 'relu'), Reshape((28,28))])
def call(self,x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
Flatten
으로 압축Dense
로 다시 784로 바꾸고 Resahpe
으로 형태 변환
autoencoder = Autoencoder(latent_dim)
autoencoder.compile(optimizer='adam',loss= MeanSquaredError())
autoencoder.fit(x_train,x_train,
epochs =10,
shuffle=True,
validation_data=(x_test,x_test))
# 넘파이 리스트 형태로 결과값 만들기
encoded_imgs = autoencoder.encoder(x_test).numpy()
decoded_imgs = autoencoder.decoder(encoded_imgs).numpy()
## 시각화
n = 10
plt.figure(figsize=(20,4))
for i in range(n):
ax= plt.subplot(2, n, i+1)
plt.imshow(x_test[i])
plt.title('original')
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax= plt.subplot(2, n, i+1+n)
plt.imshow(decoded_imgs[i])
plt.title('reconstructed')
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()