study material: https://www.youtube.com/watch?v=7zDARfKVm7s (thank you!)
HMM (forward-backward algorithm) implementation
def forward_backward(log_likelihoods, trans):
T, K = log_likelihoods.shape
gamma = torch.zeros(T, K)
for t in range(T):
alpha = torch.zeros(t + 1, K)
alpha[0] = log_likelihoods[0] + torch.log(torch.ones(K) / K + 1e-10)
for s in range(1, t + 1):
alpha[s] = log_likelihoods[s] + torch.logsumexp(alpha[s-1] + torch.log(trans + 1e-10), dim=1)
beta = torch.zeros(T - t, K)
beta[-1] = torch.zeros(K)
for s in range(T - t - 1, -1, -1):
if s == T - t - 1:
continue
beta[s] = torch.logsumexp(torch.log(trans + 1e-10) + log_likelihoods[t + s + 1] + beta[s + 1], dim=1)
gamma[t] = torch.softmax(alpha[t] + beta[0], dim=0)
xi = torch.zeros(T - 1, K, K)
for t in range(T - 1):
alpha_t = torch.zeros(t + 1, K)
alpha_t[0] = log_likelihoods[0] + torch.log(torch.ones(K) / K + 1e-10)
for s in range(1, t + 1):
alpha_t[s] = log_likelihoods[s] + torch.logsumexp(alpha_t[s-1] + torch.log(trans + 1e-10), dim=1)
beta_t = torch.zeros(T - t, K)
beta_t[-1] = torch.zeros(K)
for s in range(T - t - 1, -1, -1):
if s == T - t - 1:
continue
beta_t[s] = torch.logsumexp(torch.log(trans + 1e-10) + log_likelihoods[t + s + 1] + beta_t[s + 1], dim=1)
xi[t] = (alpha_t[t].unsqueeze(1) + torch.log(trans + 1e-10) +
log_likelihoods[t + 1].unsqueeze(0) + beta_t[0].unsqueeze(0))
xi[t] = torch.softmax(xi[t], dim=1)
return gamma, xi