MLM 레이블을 one-hot 레이블로 바꾸기

황준하·2023년 10월 1일
0

각 MLM 레이블을 one-hot 레이블로 바꾸기

Custom Loss Function 구현 시 유용하게 종종 사용할 듯

def label_to_one_hot_label(
    labels: torch.Tensor,
    num_classes: int,
    device = None,
    dtype = None,
    eps: float = 1e-6,
    ignore_index=-100,
) -> torch.Tensor:
    
    # ignore_index에 해당하는 위치 확인하고 이를 처리하기 위해 vocab_size+1의 레이블을 할당한다.
    # vocab_size+1의 레이블에 해당하는 부분은 잘라내어 해당 토큰의 대한 one-hot 레이블이 모두 0이 되도록 할 것임
    
    check = (labels == ignore_index)
    labels[check] = num_classes
    
    # labels : (Batch, input_length)
    shape = labels.shape
    # one hot : (Batch, input_length, Vocab Size+1(ignore_index))
    one_hot = torch.zeros((shape[0], shape[1], num_classes+1), device=device, dtype=dtype)
    
    # labels.unsqueeze(2) : (Batch, input_length, 1)
    one_hot = one_hot.scatter_(2, labels.unsqueeze(2), 1.0)
    
    # ignore_index 부분 떼어내기
    ret = torch.split(one_hot, [num_classes, 1], dim=2)[0]
    
    return ret

0개의 댓글

관련 채용 정보