[Pytorch]torch.cat()

ma-kjh·2023년 8월 24일
0

Pytorch

목록 보기
2/20

torch.cat(tensors, dim=0, *, out=None) → Tensor

torch.cat은 tensor를 concatenate해주는 역할을 한다. 모든 텐서는 concatenating을 진행하는 dimension을 제외한 모두 같은 shape를 지녀야 연산이 가능하다.

Example:

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790, 0.1497]])

>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]])

>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]])

x의 shape는 (2,3) 이지만, dim=0으로 cat하면 (4,3)이 된다.
x의 shape는 (2,3) 이지만, dim=1으로 cat하면 (2,6)이 된다.

profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글