Sequence to Sequence with attention
simple neural machine translation training
Reference
- Sequence to Sequence Learning with Neural Networks
- Effective Approaches to Attention-based Neural Machine Translation
- Neural Machine Translation with Attention from Tensorflow
from __future__ import absolute_import, division, print_function
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.preprocessing.sequence import pad_sequences
from pprint import pprint
import numpy as np
import os
print(tf.__version__)
sources = [['I', 'feel', 'hungry'],
['tensorflow', 'is', 'very', 'difficult'],
['tensorflow', 'is', 'a', 'framework', 'for', 'deep', 'learning'],
['tensorflow', 'is', 'very', 'fast', 'changing']]
targets = [['나는', '배가', '고프다'],
['텐서플로우는', '매우', '어렵다'],
['텐서플로우는', '딥러닝을', '위한', '프레임워크이다'],
['텐서플로우는', '매우', '빠르게', '변화한다']]
s_vocab = list(set(sum(sources, [])))
s_vocab.sort()
s_vocab = ['<pad>'] + s_vocab
source2idx = {word : idx for idx, word in enumerate(s_vocab)}
idx2source = {idx : word for idx, word in enumerate(s_vocab)}
pprint(source2idx)
t_vocab = list(set(sum(targets, [])))
t_vocab.sort()
t_vocab = ['<pad>', '<bos>', '<eos>'] + t_vocab
target2idx = {word : idx for idx, word in enumerate(t_vocab)}
idx2target = {idx : word for idx, word in enumerate(t_vocab)}
pprint(target2idx)
def preprocess(sequences, max_len, dic, mode = 'source'):
assert mode in ['source', 'target'], 'source와 target 중에 선택해주세요.'
if mode == 'source':
s_input = list(map(lambda sentence : [dic.get(token) for token in sentence], sequences))
s_len = list(map(lambda sentence : len(sentence), s_input))
s_input = pad_sequences(sequences = s_input, maxlen = max_len, padding = 'post', truncating = 'post')
return s_len, s_input
elif mode == 'target':
t_input = list(map(lambda sentence : ['<bos>'] + sentence + ['<eos>'], sequences))
t_input = list(map(lambda sentence : [dic.get(token) for token in sentence], t_input))
t_len = list(map(lambda sentence : len(sentence), t_input))
t_input = pad_sequences(sequences = t_input, maxlen = max_len, padding = 'post', truncating = 'post')
t_output = list(map(lambda sentence : sentence + ['<eos>'], sequences))
t_output = list(map(lambda sentence : [dic.get(token) for token in sentence], t_output))
t_output = pad_sequences(sequences = t_output, maxlen = max_len, padding = 'post', truncating = 'post')
return t_len, t_input, t_output
s_max_len = 10
s_len, s_input = preprocess(sequences = sources,
max_len = s_max_len, dic = source2idx, mode = 'source')
print(s_len, s_input)
t_max_len = 12
t_len, t_input, t_output = preprocess(sequences = targets,
max_len = t_max_len, dic = target2idx, mode = 'target')
print(t_len, t_input, t_output)
hyper-param
epochs = 100
batch_size = 4
learning_rate = .005
total_step = epochs / batch_size
buffer_size = 100
n_batch = buffer_size//batch_size
embedding_dim = 32
units = 128
data = tf.data.Dataset.from_tensor_slices((s_len, s_input, t_len, t_input, t_output))
data = data.shuffle(buffer_size = buffer_size)
data = data.batch(batch_size = batch_size)
def gru(units):
return tf.keras.layers.GRU(units,
return_sequences=True,
return_state=True,
recurrent_activation='sigmoid',
recurrent_initializer='glorot_uniform')
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):
super(Encoder, self).__init__()
self.batch_sz = batch_sz
self.enc_units = enc_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = gru(self.enc_units)
def call(self, x, hidden):
x = self.embedding(x)
output, state = self.gru(x, initial_state = hidden)
return output, state
def initialize_hidden_state(self):
return tf.zeros((self.batch_sz, self.enc_units))
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):
super(Decoder, self).__init__()
self.batch_sz = batch_sz
self.dec_units = dec_units
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = gru(self.dec_units)
self.fc = tf.keras.layers.Dense(vocab_size)
self.W1 = tf.keras.layers.Dense(self.dec_units)
self.W2 = tf.keras.layers.Dense(self.dec_units)
self.V = tf.keras.layers.Dense(1)
def call(self, x, hidden, enc_output):
hidden_with_time_axis = tf.expand_dims(hidden, 1)
score = self.V(tf.nn.tanh(self.W1(enc_output) + self.W2(hidden_with_time_axis)))
attention_weights = tf.nn.softmax(score, axis=1)
context_vector = attention_weights * enc_output
context_vector = tf.reduce_sum(context_vector, axis=1)
x = self.embedding(x)
x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
output, state = self.gru(x)
output = tf.reshape(output, (-1, output.shape[2]))
x = self.fc(output)
return x, state, attention_weights
def initialize_hidden_state(self):
return tf.zeros((self.batch_sz, self.dec_units))
encoder = Encoder(len(source2idx), embedding_dim, units, batch_size)
decoder = Decoder(len(target2idx), embedding_dim, units, batch_size)
def loss_function(real, pred):
mask = 1 - np.equal(real, 0)
loss_ = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=real, logits=pred) * mask
return tf.reduce_mean(loss_)
optimizer = tf.keras.optimizers.Adam()
checkpoint_dir = './data_out/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
summary_writer = tf.summary.create_file_writer(logdir=checkpoint_dir)
for epoch in range(epochs):
hidden = encoder.initialize_hidden_state()
total_loss = 0
for i, (s_len, s_input, t_len, t_input, t_output) in enumerate(data):
loss = 0
with tf.GradientTape() as tape:
enc_output, enc_hidden = encoder(s_input, hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([target2idx['<bos>']] * batch_size, 1)
for t in range(1, t_input.shape[1]):
predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
loss += loss_function(t_input[:, t], predictions)
dec_input = tf.expand_dims(t_input[:, t], 1)
batch_loss = (loss / int(t_input.shape[1]))
total_loss += batch_loss
variables = encoder.variables + decoder.variables
gradient = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradient, variables))
if epoch % 10 == 0:
print('Epoch {} Loss {:.4f} Batch Loss {:.4f}'.format(epoch,
total_loss / n_batch,
batch_loss.numpy()))
checkpoint.save(file_prefix = checkpoint_prefix)
def evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):
attention_plot = np.zeros((max_length_targ, max_length_inp))
inputs = [inp_lang[i] for i in sentence.split(' ')]
inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs], maxlen=max_length_inp, padding='post')
inputs = tf.convert_to_tensor(inputs)
result = ''
hidden = [tf.zeros((1, units))]
enc_out, enc_hidden = encoder(inputs, hidden)
dec_hidden = enc_hidden
dec_input = tf.expand_dims([targ_lang['<bos>']], 0)
for t in range(max_length_targ):
predictions, dec_hidden, attention_weights = decoder(dec_input, dec_hidden, enc_out)
attention_weights = tf.reshape(attention_weights, (-1, ))
attention_plot[t] = attention_weights.numpy()
predicted_id = tf.argmax(predictions[0]).numpy()
result += idx2target[predicted_id] + ' '
if idx2target.get(predicted_id) == '<eos>':
return result, sentence, attention_plot
dec_input = tf.expand_dims([predicted_id], 0)
return result, sentence, attention_plot
def plot_attention(attention, sentence, predicted_sentence):
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='viridis')
fontdict = {'fontsize': 14}
ax.set_xticklabels([''] + sentence, fontdict=fontdict, rotation=90)
ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict)
plt.show()
def translate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ):
result, sentence, attention_plot = evaluate(sentence, encoder, decoder, inp_lang, targ_lang, max_length_inp, max_length_targ)
print('Input: {}'.format(sentence))
print('Predicted translation: {}'.format(result))
attention_plot = attention_plot[:len(result.split(' ')), :len(sentence.split(' '))]
plot_attention(attention_plot, sentence.split(' '), result.split(' '))
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
sentence = 'I feel hungry'
translate(sentence, encoder, decoder, source2idx, target2idx, s_max_len, t_max_len)