[RuntimeError]: Expected floating point type for target with class probabilities, got Long

Hyun·2022년 8월 1일
0

ERROR

목록 보기
3/6

float 타입을 예상했는데, Long 타입의 변수가 입력되었다.
해결 : dtype을 float로 바꾸자! 정밀도가 많이 필요하지 않기 때문에 명시적으로 float16으로 선언하였다.(float로 해도 된다.)

여기서,
dtype이 float64일 때, tensor은 torch.(cuda.)DoubleTensor로 인식한다.
마찬가지로 dtype이 float16일 때, tensor은 torch.(cuda.)HalfTensor로 인식한다.

이렇게 같은 float라도 다른 object를 가질 수 있기 때문에 유의해야한다.

start_loss = self.loss_fct(torch.argmax(start_logits, dim=-1), start_positions)

=>

start_loss = self.loss_fct(torch.tensor(torch.argmax(start_logits, dim=-1), dtype=torch.float16), start_positions)

start_loss = self.loss_fct(torch.argmax(start_logits, dim=-1).float16(), start_positions.float16())

참고

[Pytorch] Tensor에서 혼동되는 여러 메서드와 함수

0개의 댓글