📒 RNN
📝 RNN?
- 순서도 데이터의 일부가 되는 sequential data를 위해 만들어졌다.
- 이전 입력 값의 처리를 반영하여 모델이 데이터의 순서를 이해한다.
- 긴 sequence가 들어와도 이를 처리할 A의 parameter만 알면 된다.
rnn = torch.nn.RNN(input_size, hidden_size)
outputs, _status = rnn(input_data)
📝 'hihello' Problem
- 하나의 character가 들어오면 다음 character를 예측하는 모델
- 각 charater에 대한 one-hot encoding이 필요하다.
sample = "hihello'
char_set = list(set(sample))
char_dic = {c : i for i, c in enumerate(char_set)}
dic_size = len(char_dic)
hidden_size = len(char_dic)
learning_rate = 0.1
sample_idx = [char_dic[c] for c in sample]
x_data = [sample_idx[:-1]]
x_one_hot = [np.eye(dic_size)[x] for x in x_data]
y_data = [sample_idx[1:]]
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)
rnn = torch.nn.RNN(input_size, hidden_size, batch_first = True)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn.parameters(), learning_rate)
for i in range(100):
optimizer.zero_grad()
outputs, _status = rnn(X)
loss = criterion(outputs.view(-1, input_size), Y.view(-1))
loss.backward()
optimizer.step()
rresult = outputs.data.numpy().agmax(axis=2)
result_str = ''.join([char_set[c] for c in np.squeeze(result)])
print(f"{i} loss: {loss.item()} prediction: {result} true Y: {y_data}, prediction str: {result_str}"