[Pytorch] 3차원 tensor torch.cat(dim=0~2)

김예신·2022년 6월 21일
0

3D tensor

torch.cat을 하다보면 종종 차원이 헷갈리는 경우가 있어 정리한다.
NLP의 경우에는 대개 (batch_size, len, dim)의 3차원 tensor를 다루는 일이 많으므로 아래의 예제들을 토대로 연산되는 방식을 숙지하면 좋을것이다.

실험을 하기 위해 x와 y를 랜덤한 텐서로 초기화시켜준다.
각각의 텐서를 아래에서 확인할 수 있다.

>>> x = torch.rand(2, 3, 2)
>>> y = torch.rand(2, 3, 2)

>>> x
tensor([[[0.5288, 0.0804],
         [0.2054, 0.4510],
         [0.6637, 0.2970]],

        [[0.9760, 0.0178],
         [0.8031, 0.9524],
         [0.6094, 0.5132]]])
         
>>> y
tensor([[[0.2779, 0.3232],
         [0.5462, 0.4260],
         [0.9379, 0.9738]],

        [[0.7008, 0.4694],
         [0.9818, 0.5061],
         [0.0285, 0.3208]]])

dim=0 (default)

>>> torch.cat([x, y], dim=0)
tensor([[[0.5288, 0.0804],	
         [0.2054, 0.4510],
         [0.6637, 0.2970]],

        [[0.9760, 0.0178],
         [0.8031, 0.9524],
         [0.6094, 0.5132]],		# x

        [[0.2779, 0.3232],
         [0.5462, 0.4260],
         [0.9379, 0.9738]],

        [[0.7008, 0.4694],
         [0.9818, 0.5061],
         [0.0285, 0.3208]]])	# y

x와 y의 각각 차원이 (2, 3, 2)였는데 첫번째 차원을 기준으로 더해짐
torch.Size([4, 3, 2])가 된다.

dim=1

>>> torch.cat([x, y], dim=1)
tensor([[[0.5288, 0.0804],
         [0.2054, 0.4510],
         [0.6637, 0.2970],		# x[0]
         [0.2779, 0.3232],
         [0.5462, 0.4260],
         [0.9379, 0.9738]],		# y[0]

        [[0.9760, 0.0178],
         [0.8031, 0.9524],
         [0.6094, 0.5132],		# x[1]
         [0.7008, 0.4694],
         [0.9818, 0.5061],
         [0.0285, 0.3208]]])	# y[1]

torch.Size([2, 6, 2]) 마찬가지로 형태를 확인해보면 두번째 차원이 늘어난 것을 알 수 있다.

dim=2

>>> torch.cat([x, y], dim=2)
tensor([[[0.5288, 0.0804, 0.2779, 0.3232],		# -> torch.cat((x[0][0], y[0][0]))
         [0.2054, 0.4510, 0.5462, 0.4260],		# -> x[0][1]+y[0][1]
         [0.6637, 0.2970, 0.9379, 0.9738]],	

        [[0.9760, 0.0178, 0.7008, 0.4694],
         [0.8031, 0.9524, 0.9818, 0.5061],
         [0.6094, 0.5132, 0.0285, 0.3208]]]) 	# x[1]과 y[1]이 열기준으로 합쳐진다.
         

torch.Size([2, 3, 4])
크기는 위와 같다.
다음에는 헷갈리지 말아야지!

profile
life is dancing

0개의 댓글