torch.nn.CTCLoss

안준모·2024년 3월 11일

Target are to be padded

T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 16 # Batch size
S = 30 # Target sequence length of longest target in batch (padding length)
Smin = 10 # Minimum target length, for demonstration purposes
Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad
()
Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

Target are to be un-padded

T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
N = 16 # Batch size
Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).logsoftmax(2).detach().requires_grad()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

Target are to be un-padded and unbatched (effectively N=1)

T = 50 # Input sequence length
C = 20 # Number of classes (including blank)
Initialize random batch of input vectors, for *size = (T,C)
input = torch.randn(T,C).logsoftmax(1).detach().requires_grad()
input_lengths = torch.tensor(T, dtype=torch.long)
Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
ctc_loss = torch.nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

한가지 실행 예시

# Target are to be padded
T = 50          # Input sequence length
B = 16          # Batch size
C = 20          # Number of classes (including blank)

S_max = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10      # Minimum target length, for demonstration purposes

# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, B, C).log_softmax(2)  #.detach().requires_grad_()
print('input : ', input.size())

# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(B, S_max), dtype=torch.long)
print('target : ', target.size())
input_length = torch.LongTensor([T for i in range(B)])
print('input length: ', input_length.size(), input_length)
target_length = torch.randint(low=S_min, high=S_max, size=(B,), dtype=torch.long)ㅠ
print('target length: ', target_length.size(), target_length)

ctc_loss = torch.nn.CTCLoss(reduction = 'mean', zero_infinity=True)
ctc_loss(input, target, input_length, target_length)
input :  torch.Size([50, 16, 20])
target :  torch.Size([16, 30])
input length:  torch.Size([16]) tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50])
target length:  torch.Size([16]) tensor([12, 28, 25, 17, 19, 23, 14, 11, 11, 13, 21, 26, 23, 17, 17, 26])
tensor(6.7078)
profile
MLops 에 관심이 많습니다.

0개의 댓글