Keras SimpleRNN

yousmile·2021년 8월 2일

자연어처리

목록 보기
3/6

1. 임의의 입력 생성

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import SimpleRNN, LSTM, Bidirectional

#단어 벡터의 차원은 5이고, 문장의 길이(timesteps)가 4인 경우
train_X = [[0.1, 4.2, 1.5, 1.1, 2.8], [1.0, 3.1, 2.5, 0.7, 1.1], [0.3, 2.1, 1.5, 2.1, 0.1], [2.2, 1.4, 0.5, 0.9, 1.1]]
print(np.shape(train_X))
(4, 5)

RNN은 3D 텐서를 입력 받는다고 앞서 말함.
따라서 위의 2D 텐서를 batch_size 1을 추가해 3D 텐서로 변경

train_X = [[[0.1, 4.2, 1.5, 1.1, 2.8], [1.0, 3.1, 2.5, 0.7, 1.1], [0.3, 2.1, 1.5, 2.1, 0.1], [2.2, 1.4, 0.5, 0.9, 1.1]]]
train_X = np.array(train_X, dtype=np.float32)
print(train_X.shape)
(1, 4, 5)

2. SimpleRNN 적용

대표적인 인자

  • return_sequences : False인 경우 마지막 시점 은닉 상태만 출력. True인 경우 모든 시점의 은닉 상태 출력
  • return_state : True인 경우 return_sequences 여부와 상관없이 마지막 시점 은닉상태를 출력

1. 둘 다 False인 경우: 마지막 시점 은닉 상태 출력

rnn = SimpleRNN(3)
hidden_state = rnn(train_X)
print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape)) 
hidden state : [[-0.9144172  -0.99990714 -0.5751257 ]], shape: (1, 3)

2. return_sequences만 True인 경우: 모든 시점의 은닉 상태 출력

rnn = SimpleRNN(3, return_sequences=True) 
hidden_states = rnn(train_X)
hidden states : [[[-0.5834232  -0.89815223  0.9967079 ]
  [ 0.7584444  -0.97730017  0.9737648 ]
  [ 0.5499788  -0.95319057 -0.9243298 ]
  [ 0.69824535 -0.7607714   0.8957305 ]]], shape: (1, 4, 3)

3. return_state만 True인 경우: 두 개의 출력 리턴(두 다 마지막 시점 은닉 상태 출력)

rnn = SimpleRNN(3, return_sequences=False, return_state=True)
hidden_state, last_state = rnn(train_X)

print('hidden state : {}, shape: {}'.format(hidden_state, hidden_state.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))
hidden state : [[-0.26085535 -0.8020567  -0.78999335]], shape: (1, 3)
last hidden state : [[-0.26085535 -0.8020567  -0.78999335]], shape: (1, 3)

4. 둘 다 True인 경우: 두 개의 출력 리턴(모든 시점, 마지막 시점 은닉 상태 출력)

rnn = SimpleRNN(3, return_sequences=True, return_state=True)
hidden_states, last_state = rnn(train_X)

print('hidden states : {}, shape: {}'.format(hidden_states, hidden_states.shape))
print('last hidden state : {}, shape: {}'.format(last_state, last_state.shape))
hidden states : [[[ 0.640092    0.9975969   0.74564785]
  [ 0.5521973   0.99620426  0.17306131]
  [ 0.43504977  0.5067752  -0.5658185 ]
  [ 0.9774429   0.93501943  0.08292837]]], shape: (1, 4, 3)
last hidden state : [[0.9774429  0.93501943 0.08292837]], shape: (1, 3)```

0개의 댓글