[Python/Pytorch] torch.Tensor()와 torch.tensor() 의 차이

Jaewon Lim·2024년 6월 18일
0

torch.Tensor()

  • Class
  • int 입력 -> float 출력
  • torch 데이터 입력시 입력 받은 데이터의 메모리 공간을 활용
  • list, numpy 데이터 입력시 입력 받은 데이터를 복사하여 새롭게 torch.Tensor를 만든 후 사용(call by value)
  • Tensor 객체로 데이터 입력 시 메모리 공간을 그대로 사용(call by reference)
  • 입력값(data)이 없어도 실행이 가능

torch.tensor()

  • Function
  • int 입력 -> int 출력 / float 입력 -> float 출력
  • 입력 받은 데이터를 새로운 메모리 공간으로 복사 후 사용
  • 입력값(data)이 없으면 실행이 불가능
  • torch.Tensor() = torch.tensor([]) 동일

예를 들어 설명해보자

data_int = [1,2,3,4]
data_float = [1.,2.,3.,4.]
data_boolean = [True, False]

print(torch.Tensor(data_int), torch.Tensor(data_int).dtype) #tensor([1., 2., 3., 4.]) torch.float32
print(torch.tensor(data_int), torch.tensor(data_int).dtype) #tensor([1, 2, 3, 4]) torch.int64

print(torch.Tensor(data_float), torch.Tensor(data_float).dtype) #tensor([1., 2., 3., 4.]) torch.float32
print(torch.tensor(data_float), torch.tensor(data_float).dtype) #tensor([1., 2., 3., 4.]) torch.float32

print(torch.Tensor(data_boolean), torch.Tensor(data_boolean).dtype) #tensor([1., 0.]) torch.float32
print(torch.tensor(data_boolean), torch.tensor(data_boolean).dtype) #tensor([ True, False]) torch.bool

스칼라(Scala)를 넣었을 때 차이???

torch.Tensor([1,2]) 리스트로 값을 넣었지만, 스칼라 값 한 개만 넣게 된다면 리스트 안에 n개의 데이터가 랜덤으로 들어간다.

torch.Tensor(3)
#tensor([2.4990e-36, 0.0000e+00, 0.0000e+00]
torch.Tensor(3).dtype
#torch.float32

하지만 torch.tensor(3)의 경우 스칼라 값도 하나의 데이터로 인식한다.

torch.tensor(3)
#tensor(3)
torch.tensor(3).dtype
#torch.int64

requires_grad 차이

자동 미분을 이용하기 위한 requires_grad 함수를 사용한다.

torch.Tensor([2.,3.]), requires_grad = True)
#Error, Tensor안에는 requires_grad 파라미터 존재하지 않기 때문

torch.tensor([2.,3.]), requires_grad = True)
#tensor([2.,3.]), requires_grad = True, tensor안에는 requires_grad 파라미터 존재

결론

  • torch.tensor() 는 torch.Torch() 를 기반으로 만들어 진다. 결국 둘 다 torch.Tensor이지만 torch.tensor()는 여러 규칙을 준수하여 만들 수 있도록 제공하는 단순한 함수이다.
  • 리스트를 넣어줬을 때 Tensor()는 float 타입의 텐서 출력, tensor()는 int 타입의 텐서 출력.
  • 스칼라 값 하나만 넣어줬을 때, Tensor()는 숫자만큼의 랜덤 데이터 입력, tensor()는 숫자 한 개의 데이터 출력.

0개의 댓글