Custom Loss Function 구현 시 유용하게 종종 사용할 듯
def label_to_one_hot_label(
labels: torch.Tensor,
num_classes: int,
device = None,
dtype = None,
eps: float = 1e-6,
ignore_index=-100,
) -> torch.Tensor:
# ignore_index에 해당하는 위치 확인하고 이를 처리하기 위해 vocab_size+1의 레이블을 할당한다.
# vocab_size+1의 레이블에 해당하는 부분은 잘라내어 해당 토큰의 대한 one-hot 레이블이 모두 0이 되도록 할 것임
check = (labels == ignore_index)
labels[check] = num_classes
# labels : (Batch, input_length)
shape = labels.shape
# one hot : (Batch, input_length, Vocab Size+1(ignore_index))
one_hot = torch.zeros((shape[0], shape[1], num_classes+1), device=device, dtype=dtype)
# labels.unsqueeze(2) : (Batch, input_length, 1)
one_hot = one_hot.scatter_(2, labels.unsqueeze(2), 1.0)
# ignore_index 부분 떼어내기
ret = torch.split(one_hot, [num_classes, 1], dim=2)[0]
return ret