확실히 수업 따라가느라 급급하고, 복습까지는 제대로 생각도 하지 못하고 있다. 그래도 끝까지 과정을 통과해서 이수하는데 의의를 둬야겠다.
import torch, torch.nn as nn
tanh = torch.tanh
W_hh = nn.Linear(128, 128, bias = False) # h_{t-1} -> h
W_xh = nn.Linear(64, 128, bias = False) # x_t -> h
W_hy = nn.Linear(128, 10, bias = False) # h_t -> y
h_prev = torch.zeros(32, 128) # (batch, hidden)
x_t = torch.randn(32, 64) # (batch, input)
h_t = tanh(W_hh(h_prev) + W_xh(x_t))
y_t = W_hy(h_t)
import torch, torch.nn as nn
sigmoid, tanh = torch.sigmoid, torch.tanh
class LSTMCellManual(nn.Module):
def __init__(self, in_dim, hid_dim):
super().__init__()
self.W_i = nn.Linear(in_dim + hid_dim, hid_dim)
self.W_f = nn.Linear(in_dim + hid_dim, hid_dim)
self.W_o = nn.Linear(in_dim + hid_dim, hid_dim)
self.W_g = nn.Linear(in_dim + hid_dim, hid_dim)
def forward(self, x_t, h_prev, c_prev):
z = torch.cat([h_prev, x_t],dim = -1)
i = sigmoid(self.W_i(z))
f = sigmoid(self.W_f(z))
o = sigmoid(self.W_o(z))
g = tanh(self.W_g(z))
c_t = f * c_prev + i * g
h_t = o * tanh(c_t)
return h_t, c_t
# 사용 예
cell = LSTMCellManual(64, 128)
h_prev = torch.zeros(32, 128)
c_prev = torch.zeros(32, 128)
x_t = torch.randn(32, 64)
h_t, c_t = cell(x_t, h_prev, c_prev)
# 양방향 LSTM(Bi-LSTM) 한 줄 요약 코드
import torch, torch.nn as nn
bilstm = nn.LSTM(input_size=64, hidden_size=128, num_layers=2,
bidirectional=True, batch_first=True)
x = torch.randn(32, 50, 64) # (batch, seq_len, input_size)
out, _ = bilstm(x) # out: (32, 50, 256) (2*hidden)
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
lengths = torch.tensor([50, 45, 41, 30]) # 내림차순 정렬 가정
x = torch.randn(4, 50, 64)
x_packed = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=True)
out_packed, _ = bilstm(x_packed)
out, _ = pad_packed_sequence(out_packed, batch_first=True) # (4, 50, 256)
import torch, torch.nn as nn
T, N, C = 4, 2, 5 # time, batch, classes(문자+blank)
logits = torch.randn(T, N, C).log_softmax(2) # (T, N, C), log-prob
# target: 배치 별 정답 라벨들을 일렬로 붙인 형태
targets = torch.tensor([1,2,3, 2,2]) # 예: 첫 배치 '1,2,3', 둘째 배치 '2,2'
target_lengths = torch.tensor([3,2])
input_lengths = torch.tensor([T, T])
ctc = nn.CTCLoss(blank=0, zero_infinity=True)
loss = ctc(logits, targets, input_lengths, target_lengths)
loss.backward()
import torch
def ctc_greedy_decode(log_probs, blank_id=0):
# log_probs: (T, C) or (T, N, C) 중 (T, C) 가정
path = log_probs.argmax(dim=-1).tolist()
decoded = []
prev = None
for p in path:
if p != blank_id and p != prev:
decoded.append(p)
prev = p
return decoded
# 예시
T, C = 6, 5
log_probs = torch.randn(T, C).log_softmax(-1)
tokens = ctc_greedy_decode(log_probs, blank_id=0)