torch.cat() VS torch.stack()

정강민·2022년 1월 27일
0

PytorchMaster

목록 보기
3/3

torch.cat()은 주어진 차원을 기준으로 주어진 텐서들을 붙입(concatenate)니다.
torch.stack()은 새로운 차원으로 주어진 텐서들을 붙입니다.
따라서, (3, 4)의 크기(shape)를 갖는 2개의 텐서 A와 B를 붙이는 경우,
torch.cat([A, B], dim=0)의 결과는 (6, 4)의 크기(shape)를 갖고,
torch.stack([A, B], dim=0)의 결과는 (2, 3, 4)의 크기를 갖습니다.

  • 두 개의 텐서 t1, t2를 예시로 선언
t1 = torch.tensor([[1, 2],
                   [3, 4]])
t2 = torch.tensor([[5, 6],
                   [7, 8]])
>>> torch.cat((t1, t2), dim=0) # dim=0인 경우
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
  • 이 때, torch.cat()의 동작은 다음과 같습니다.
>>> torch.cat((t1, t2), dim=1) # dim=1인 경우
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])
  • torch.stack()은 다음과 같습니다.
>>> torch.stack((t1, t2))
tensor([[[1, 2],
         [3, 4]],
 
        [[5, 6],
         [7, 8]]])
profile
DA/DA/AE

0개의 댓글