[PyTorch] torch.cat과 torch.stack의 차이

Ethan·2022년 12월 28일
1

파이토치

목록 보기
2/3

torch.cat

torch.cat 은 주어진 텐서들을 주어진 차원에 맞춰 합쳐주는 함수입니다.

코드로 예시를 살펴보면 다음과 같습니다.

t1 = torch.ones(2, 2)
t2 = torch.zeros(2, 2)

t3 = torch.cat([t1, t2], dim=0)
t4 = torch.cat([t1, t2], dim=1)

print(t3)
>>> tensor([[1., 1.],
    [1., 1.],
    [0., 0.],
    [0., 0.]])
    
print(t3.shape)
>>> torch.Size([4, 2])

print(t4)
>>> tensor([[1., 1., 0., 0.],
    [1., 1., 0., 0.]])
    
print(t4.shape)
>>> torch.Size([2, 4])

dim 옵션은 입력 텐서의 차원을 따라갑니다. 예를 들어, 2차원 텐서를 입력하면 최대 dim=1, 3차원 텐서를 입력하면 최대 dim=2입니다.

dim 옵션에 따라 concat 되는 방향도 원본 텐서의 차원을 따라갑니다. 예를 들어 위 코드에서 t1, t2는 모두 2차원 텐서이므로, dim=0이면 1차원, 즉 행 방향으로 두 텐서를 합칩니다. 반대로 dim=1라면 열 방향으로 텐서를 합칩니다. 엑셀 시트의 행 추가, 열 추가와 비슷합니다.

torch.stack

공식 문서에는 "주어진 텐서들을 새로운 차원으로 합친다"고 적혀 있습니다. 즉, 다음과 같이 작동합니다.

기본적인 작동 방식은 torch.cat과 같습니다.

t1 = torch.zeros(2, 2)
t2 = torch.ones(2, 2)

t3 = torch.stack([t1, t2], dim=0)
t4 = torch.stack([t1, t2], dim=1)

print(t3)
>>> tensor([[[1., 1.],
     [1., 1.]],

    [[0., 0.],
     [0., 0.]]])

print(t3.shape)
>>> torch.Size([2, 2, 2])

print(t4)
>>> tensor([[[1., 1.],
     [0., 0.]],

    [[1., 1.],
     [0., 0.]]])

print(t4.shape)
>>> torch.Size([2, 2, 2])

위 코드를 보면 torch.cat과 마찬가지로 입력 텐서의 차원을 따라가고 있는 걸 확인할 수 있습니다. 다만 torch.cat과 다른 점은 차원이 하나 늘어서 3차원 텐서가 되었다는 점입니다. 공식 문서의 설명처럼 새로운 차원이 늘어난 거죠. (그래서 함수명이 stack)


참고문헌

  1. [PyTorch] Tensor 합치기: cat(), stack()
profile
재미있게 살고 싶은 대학원생

0개의 댓글

관련 채용 정보