헷갈렸던 torch scatter 함수 이해를 위해 정리해본다.
torch 공식 문서에는 scatter 에 대해 아래와 같이 나와있다.
Parameters:
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
- src (Tensor or float) – the source element(s) to scatter.
- reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'.
파라마터에 대해 간단히 정리하자면 dim
은 값을 채울 차원 기준 방향을 의미하며, index
는 값을 채울 위치, src
는 실제로 채워질 값(source)를 뜻한다.
Example:
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
dim이 0이므로 행(row) 방향으로 각 index에 따라 값을 순서대로 넣는다. 즉, 처음 index가 0이고 값이 1이므로 행 방향 [0, 0, 0]에서 [1, 0, 0]이 되고, 다음 index는 1이고 값이 2이므로 [0, 0, 0]에서 [0, 2, 0]이 된다. 유의할 점은 index의 tensor 값이 3이상이 되면 안된다. 당연하겠지만 dim 0 기준으로 index의 최대값은 2가 되기 때문.
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
이번에는 dim이 1이므로 열(column) 방향으로 각 index에 따라 값을 순서대로 넣는다. 즉, 처음 index가 [0, 1, 2]이므로 열 방향 [0, 0, 0, 0, 0]에서 [1, 2 , 3, 0, 0]이 된다. 또 다음 index는 [0, 1, 4]이므로 [0, 0, 0, 0, 0]에서 [6, 7, 0, 0, 8]이 된다.
여기서 index의 shape를 잘 고려해야 한다. 채우고자 하는 주체의 shape는 (3, 5)이고 scatter 방향은 1번 차원(dim) 기준이다. 그러므로 index의 0번 차원은 아무리 커도 3보다 크면 안 된다. 쉽게 말하면 채울 그릇보다 채우고자하는 번호 리스트(?)가 더 크면 안된다는 말이다.
예시로 코드를 바꿔보면,
src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
>>tensor([[1, 2, 3, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]])
index shape를 (1,3)으로 변경해도 문제가 없다. 채울 그릇의 크기인 3보다 번호 리스트 크기(1)가 더 작기 때문이다. 만약 index shape를 더 키워본다면,
src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2], [0, 1, 4], [2, 3, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
이렇게 하면 에러가 난다. 여기서의 문제는 번호 리스트는 있는데 값이 없기 때문이다. 다시 말해, index shape는 (3,3)이지만 src shape는 (2,5)이기 때문이다.
src = torch.arange(1, 13).view(3,-1)
index = torch.tensor([[0, 1, 2], [0, 1, 4], [2, 3, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
>>tensor([[ 1, 2, 3, 0, 0],
[ 5, 6, 0, 0, 7],
[ 0, 0, 9, 10, 11]])
이번에는 번호 리스트의 크기에 맞게 값도 맞춰주었다. 채울 그릇의 크기와 번호 리스트 크기가 3으로 같으므로 문제가 없다.
src = torch.arange(1, 13).view(3,-1)
index = torch.tensor([[0, 1, 2], [0, 1, 4], [2,3,4], [1, 1, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
만약 위와 같이 index shape가 (4, 3)이라면 에러가 발생한다. 이번에는 채울 그릇의 크기(3)보다 번호 리스트의 크기(4)가 더 크기 때문이다.
반면, index 의 1번 차원의 크기는 제한이 없다. 여기서는 3으로 동일하게 진행했지만 그 이상으로 얼마든지 키워도 된다(단, src 와 구조를 맞춰야함). 즉, index shape가 (3, 3)이든 (3,5) 또는 (3,6) 등 상관없다. 그러나, 테스트를 해보면 굉장히 이상하게(?) 채워지는 것을 알 수 있다.
뭔가 복잡한 것 같지만, 단순하게 번호 리스트(index)의 크기는 채울 그릇과 채울 값(src) 보다 더 크면 안된다. (더 작은 건 상관 없음)
3차원 tensor에서도 마찬가지다.
x = torch.zeros(2, 3, 4)
index = torch.tensor([[1,1,2],[0,0,2]])
x.scatter_(1, index.unsqueeze(1), 1.5)
>>tensor([[[0.0000, 0.0000, 0.0000, 0.0000],
[1.5000, 1.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.5000, 0.0000]],
[[1.5000, 1.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 1.5000, 0.0000]]])
다만, 3차원 구조는 맞춰줘야 하기때문에 채우고자 하는 방향(1번 차원)에 맞게 index를 unsqueeze 해서 shape를 (2,3) 에서 (2,1,3)으로 변경해준다.