๋ณธ ๊ธ์ Hierachical Structure์ ๊ธ์ฐ๊ธฐ ๋ฐฉ์์ผ๋ก, ๊ธ์ ์ ์ฒด์ ์ธ ๋งฅ๋ฝ์ ํ์ ํ๊ธฐ ์ฝ๋๋ก ์์ฑ๋์์ต๋๋ค.
๋ํ ๋ณธ ๊ธ์ CSF(Curation Service for Facilitation)๋ก ์ธ์ฉ๋(์ฐธ์กฐ๋) ๋ชจ๋ ์ถ์ฒ๋ ์๋ตํฉ๋๋ค.
import numpy as np
import tensorflow as tf
import glob
txt_file_path = './lyricist/data/lyrics/*'
# ๊ฒฝ๋ก ์ง์
txt_list = glob.glob(txt_file_path)
# https://wikidocs.net/3746
# ํด๋น ๊ฒฝ๋ก ํด๋์ ๋ชจ๋ ํ์ ํด๋๊น์ง ํ์ํ์ฌ txt_list์ ๋ฃ์
raw_corpus = []
# ์ฌ๋ฌ๊ฐ์ txt ํ์ผ์ ๋ชจ๋ ์ฝ์ด์ raw_corpus ์ ๋ด์ต๋๋ค.
for txt_file in txt_list:
with open(txt_file, "r") as f:
raw = f.read().splitlines()
raw_corpus.extend(raw)
print("๋ฐ์ดํฐ ํฌ๊ธฐ:", len(raw_corpus))
print("Examples:\n", raw_corpus[:3])
๋ฐ์ดํฐ ํฌ๊ธฐ: 187088
Examples:
['The first words that come out', 'And I can see this song will be about you', "I can't believe that I can breathe without you"]
# ์
๋ ฅ๋ ๋ฌธ์ฅ์
# 1. ์๋ฌธ์๋ก ๋ฐ๊พธ๊ณ , ์์ชฝ ๊ณต๋ฐฑ์ ์ง์๋๋ค
# 2. ํน์๋ฌธ์ ์์ชฝ์ ๊ณต๋ฐฑ์ ๋ฃ๊ณ
# 3. ์ฌ๋ฌ๊ฐ์ ๊ณต๋ฐฑ์ ํ๋์ ๊ณต๋ฐฑ์ผ๋ก ๋ฐ๊ฟ๋๋ค
# 4. a-zA-Z?.!,ยฟ๊ฐ ์๋ ๋ชจ๋ ๋ฌธ์๋ฅผ ํ๋์ ๊ณต๋ฐฑ์ผ๋ก ๋ฐ๊ฟ๋๋ค
# 5. ๋ค์ ์์ชฝ ๊ณต๋ฐฑ์ ์ง์๋๋ค
# 6. ๋ฌธ์ฅ ์์์๋ <start>, ๋์๋ <end>๋ฅผ ์ถ๊ฐํฉ๋๋ค
# ์ด ์์๋ก ์ฒ๋ฆฌํด์ฃผ๋ฉด ๋ฌธ์ ๊ฐ ๋๋ ์ํฉ์ ๋ฐฉ์งํ ์ ์๊ฒ ๋ค์!
import re
def preprocess_sentence(sentence):
sentence = sentence.lower().strip() # 1
sentence = re.sub(r"([?.!,ยฟ])", r" \1 ", sentence) # 2
sentence = re.sub(r'[" "]+', " ", sentence) # 3
sentence = re.sub(r"[^a-zA-Z?.!,ยฟ]+", " ", sentence) # 4
sentence = sentence.strip() # 5
sentence = '<start> ' + sentence + ' <end>' # 6
return sentence
# ์ด ๋ฌธ์ฅ์ด ์ด๋ป๊ฒ ํํฐ๋ง๋๋์ง ํ์ธํด ๋ณด์ธ์.
print(preprocess_sentence("This @_is ;;;sample 23423 sentence."))
<start> this is sample sentence . <end>
#############################################
# # ์ถ๊ฐ๋ก ์ง๋์น๊ฒ ๊ธด ๋ฌธ์ฅ์ ๋ค๋ฅธ ๋ฐ์ดํฐ๋ค์ด ๊ณผ๋ํ Padding์ ๊ฐ๊ฒ ํ๋ฏ๋ก ์ ๊ฑฐ
corpus = [] #์ ์ ๋ ๋ฌธ์ฅ ๋ชจ์ผ๋ ๊ณณ
for sentence in raw_corpus:
if len(sentence) == 0: continue
if len(sentence) > 100: continue # ๊ณผ๋ํ๊ฒ ๊ธด๋ฌธ์ฅ ์์ ๊ธฐ
if sentence[-1] == ":": continue
# ์ ์ ํ๊ธฐ
preprocessed_sentence = preprocess_sentence(sentence)
# ํ ํฐ์ ๊ฐ์๊ฐ 15๊ฐ๋ฅผ ๋์ด๊ฐ๋ ๋ฌธ์ฅ์ ์ ์ธ
if len(preprocessed_sentence.split()) > 15: continue
# ๋ด๊ธฐ
corpus.append(preprocessed_sentence)
# ์ ์ ๋ ๊ฒฐ๊ณผ ํ์ธ
corpus[:1]
['<start> the first words that come out <end>']
def tokenize(corpus):
tokenizer = tf.keras.preprocessing.text.Tokenizer(
num_words=12000,
filters=' ',
oov_token="<unk>"
)
# corpus๋ฅผ ์ด์ฉํด tokenizer ๋ด๋ถ์ ๋จ์ด์ฅ์ ์์ฑํฉ๋๋ค
tokenizer.fit_on_texts(corpus)
# ์ค๋นํ tokenizer๋ฅผ ์ด์ฉํด corpus๋ฅผ Tensor๋ก ๋ณํํฉ๋๋ค
tensor = tokenizer.texts_to_sequences(corpus)
# ์
๋ ฅ ๋ฐ์ดํฐ์ ์ํ์ค ๊ธธ์ด๋ฅผ ์ผ์ ํ๊ฒ ๋ง์ถฐ์ค๋๋ค
# ๋ง์ฝ ์ํ์ค๊ฐ ์งง๋ค๋ฉด ๋ฌธ์ฅ ๋ค์ ํจ๋ฉ์ ๋ถ์ฌ ๊ธธ์ด๋ฅผ ๋ง์ถฐ์ค๋๋ค.
# ๋ฌธ์ฅ ์์ ํจ๋ฉ์ ๋ถ์ฌ ๊ธธ์ด๋ฅผ ๋ง์ถ๊ณ ์ถ๋ค๋ฉด padding='pre'๋ฅผ ์ฌ์ฉํฉ๋๋ค
tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor, padding='post', maxlen=15)
print(tensor, tokenizer)
return tensor, tokenizer
tensor, tokenizer = tokenize(corpus)
[[ 2 6 248 ... 0 0 0]
[ 2 8 4 ... 0 0 0]
[ 2 4 35 ... 0 0 0]
...
[ 2 124 112 ... 0 0 0]
[ 2 124 112 ... 0 0 0]
[ 2 124 112 ... 0 0 0]] <keras_preprocessing.text.Tokenizer object at 0x7fca320329d0>
for idx, sentence in enumerate(raw_corpus):
if len(sentence) == 0: continue # ๊ธธ์ด๊ฐ 0์ธ ๋ฌธ์ฅ์ ๊ฑด๋๋๋๋ค.
if sentence[-1] == ":": continue # ๋ฌธ์ฅ์ ๋์ด : ์ธ ๋ฌธ์ฅ์ ๊ฑด๋๋๋๋ค.
if idx > 9: break # ์ผ๋จ ๋ฌธ์ฅ 10๊ฐ๋ง ํ์ธํด ๋ณผ ๊ฒ๋๋ค.
print(sentence)
The first words that come out
And I can see this song will be about you
I can't believe that I can breathe without you
But all I need to do is carry on
The next line I write down
And there's a tear that falls between the pages
I know that pain's supposed to heal in stages
But it depends which one I'm standing on I write lines down, then rip them up
Describing love can't be this tough I could set this song on fire, send it up in smoke
I could throw it in the river and watch it sink in slowly
# tokenizer์ ๊ตฌ์ถ๋ ๋จ์ด ์ฌ์ ์ ์ธ๋ฑ์ค ์ถ๋ ฅ
print(tensor[:5, :])
[[ 2 6 248 436 15 68 57 3 0 0 0 0 0 0 0]
[ 2 8 4 35 63 41 357 84 27 111 7 3 0 0 0]
[ 2 4 35 16 218 15 4 35 767 257 7 3 0 0 0]
[ 2 33 25 4 92 10 48 26 829 18 3 0 0 0 0]
[ 2 6 331 441 4 759 58 3 0 0 0 0 0 0 0]]
# ์ ์ ํ ํ
์ ํฌ๊ธฐ ์ถ๋ ฅ
print(len(tensor), len(corpus))
156013 156013
# ๋จ์ด์ฅ์ด ์ด๋ป๊ฒ ๊ตฌ์ถ๋์๋์ง ํ์ธ
for idx in tokenizer.index_word:
print(idx, ":", tokenizer.index_word[idx])
if idx >= 10: break
1 : <unk>
2 : <start>
3 : <end>
4 : i
5 : ,
6 : the
7 : you
8 : and
9 : a
10 : to
# tensor์์ ๋ง์ง๋ง ํ ํฐ์ ์๋ผ๋ด์ ์์ค ๋ฌธ์ฅ์ ์์ฑํฉ๋๋ค
# ๋ง์ง๋ง ํ ํฐ์ <end>๊ฐ ์๋๋ผ <pad>์ผ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค.
src_input = tensor[:, :-1]
# tensor์์ <start>๋ฅผ ์๋ผ๋ด์ ํ๊ฒ ๋ฌธ์ฅ์ ์์ฑํฉ๋๋ค.
tgt_input = tensor[:, 1:]
print(src_input[0])
print(tgt_input[0])
[ 2 6 248 436 15 68 57 3 0 0 0 0 0 0]
[ 6 248 436 15 68 57 3 0 0 0 0 0 0 0]
# ํ๊ฐ ๋ฐ์ดํฐ์
๋ถ๋ฆฌ
# 20%๋ฅผ ํ๊ฐ์ฉ
from sklearn.model_selection import train_test_split
enc_train, enc_val, dec_train, dec_val = train_test_split(src_input,
tgt_input,
test_size=0.2,
random_state=42)
print("Source Train:", enc_train.shape)
print("Target Train:", dec_train.shape)
Source Train: (124810, 14)
Target Train: (124810, 14)
BUFFER_SIZE = len(src_input)
BATCH_SIZE = 256
steps_per_epoch = len(src_input) // BATCH_SIZE
# tokenizer๊ฐ ๊ตฌ์ถํ ๋จ์ด์ฌ์ ๋ด 12000๊ฐ์, ์ฌ๊ธฐ ํฌํจ๋์ง ์์ 0:<pad>๋ฅผ ํฌํจํ์ฌ 7001๊ฐ
VOCAB_SIZE = tokenizer.num_words + 1
# ์ค๋นํ ๋ฐ์ดํฐ ์์ค๋ก๋ถํฐ ๋ฐ์ดํฐ์
์ ๋ง๋ญ๋๋ค
# ๋ฐ์ดํฐ์
์ ๋ํด์๋ ์๋ ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ์ธ์
# ์์ธํ ์์๋์๋ก ๋์์ด ๋ง์ด ๋๋ ์ค์ํ ๋ฌธ์์
๋๋ค
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset
# train ๋ฐ์ดํฐ์
train_dataset = tf.data.Dataset.from_tensor_slices((enc_train, dec_train))
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)
print(train_dataset)
# test ๋ฐ์ดํฐ์
test_dataset = tf.data.Dataset.from_tensor_slices((enc_val, dec_val))
test_dataset = test_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE, drop_remainder=True)
print(test_dataset)
<BatchDataset shapes: ((256, 14), (256, 14)), types: (tf.int32, tf.int32)>
<BatchDataset shapes: ((256, 14), (256, 14)), types: (tf.int32, tf.int32)>
# ์ธ๊ณต์ง๋ฅ ๋ง๋ค๊ธฐ
class TextGenerator(tf.keras.Model):
def __init__(self, vocab_size, embedding_size, hidden_size):
super().__init__()
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
self.rnn_1 = tf.keras.layers.LSTM(hidden_size, return_sequences=True)
self.rnn_2 = tf.keras.layers.LSTM(hidden_size, return_sequences=True)
self.linear = tf.keras.layers.Dense(vocab_size)
def call(self, x):
out = self.embedding(x)
out = self.rnn_1(out)
out = self.rnn_2(out)
out = self.linear(out)
return out
embedding_size = 256 # ๊ฐ์ด ์ปค์ง์๋ก ๋จ์ด์ ์ถ์์ ์ธ ํน์ง๋ค์ ๋ ์ก์๋ผ ์ ์์ง๋ง, ๋ฐ์ดํฐ ์์ด ์ถฉ๋ถํด์ผํจ
hidden_size = 1024 # ๋ชจ๋ธ์ ์ผ๋ง๋ ๋ง์ ์ผ๊พผ์ ๋ ๊ฒ์ธ๊ฐ, ์ถฉ๋ถํ ๋ฐ์ดํฐ๊ฐ ์ฃผ์ด์ ธ์ผ ๋ฐฐ๊ฐ ์ฐ์ผ๋ก ๊ฐ์ง ์์
lyricist = TextGenerator(tokenizer.num_words + 1, embedding_size , hidden_size)
# ๋ฐ์ดํฐ์
์์ ๋ฐ์ดํฐ ํ ๋ฐฐ์น๋ง ๋ถ๋ฌ์ค๋ ๋ฐฉ๋ฒ์
๋๋ค.
# ์ง๊ธ์ ๋์ ์๋ฆฌ์ ๋๋ฌด ๋น ์ ธ๋ค์ง ๋ง์ธ์~
for src_sample, tgt_sample in train_dataset.take(1): break
# ํ ๋ฐฐ์น๋ง ๋ถ๋ฌ์จ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ธ์ ๋ฃ์ด๋ด
๋๋ค
lyricist(src_sample)
<tf.Tensor: shape=(256, 14, 12001), dtype=float32, numpy=
array([[[-1.44549223e-04, -1.15473573e-04, -6.30542418e-05, ...,
-4.81275529e-05, 3.25414061e-04, 2.03498232e-04],
[-1.86848993e-04, -2.30316669e-04, -2.25731535e-04, ...,
-6.46176923e-05, 5.04140393e-04, 3.21763946e-04],
[-1.03205770e-04, -4.84468270e-04, -1.17466385e-04, ...,
-4.00223624e-04, 6.36043027e-04, 5.16410160e-04],
...,
[-1.38640136e-03, -4.47364670e-04, -2.00071765e-04, ...,
-2.70744367e-05, 5.63458947e-04, 1.36734487e-03],
[-1.38968823e-03, -2.18201923e-04, -5.19401219e-05, ...,
6.40518032e-04, 7.30469066e-04, 1.37798791e-03],
[-1.20439660e-03, -3.50640657e-05, 1.76810892e-04, ...,
1.40748988e-03, 9.31126706e-04, 1.26391207e-03]],
[[-1.44549223e-04, -1.15473573e-04, -6.30542418e-05, ...,
-4.81275529e-05, 3.25414061e-04, 2.03498232e-04],
[-2.94670928e-04, -8.99762235e-05, -2.69432086e-04, ...,
8.99600991e-07, 4.78153932e-04, 9.61799306e-05],
[-4.10166569e-04, -5.20658679e-04, -2.53264647e-04, ...,
3.23973043e-04, 5.47863834e-04, -1.30365908e-04],
...,
[ 1.13907992e-03, -9.15171870e-04, 8.65232723e-04, ...,
3.22713796e-03, 6.88177999e-04, 6.72521652e-04],
[ 1.39695173e-03, -7.76036293e-04, 1.17004605e-03, ...,
3.59163154e-03, 7.81425682e-04, 4.78076487e-04],
[ 1.59162667e-03, -6.71598536e-04, 1.45776011e-03, ...,
3.88893508e-03, 8.60124943e-04, 2.58787506e-04]],
[[-1.44549223e-04, -1.15473573e-04, -6.30542418e-05, ...,
-4.81275529e-05, 3.25414061e-04, 2.03498232e-04],
[-9.57229131e-05, 4.74091621e-06, -6.54251926e-05, ...,
-2.03965265e-05, 3.87494307e-04, 2.87009287e-04],
[ 7.70175102e-05, 2.29143661e-05, -1.12108835e-04, ...,
-6.15792305e-05, 7.54494278e-04, 2.59941327e-04],
...,
[-3.39211168e-04, -7.62147596e-04, 7.07902480e-04, ...,
1.03856390e-03, 2.22608724e-04, 5.74758451e-04],
[-1.15578106e-04, -5.49393706e-04, 9.38904588e-04, ...,
1.59561902e-03, 3.60214763e-04, 6.75352698e-04],
[ 1.59531934e-04, -3.65949731e-04, 1.20420614e-03, ...,
2.15193047e-03, 5.24571515e-04, 6.74697279e-04]],
...,
[[-1.44549223e-04, -1.15473573e-04, -6.30542418e-05, ...,
-4.81275529e-05, 3.25414061e-04, 2.03498232e-04],
[-1.11118985e-04, -1.70493659e-04, 7.63675689e-06, ...,
-1.08443775e-04, 5.13266714e-04, 2.32881168e-04],
[-2.32437969e-07, -3.77636781e-04, 2.63728609e-04, ...,
6.40669605e-05, 7.68769416e-04, -7.61850824e-05],
...,
[-3.61508537e-05, 2.39523928e-04, 1.20974041e-03, ...,
1.75854599e-03, 1.03051902e-03, -2.45660427e-04],
[ 2.07302044e-04, 2.38900306e-04, 1.42566522e-03, ...,
2.29557604e-03, 1.19230570e-03, -2.90194992e-04],
[ 4.27131599e-04, 2.16009008e-04, 1.64476351e-03, ...,
2.76157586e-03, 1.31183036e-03, -3.64537846e-04]],
[[-1.44549223e-04, -1.15473573e-04, -6.30542418e-05, ...,
-4.81275529e-05, 3.25414061e-04, 2.03498232e-04],
[-1.79736468e-04, -8.09120102e-05, -1.06378997e-04, ...,
-5.81318745e-05, 2.21763956e-04, 8.87998249e-05],
[-5.21522888e-04, -4.01913829e-04, 8.20151035e-05, ...,
-7.17598687e-06, 2.78848631e-04, -2.14788488e-05],
...,
[-1.21572160e-03, -6.60044549e-04, -2.85983660e-05, ...,
1.35207723e-03, 6.91875233e-04, -8.14947416e-04],
[-8.80337728e-04, -5.37527958e-04, 1.78775561e-04, ...,
1.90995191e-03, 9.32036142e-04, -6.58150821e-04],
[-5.19666821e-04, -4.26618091e-04, 4.55988979e-04, ...,
2.46126065e-03, 1.15186919e-03, -5.66828065e-04]],
[[-1.44549223e-04, -1.15473573e-04, -6.30542418e-05, ...,
-4.81275529e-05, 3.25414061e-04, 2.03498232e-04],
[-3.42270418e-04, -2.84511450e-04, 7.37812079e-05, ...,
1.81375945e-04, 3.50613933e-04, 2.93300021e-04],
[-5.02196606e-04, -4.85470606e-04, 4.54048219e-04, ...,
4.33529960e-04, 1.44814345e-04, 4.72123036e-04],
...,
[ 3.60102160e-04, 3.53560550e-04, 1.83492934e-03, ...,
3.30801494e-03, 8.18313390e-04, -1.35901757e-03],
[ 5.72396151e-04, 2.90043768e-04, 1.98598928e-03, ...,
3.61621543e-03, 8.88703857e-04, -1.44164031e-03],
[ 7.50529172e-04, 2.19747817e-04, 2.14289455e-03, ...,
3.86781618e-03, 9.42276383e-04, -1.49855157e-03]]],
dtype=float32)>
lyricist.summary()
Model: "text_generator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
embedding (Embedding) multiple 3072256
_________________________________________________________________
lstm (LSTM) multiple 5246976
_________________________________________________________________
lstm_1 (LSTM) multiple 8392704
_________________________________________________________________
dense (Dense) multiple 12301025
=================================================================
Total params: 29,012,961
Trainable params: 29,012,961
Non-trainable params: 0
_________________________________________________________________
optimizer = tf.keras.optimizers.Adam()
#Loss
loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction='none'
)
lyricist.compile(loss=loss,
optimizer=optimizer,
metrics=['accuracy']) # ์ ํ์ฑ ํ๋จ
lyrics_history = lyricist.fit(train_dataset, # ํ๋ จ ๋ฐ์ดํฐ
validation_data=test_dataset, # ํ๊ฐ ๋ฐ์ดํฐ
epochs=10)
Epoch 1/10
487/487 [==============================] - 100s 201ms/step - loss: 3.5322 - accuracy: 0.4808 - val_loss: 3.1559 - val_accuracy: 0.5119
Epoch 2/10
487/487 [==============================] - 98s 201ms/step - loss: 3.0327 - accuracy: 0.5202 - val_loss: 2.9604 - val_accuracy: 0.5263
Epoch 3/10
487/487 [==============================] - 98s 202ms/step - loss: 2.8649 - accuracy: 0.5308 - val_loss: 2.8512 - val_accuracy: 0.5338
Epoch 4/10
487/487 [==============================] - 98s 202ms/step - loss: 2.7386 - accuracy: 0.5393 - val_loss: 2.7715 - val_accuracy: 0.5406
Epoch 5/10
487/487 [==============================] - 98s 202ms/step - loss: 2.6322 - accuracy: 0.5470 - val_loss: 2.7075 - val_accuracy: 0.5464
Epoch 6/10
487/487 [==============================] - 98s 202ms/step - loss: 2.5367 - accuracy: 0.5544 - val_loss: 2.6574 - val_accuracy: 0.5517
Epoch 7/10
487/487 [==============================] - 98s 202ms/step - loss: 2.4491 - accuracy: 0.5618 - val_loss: 2.6150 - val_accuracy: 0.5572
Epoch 8/10
487/487 [==============================] - 98s 202ms/step - loss: 2.3682 - accuracy: 0.5695 - val_loss: 2.5784 - val_accuracy: 0.5624
Epoch 9/10
487/487 [==============================] - 98s 202ms/step - loss: 2.2921 - accuracy: 0.5774 - val_loss: 2.5480 - val_accuracy: 0.5676
Epoch 10/10
487/487 [==============================] - 98s 202ms/step - loss: 2.2198 - accuracy: 0.5858 - val_loss: 2.5225 - val_accuracy: 0.5724
# EX06์์ ๋ฐฐ์ด ๊ฑฐ ์จ๋จน์ด ๋ณด๊ธฐ
import matplotlib.pyplot as plt
acc = lyrics_history.history['accuracy']
val_acc = lyrics_history.history['val_accuracy']
loss = lyrics_history.history['loss']
val_loss = lyrics_history.history['val_loss']
epochs_range = range(len(acc))
plt.figure(figsize = (12, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label = 'Training Accuracy')
plt.plot(epochs_range, val_acc, label = 'Validation Accuracy')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label = 'Training Loss')
plt.plot(epochs_range, val_loss, label = 'Validation Loss')
plt.legend(loc = 'upper right')
plt.title('Training and Validation Loss')
plt.show()
์ฌ๊ธฐ์ ์ ํ๋๊ฐ ๋ ์ฌ๋ผ๊ฐ ์ ์๊ณ
๋ก์ค๋ ๋ ์ค์ผ ์ ์์ผ๋ฏ๋ก epoch๋ฅผ ๋ ์งํํด์ฃผ์ด๋ ์ข์
์๋๋ฉด Embedding_size์ Hidden_size๋ฅผ ๋์ฌ์ฃผ๋ ๊ฒ์ด ์ข์ ๋ฏ.
def generate_text(lyricist, tokenizer, init_sentence="<start>", max_len=20):
# ํ
์คํธ๋ฅผ ์ํด์ ์
๋ ฅ๋ฐ์ init_sentence๋ ํ
์๋ก ๋ณํํฉ๋๋ค
test_input = tokenizer.texts_to_sequences([init_sentence])
test_tensor = tf.convert_to_tensor(test_input, dtype=tf.int64)
end_token = tokenizer.word_index["<end>"]
# ๋จ์ด ํ๋์ฉ ์์ธกํด ๋ฌธ์ฅ์ ๋ง๋ญ๋๋ค
# 1. ์
๋ ฅ๋ฐ์ ๋ฌธ์ฅ์ ํ
์๋ฅผ ์
๋ ฅํฉ๋๋ค
# 2. ์์ธก๋ ๊ฐ ์ค ๊ฐ์ฅ ๋์ ํ๋ฅ ์ธ word index๋ฅผ ๋ฝ์๋
๋๋ค
# 3. 2์์ ์์ธก๋ word index๋ฅผ ๋ฌธ์ฅ ๋ค์ ๋ถ์
๋๋ค
# 4. ๋ชจ๋ธ์ด <end>๋ฅผ ์์ธกํ๊ฑฐ๋, max_len์ ๋๋ฌํ๋ค๋ฉด ๋ฌธ์ฅ ์์ฑ์ ๋ง์นฉ๋๋ค
while True:
# 1
predict = lyricist(test_tensor)
# 2
predict_word = tf.argmax(tf.nn.softmax(predict, axis=-1), axis=-1)[:, -1]
# 3
test_tensor = tf.concat([test_tensor, tf.expand_dims(predict_word, axis=0)], axis=-1)
# 4
if predict_word.numpy()[0] == end_token: break
if test_tensor.shape[1] >= max_len: break
generated = ""
# tokenizer๋ฅผ ์ด์ฉํด word index๋ฅผ ๋จ์ด๋ก ํ๋์ฉ ๋ณํํฉ๋๋ค
for word_index in test_tensor[0].numpy():
generated += tokenizer.index_word[word_index] + " "
return generated
# ๋ฌธ์ฅ ์ถ๋ ฅ
generate_text(lyricist, tokenizer, init_sentence="<start> I ", max_len=20)
'<start> i m a survivor <end> '