torch.cat을 하다보면 종종 차원이 헷갈리는 경우가 있어 정리한다.
NLP의 경우에는 대개 (batch_size, len, dim)의 3차원 tensor를 다루는 일이 많으므로 아래의 예제들을 토대로 연산되는 방식을 숙지하면 좋을것이다.
실험을 하기 위해 x와 y를 랜덤한 텐서로 초기화시켜준다.
각각의 텐서를 아래에서 확인할 수 있다.
>>> x = torch.rand(2, 3, 2)
>>> y = torch.rand(2, 3, 2)
>>> x
tensor([[[0.5288, 0.0804],
[0.2054, 0.4510],
[0.6637, 0.2970]],
[[0.9760, 0.0178],
[0.8031, 0.9524],
[0.6094, 0.5132]]])
>>> y
tensor([[[0.2779, 0.3232],
[0.5462, 0.4260],
[0.9379, 0.9738]],
[[0.7008, 0.4694],
[0.9818, 0.5061],
[0.0285, 0.3208]]])
>>> torch.cat([x, y], dim=0)
tensor([[[0.5288, 0.0804],
[0.2054, 0.4510],
[0.6637, 0.2970]],
[[0.9760, 0.0178],
[0.8031, 0.9524],
[0.6094, 0.5132]], # x
[[0.2779, 0.3232],
[0.5462, 0.4260],
[0.9379, 0.9738]],
[[0.7008, 0.4694],
[0.9818, 0.5061],
[0.0285, 0.3208]]]) # y
x와 y의 각각 차원이 (2, 3, 2)였는데 첫번째 차원을 기준으로 더해짐
torch.Size([4, 3, 2])
가 된다.
>>> torch.cat([x, y], dim=1)
tensor([[[0.5288, 0.0804],
[0.2054, 0.4510],
[0.6637, 0.2970], # x[0]
[0.2779, 0.3232],
[0.5462, 0.4260],
[0.9379, 0.9738]], # y[0]
[[0.9760, 0.0178],
[0.8031, 0.9524],
[0.6094, 0.5132], # x[1]
[0.7008, 0.4694],
[0.9818, 0.5061],
[0.0285, 0.3208]]]) # y[1]
torch.Size([2, 6, 2])
마찬가지로 형태를 확인해보면 두번째 차원이 늘어난 것을 알 수 있다.
>>> torch.cat([x, y], dim=2)
tensor([[[0.5288, 0.0804, 0.2779, 0.3232], # -> torch.cat((x[0][0], y[0][0]))
[0.2054, 0.4510, 0.5462, 0.4260], # -> x[0][1]+y[0][1]
[0.6637, 0.2970, 0.9379, 0.9738]],
[[0.9760, 0.0178, 0.7008, 0.4694],
[0.8031, 0.9524, 0.9818, 0.5061],
[0.6094, 0.5132, 0.0285, 0.3208]]]) # x[1]과 y[1]이 열기준으로 합쳐진다.
torch.Size([2, 3, 4])
크기는 위와 같다.
다음에는 헷갈리지 말아야지!