RuntimeError: expected scalar type Long but found Float

boingboing·2024년 11월 6일

현상

  • 정수형 타입인 Long으로 구현되어야 하는 함수에 float같은 다른 자료형의 텐서가 주어질 때 발생.

out_loss = nn.CrossEntropyLoss(preds, masks.float())  # Loss 

코드에서 발생.

  • 해당 함수에서는 라벨 텐서 y의 자료형이 Long인 상태에 대하여 지원하고 있는데, float 자료형으로 주어지고 있기 때문에 해당 오류가 발생함.

  • 이 경우, Long 자료형으로 라벨 텐서 y의 타입을 캐스팅하여 아래와 같이 다시 함수를 적용하면 정상적으로 작동됨.

해결

  out_loss = criterion(preds, masks.long())
참고

https://jimmy-ai.tistory.com/313

0개의 댓글