T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 16 # Batch size
S = 30 # Target sequence length of longest target in batch (padding length)
Smin = 10 # Minimum target length, for demonstration purposes
Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad()
Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 16 # Batch size
Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).logsoftmax(2).detach().requires_grad()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
Initialize random batch of input vectors, for *size = (T,C)
input = torch.randn(T,C).logsoftmax(1).detach().requires_grad()
input_lengths = torch.tensor(T, dtype=torch.long)
Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
한가지 실행 예시
# Target are to be padded
T = 50 # Input sequence length
B = 16 # Batch size
C = 20 # Number of classes (including blank)
S_max = 30 # Target sequence length of longest target in batch (padding length)
S_min = 10 # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, B, C).log_softmax(2) #.detach().requires_grad_()
print('input : ', input.size())
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(B, S_max), dtype=torch.long)
print('target : ', target.size())
input_length = torch.LongTensor([T for i in range(B)])
print('input length: ', input_length.size(), input_length)
target_length = torch.randint(low=S_min, high=S_max, size=(B,), dtype=torch.long)ㅠ
print('target length: ', target_length.size(), target_length)
ctc_loss = torch.nn.CTCLoss(reduction = 'mean', zero_infinity=True)
ctc_loss(input, target, input_length, target_length)
input : torch.Size([50, 16, 20])
target : torch.Size([16, 30])
input length: torch.Size([16]) tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50])
target length: torch.Size([16]) tensor([12, 28, 25, 17, 19, 23, 14, 11, 11, 13, 21, 26, 23, 17, 17, 26])
tensor(6.7078)