Bool value of Tensor with more than one value is ambiguous in Pytorch

xdfc1745·2021년 11월 29일
1

GPU를 이용한 pytorch를 사용하다보면 이런 에러를 만날 수도 있다.

이 에러는 tensor값을 boolen 값으로 비교하려고 할 떄 발생하는 에러이다.

예를 들면 다음과 같은 상황이다.

a = [[1, 2],[3, 4]]
a = torch.tensor(a)
if a == 1:
	print('a is 1')
else:
	print('a is not 1')

위의 코드와 같이 tensor값을 비교하면 다음과 같은 에러가 발생한다.

이를 해결하기 위한 방법으로는 여러가지가 있겠지만, 나는 이 방법을 사용했다.

  1. tensor값을 cpu값으로 바꿔준다.
a = torch.tensor(a)

if a.cpu().detach().numpy() == 1:
	...

이와 같이 tensor값을 다시 cpu에서 사용가능한 값으로 바꿔주면 된다.

cpu를 먼저 선언해주고, detach후 numpy배열로 바꿔야 에러가 안나고 잘 바뀐다.

예외) 나는 값보다 타입으로 비교가 가능하여

if type(a) == int:
	...

이처럼 비교가 가능한 타입으로 바꿔주면 에러없이 실행이 가능하다.

profile
안녕하세요 ㅎㅎ

1개의 댓글

comment-user-thumbnail
2022년 4월 27일

도움이 되었습니다. 감사합니다.

답글 달기