Comment :
tensor.long()
이 뭔지 찾아보다가 PyTorch의 Tensor의 타입에 대해서 간단히 정리함
참고 : [PyTorch]TENSOR ATTRIBUTES PyTorch 공식Docu, 생각보다 종류가 많음...
PyTorch의 Tensor 데이터 타입은 위와 같으며 보통 실수 계산을 하기 위해서는 FloatTensor, 정수를 사용하기 위해서는 LongTensor를 사용하며 Boolean, 즉 True/False 사용 시 ByteTensor를 사용한다.
Tensor.type(dtype=None, non_blocking=False)
str
orTensor
Returns the type if dtype is not provided, else casts this object to the specified type.
dtype을 설정하지 않고 사용하면 해당 Tensor의 타입을 출력하고, dtype을 설정하면 해당 dtype으로 Tensor를 변환하여 돌려줌
t_tensor = torch.tensor(1)
t_long = t_tensor.type('torch.LongTensor')
t_float = t_tensor.type('torch.FloatTensor')
t_byte = t_tensor.type('torch.ByteTensor')
print(t_long.type())
print(t_float.type())
print(t_byte.type())
>>>
torch.LongTensor
torch.FloatTensor
torch.ByteTensor