문자단위 RNN(Char RNN)
- 주어진 문장을 RNN으로 학습시킨 뒤 얼마나 비슷하게 텍스트를 생성하는지 살펴본다.
1. 훈련 데이터 전터리 하기
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
sentence = ("if you want to build a ship, don't drum up people together to "
"collect wood and don't assign them tasks and work, but rather "
"teach them to long for the endless immensity of the sea.")
char_set = list(set(sentence))
char_dic = {c: i for i, c in enumerate(char_set)}
print(char_dic)
{'o': 0, 'w': 1, 'y': 2, 'd': 3, 'c': 4, 'i': 5, 'h': 6, ' ': 7, 'g': 8, 'b': 9, 's': 10, 'a': 11, "'": 12, 'n': 13, 't': 14, 'e': 15, 'm': 16, 'u': 17, '.': 18, 'k': 19, 'p': 20, 'f': 21, ',': 22, 'r': 23, 'l': 24}
dic_size = len(char_dic)
print('문자 집합의 크기 : {}'.format(dic_size))
문자 집합의 크기 : 25
hidden_size = dic_size
sequence_length = 10
learning_rate = 0.1
x_data = []
y_data = []
for i in range(0, len(sentence) - sequence_length):
x_str = sentence[i:i + sequence_length]
y_str = sentence[i + 1: i + sequence_length + 1]
x_data.append([char_dic[c] for c in x_str])
y_data.append([char_dic[c] for c in y_str])
if i < 4 or i > 165:
print(i, x_str, '->', y_str)
0 if you wan -> f you want
1 f you want -> you want
2 you want -> you want t
3 you want t -> ou want to
166 ty of the -> y of the s
167 y of the s -> of the se
168 of the se -> of the sea
169 of the sea -> f the sea.
print(x_data[0])
print(y_data[0])
[5, 21, 7, 2, 0, 17, 7, 1, 11, 13]
[21, 7, 2, 0, 17, 7, 1, 11, 13, 14]
x_one_hot = [np.eye(dic_size)[x] for x in x_data]
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)
print('훈련 데이터의 크기 : {}'.format(X.shape))
print('레이블의 크기 : {}'.format(Y.shape))
훈련 데이터의 크기 : torch.Size([170, 10, 25])
레이블의 크기 : torch.Size([170, 10])
print(X[0])
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.]])
2. 모델 구현하기
class Net(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, layers):
super(Net, self).__init__()
self.rnn = torch.nn.RNN(input_dim, hidden_dim, num_layers=layers, batch_first=True)
self.fc = torch.nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x):
x, _status = self.rnn(x)
x = self.fc(x)
return x
net = Net(dic_size, hidden_size, 2)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), learning_rate)
outputs = net(X)
print(outputs.shape)
torch.Size([170, 10, 25])
print(outputs.view(-1, dic_size).shape)
torch.Size([1700, 25])
print(Y.shape)
print(Y.view(-1).shape)
torch.Size([170, 10])
torch.Size([1700])
for i in range(100):
optimizer.zero_grad()
outputs = net(X)
loss = criterion(outputs.view(-1, dic_size), Y.view(-1))
loss.backward()
optimizer.step()
results = outputs.argmax(dim=2)
predict_str = ""
for j, result in enumerate(results):
if j == 0:
predict_str += ''.join([char_set[t] for t in result])
else:
predict_str += char_set[result[-1]]
if i == 0 or i == 99:
print(predict_str)
boybocccolhowoobcwoooocbooooooooccolobccbhcloolooohoooooooowooboooooooccoooooooyooccoloooolccoocywoocbooooocoooloccwoyoooooowoohoocoocywooooobocyoowoooboyoooooccchbcooobycbcoooboo
l you wont to build a ship, don't drum up people together te collect wood and don't assign them tosks and work, but rather teach them to bong for the endless immensity of the seac
- 처음엔 이상한 예측을 하지만 마지막엔 꽤 정확한 문자를 생성하는 것을 볼 수 있다.