[torch] cat vs stack

SeungHyun·2024년 4월 17일
0

pytorch

목록 보기
8/8
post-thumbnail

입력 tensor의 차원

tensor들 간의 차원이 다를 경우

batch1 shape: torch.Size([4, 4])
batch2 shape: torch.Size([4, 3])

cat

  • dim=0: 오류 발생
    (입력 차원 외 크기가 다를 경우 오류)
  • dim=1: cat_result shape: torch.Size([4, 7])

stack

  • dim=0: 오류 발생
    (크기가 다를 경우 오류)
  • dim=1: 오류 발생

tensor들 간의 차원이 같을 경우

batch1 shape: torch.Size([4, 3])
batch2 shape: torch.Size([4, 3])

cat

  • dim=0: cat_result shape: torch.Size([8, 3])
  • dim=1: cat_result shape: torch.Size([4, 6])

stack

  • dim=0: stack_result shape: torch.Size([2, 4, 3])
  • dim=1: stack_result shape: torch.Size([4, 2, 3])


예시

cat

  • dim = k 일때 k차원을 이어붙힌다고 생각할것.
    • k차원을 이어붙히기 때문에 k차원은 두 텐서의 k차원 크기의 합이며 나머지는 동일함.
    • 위 예시의 경우 dim = 1이므로 column에 이어붙힘.
  • k차원을 이어붙혀야 하기 때문에 k 외 나머지 차원은 크기가 동일해야함!

stack

dim = 1

dim = 2

  • dim = k 일때 k차원에 서로 겹쳐서 쌓는다고 생각할것.
    • k차원에 겹쳐서 쌓기 때문에 tensor들 간의 shape이 전부 동일해야함.


결론

cat

  • 차원을 유지하면서 텐서를 이어 붙임.
  • 차원은 동일하며 차원의 크기가 증가함.
  • 단순히 tensor와 tensor를 차원에 맞게 이어붙힘.(이어붙힐 차원 외 다른 차원의 크기는 반드시 동일해야함.)

    ex)
    a.shape = (2, 3)
    b.shape = (2, 3) 일때


    dim = 0일 경우: (4, 3) = (2 + 2, 3)
    dim = 1일 경우: (2, 6) = (2 , 3 + 3)

stack

  • 새로운 차원을 추가하여 텐서를 쌓음.
  • 크기는 동일하며 차원이 증가함.
  • tensor와 tensor를 차원에 맞게 겹침.(겹치기 때문에 새로운 차원이 생성됨)

ref

profile
어디로 가야하오

0개의 댓글

관련 채용 정보