>> x = torch.tensor([1, 2, 3])
>> torch.sum(x)
tensor(6)
1차원 텐서를 다룰 때는 직관적으로 이해할 수 있습니다. 그러나 그 이상을 넘어가면 이해하기 어려울 수 있습니다.
torch.sum
의 Description을 보면 각 주어진 dim의 row를 더한다고 되어 있지만 이 역시 바로 이해하기 힘듭니다.numpy.sum
의 axis
parameter도 torch.sum
의 dim
과 같은 역할을 합니다. >> x = torch.tensor([
[1, 2, 3],
[4, 5, 6]
])
>> x.shape
torch.Size([2, 3])
dim = 0 : row
, dim = 1 : column
인 것을 쉽게 알 수 있습니다.x.sum(dim = 0)
는 row-wise로 summation을 구하여 tensor([6, 15])
를 내놓을까요?>> x.sum(dim = 0)
tensor([5, 7, 9])
dim = 1
또한 반대의 결과를 가져옵니다.>> x.sum(dim = 1)
tensor([6, 15])
The way to understand the “axis” of numpy sum is that it collapses the specified axis. So when it collapses the axis 0 (the row), it becomes just one row (it sums column-wise).
numpy.sum의 axis를 이해하려면, 그것(axis parameter)이 특정 axis를 접는다고 생각하면 됩니다. axis 0(row)를 접는다면, 그것은 1개의 row가 되고 column-wise한 sum을 계산합니다.
numpy axis == torch dim
>> x = torch.tensor([
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
],
[
[1, 2, 3],
[4, 5, 6]
]
])
>> x.shape
torch.Size([3, 2, 3])
dim = 0
)dim = 0
은 3개의 2d tensor를 포함합니다.[[1,2,3], [4,5,6]]
>> x.sum(dim = 0)
tensor([[ 3, 6, 9],
[12, 15, 18]])
dim = 1
으로 summation하면 3 x 2 x 3 중 2가 사라져 3 x 3이 나옵니다>> x.sum(dim = 1)
tensor([[5, 7, 9],
[5, 7, 9],
[5, 7, 9]])
해당 포스트는 https://towardsdatascience.com/understanding-dimensions-in-pytorch-6edf9972d3be 를 참고하여 작성하였습니다.