# 라이브러리 임포트 및 train, test 데이터 분류
from tensorflow.keras.datasets import imdb
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.sequence import pad_sequences
(train_input, train_target), (test_input, test_target) = imdb.load_data(
num_words=500)
train_input, val_input, train_target, val_target = train_test_split(
train_input, train_target, test_size=0.2, random_state=42)
train_seq = pad_sequences(train_input, maxlen=100)
val_seq = pad_sequences(val_input, maxlen=100)
# sequential 모델 객체 생성
from tensorflow import keras
model = keras.Sequential()
model.add(keras.layers.Embedding(500, 16, input_length=100))
model.add(keras.layers.LSTM(8))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.summary()
rmsprop = keras.optimizers.RMSprop(learning_rate=1e-4)
model.compile(optimizer=rmsprop, loss='binary_crossentropy',
metrics=['accuracy'])
checkpoint_cb = keras.callbacks.ModelCheckpoint('best-lstm-model.h5')
early_stopping_cb = keras.callbacks.EarlyStopping(patience=5,
restore_best_weights=True)
history = model.fit(train_seq, train_target, epochs=200, batch_size=128,
validation_data=(val_seq, val_target),
callbacks=[checkpoint_cb, early_stopping_cb])
import matplotlib.pyplot as plt
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['train', 'val'])
plt.show()
LSTM 사용이유 요약
- RNN은 시간이 길어지면 은닉층에서 데이터를 까먹기 때문에 LSTM을 사용
- 하지만 계산식이 복잡하다는 단점 보유
model4 = keras.Sequential()
model4.add(keras.layers.Embedding(500, 16, input_length=100))
model4.add(keras.layers.GRU(8))
model4.add(keras.layers.Dense(1, activation='sigmoid'))
model4.summary()
rmsprop = keras.optimizers.RMSprop(learning_rate=1e-4)
model4.compile(optimizer=rmsprop, loss='binary_crossentropy',
metrics=['accuracy'])
checkpoint_cb = keras.callbacks.ModelCheckpoint('best-gru-model.h5')
early_stopping_cb = keras.callbacks.EarlyStopping(patience=5,
restore_best_weights=True)
history = model4.fit(train_seq, train_target, epochs=200, batch_size=512,
validation_data=(val_seq, val_target),
callbacks=[checkpoint_cb, early_stopping_cb])
LSTM & GRU 차이점
구조
- LSTM : 입력게이트, 망각게이트, 출력게이트 등의 메커니즘을 사용하여 순차적 입력, 이전 상태 정보 조회 -> 장기 기억 상태로 저장
- GRU : 입력게이트와 망각게이트를 업데이트 게이트로 통합, 훈련속도가 빠르고 더 간단한 구조
게이트 수
- LSTM : 총 3개 (입력, 망각, 출력)
- GRU : 총 2개 (업데이트, 출력)
정보의 흐름
- LSTM : 입력 게이트를 통해 어떤 정보를 기억할지 결정 후 망각 게이트를 통해 이전 상태의 어떤 정보를 잊을지 결정, 출력 게이트는 현재 상태의 어떤 정보를 다음 상태의 전달할 지 결정
- GRU : 업데이트 게이트를 통해 현재 입력과 이전 상태의 정보를 조합하여 현재 상태를 업데이트
참고자료
https://blog.naver.com/winddori2002/221992543837
https://wooono.tistory.com/242