[Pytorch]torch.stack

ma-kjh·2024년 12월 23일
0

Pytorch

목록 보기
22/25

torch.stack리스트 또는 배열 형태의 텐서들을 하나의 torch.Tensor로 결합해주는 함수. 입력 텐서들을 새로운 차원으로 쌓는 역할.


1. torch.stack

  • torch.stack여러 텐서를 주어진 새 축(dim)을 따라 연결.
  • 즉, 입력 텐서 리스트의 요소들은 동일한 크기를 가져야 하며, 이를 새로운 차원으로 쌓아서 하나의 텐서로 만든다.

기본 예제:

import torch

# 텐서 리스트
tensor_list = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6])]

# torch.stack으로 쌓기
result = torch.stack(tensor_list)

print(result)
# 출력:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

print(result.shape)
# 출력: torch.Size([3, 2])  # (새로운 축, 기존 차원)

설명:

  • tensor_list는 크기 [2]를 가진 텐서 3개로 구성됨.
  • torch.stack은 리스트의 요소를 새로운 첫 번째 차원(디폴트 dim=0)으로 쌓아 크기 [3, 2]의 텐서를 생성함.

2. torch.stack vs torch.cat

둘 다 텐서를 연결하지만, 동작 방식이 다르다.

  • torch.stack: 새로운 차원을 생성해서 텐서를 쌓음.
  • torch.cat: 기존 차원에서 텐서를 이어붙임.

비교 예제:

import torch

tensor_list = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6])]

# torch.stack
stacked = torch.stack(tensor_list)
print(stacked)
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

# torch.cat
concatenated = torch.cat(tensor_list)
print(concatenated)
# tensor([1, 2, 3, 4, 5, 6])
  • torch.stack: 차원이 추가되어 [3, 2].
  • torch.cat: 차원이 추가되지 않고, 결과 크기가 [6].

3. torch.stack의 축 변경

새로운 축(dim)의 위치를 지정할 수 있습니다.

예제:

tensor_list = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6])]

# Default (dim=0)
result1 = torch.stack(tensor_list, dim=0)
print(result1.shape)  # torch.Size([3, 2])

# dim=1
result2 = torch.stack(tensor_list, dim=1)
print(result2)
# tensor([[1, 3, 5],
#         [2, 4, 6]])
print(result2.shape)  # torch.Size([2, 3])

4. 주의 사항

  1. 모든 텐서의 크기가 같아야 함:

    • torch.stack은 새로운 차원을 생성하므로, 리스트의 모든 텐서가 같은 크기를 가져야 함.
    tensors = [torch.tensor([1, 2]), torch.tensor([3, 4, 5])]
    torch.stack(tensors)  # 오류 발생: RuntimeError
  2. 데이터 타입도 동일해야 함:

    • 리스트의 텐서가 다른 데이터 타입을 가지면 오류가 발생.

5. 단순히 listtorch.Tensor로 바꾸는 경우

torch.tensor를 사용하면 간단히 리스트를 텐서로 변환할 수 있습니다.

예제:

# 리스트를 텐서로 변환
my_list = [[1, 2], [3, 4]]
tensor = torch.tensor(my_list)

print(tensor)
# tensor([[1, 2],
#         [3, 4]])

차이점:

  • torch.tensor는 리스트를 텐서로 직접 변환.
  • torch.stack은 텐서 리스트를 새로운 차원으로 쌓아서 하나의 텐서를 만듦.

요약

  • torch.stack은 텐서 리스트를 새로운 축(dim)으로 쌓아주는 역할.
  • 입력 텐서들은 크기와 데이터 타입이 동일해야 함.
  • 단순히 리스트를 텐서로 변환하려면 torch.tensor를 사용하고, 여러 텐서를 하나로 합칠 때는 torch.stack 또는 torch.cat을 적절히 선택.
profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글