loss functions

안준모·2024년 3월 4일

CTCLoss

Target are to be un-padded

T = 5 # Input sequence length
C = 4 # Number of classes (including blank)
N = 3 # Batch size

Initialize random batch of input vectors, for *size = (T,N,C)

input = torch.randn(T, N, C).logsoftmax(2).detach().requires_grad()
print('input : ', input)
print('input size : ', input.size())
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
print('input lengths : ', input_lengths)

input : tensor([[[-1.6540, -1.2314, -1.0986, -1.6956],
[-1.3459, -1.3507, -1.4736, -1.3802],
[-1.5542, -1.1642, -2.2226, -0.9993]],
[[-1.4216, -1.8971, -2.3815, -0.6611],
[-1.0920, -1.0589, -2.5187, -1.4395],
[-1.7087, -0.6169, -1.8833, -2.0619]],
[[-1.5675, -0.8891, -1.0697, -3.2885],
[-2.0517, -1.6354, -2.1639, -0.5767],
[-1.5850, -1.1202, -1.7741, -1.2066]],
[[-1.4258, -0.6757, -1.5976, -3.0264],
[-1.5270, -0.7361, -1.5506, -2.3891],
[-0.7249, -1.4743, -1.4756, -2.8466]],
[[-1.2064, -1.5388, -0.8826, -2.6257],
[-0.7286, -2.3960, -1.1572, -2.1894],
[-2.5590, -1.2184, -0.5809, -2.6953]]], requires_grad=True)

input size : torch.Size([5, 3, 4])
input lengths : tensor([5, 5, 5])

Initialize random batch of targets (0 = blank, 1:C = classes)

target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
print('target length', target_lengths)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
print('target', target)
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
print('loss', loss)
loss.backward()

target length tensor([1, 1, 3])
target tensor([1, 1, 1, 2, 2])
loss tensor(2.9141, grad_fn=)

profile
MLops 에 관심이 많습니다.

0개의 댓글