torch.cat
은 주어진 텐서들을 주어진 차원에 맞춰 합쳐주는 함수입니다.
코드로 예시를 살펴보면 다음과 같습니다.
t1 = torch.ones(2, 2)
t2 = torch.zeros(2, 2)
t3 = torch.cat([t1, t2], dim=0)
t4 = torch.cat([t1, t2], dim=1)
print(t3)
>>> tensor([[1., 1.],
[1., 1.],
[0., 0.],
[0., 0.]])
print(t3.shape)
>>> torch.Size([4, 2])
print(t4)
>>> tensor([[1., 1., 0., 0.],
[1., 1., 0., 0.]])
print(t4.shape)
>>> torch.Size([2, 4])
dim
옵션은 입력 텐서의 차원을 따라갑니다. 예를 들어, 2차원 텐서를 입력하면 최대 dim=1, 3차원 텐서를 입력하면 최대 dim=2입니다.
dim
옵션에 따라 concat 되는 방향도 원본 텐서의 차원을 따라갑니다. 예를 들어 위 코드에서 t1, t2는 모두 2차원 텐서이므로, dim=0이면 1차원, 즉 행 방향으로 두 텐서를 합칩니다. 반대로 dim=1라면 열 방향으로 텐서를 합칩니다. 엑셀 시트의 행 추가, 열 추가와 비슷합니다.
공식 문서에는 "주어진 텐서들을 새로운 차원으로 합친다"고 적혀 있습니다. 즉, 다음과 같이 작동합니다.
기본적인 작동 방식은 torch.cat
과 같습니다.
t1 = torch.zeros(2, 2)
t2 = torch.ones(2, 2)
t3 = torch.stack([t1, t2], dim=0)
t4 = torch.stack([t1, t2], dim=1)
print(t3)
>>> tensor([[[1., 1.],
[1., 1.]],
[[0., 0.],
[0., 0.]]])
print(t3.shape)
>>> torch.Size([2, 2, 2])
print(t4)
>>> tensor([[[1., 1.],
[0., 0.]],
[[1., 1.],
[0., 0.]]])
print(t4.shape)
>>> torch.Size([2, 2, 2])
위 코드를 보면 torch.cat
과 마찬가지로 입력 텐서의 차원을 따라가고 있는 걸 확인할 수 있습니다. 다만 torch.cat
과 다른 점은 차원이 하나 늘어서 3차원 텐서가 되었다는 점입니다. 공식 문서의 설명처럼 새로운 차원이 늘어난 거죠. (그래서 함수명이 stack)
참고문헌