[TIL] TORCH.TENSOR.SCATTER_

wandajeong·2023년 5월 12일
0

TIL with chatGPT

목록 보기
5/5

헷갈렸던 torch scatter 함수 이해를 위해 정리해본다.
torch 공식 문서에는 scatter
에 대해 아래와 같이 나와있다.

Parameters:

  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged.
  • src (Tensor or float) – the source element(s) to scatter.
  • reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'.

파라마터에 대해 간단히 정리하자면 dim은 값을 채울 차원 기준 방향을 의미하며, index는 값을 채울 위치, src는 실제로 채워질 값(source)를 뜻한다.

Example:

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

dim이 0이므로 행(row) 방향으로 각 index에 따라 값을 순서대로 넣는다. 즉, 처음 index가 0이고 값이 1이므로 행 방향 [0, 0, 0]에서 [1, 0, 0]이 되고, 다음 index는 1이고 값이 2이므로 [0, 0, 0]에서 [0, 2, 0]이 된다. 유의할 점은 index의 tensor 값이 3이상이 되면 안된다. 당연하겠지만 dim 0 기준으로 index의 최대값은 2가 되기 때문.

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

이번에는 dim이 1이므로 열(column) 방향으로 각 index에 따라 값을 순서대로 넣는다. 즉, 처음 index가 [0, 1, 2]이므로 열 방향 [0, 0, 0, 0, 0]에서 [1, 2 , 3, 0, 0]이 된다. 또 다음 index는 [0, 1, 4]이므로 [0, 0, 0, 0, 0]에서 [6, 7, 0, 0, 8]이 된다.

여기서 index의 shape를 잘 고려해야 한다. 채우고자 하는 주체의 shape는 (3, 5)이고 scatter 방향은 1번 차원(dim) 기준이다. 그러므로 index의 0번 차원은 아무리 커도 3보다 크면 안 된다. 쉽게 말하면 채울 그릇보다 채우고자하는 번호 리스트(?)가 더 크면 안된다는 말이다.
예시로 코드를 바꿔보면,

src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
>>tensor([[1, 2, 3, 0, 0],
          [0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0]])

index shape를 (1,3)으로 변경해도 문제가 없다. 채울 그릇의 크기인 3보다 번호 리스트 크기(1)가 더 작기 때문이다. 만약 index shape를 더 키워본다면,

src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2], [0, 1, 4], [2, 3, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)

이렇게 하면 에러가 난다. 여기서의 문제는 번호 리스트는 있는데 값이 없기 때문이다. 다시 말해, index shape는 (3,3)이지만 src shape는 (2,5)이기 때문이다.

src = torch.arange(1, 13).view(3,-1)
index = torch.tensor([[0, 1, 2], [0, 1, 4], [2, 3, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
>>tensor([[ 1,  2,  3,  0,  0],
          [ 5,  6,  0,  0,  7],
          [ 0,  0,  9, 10, 11]])

이번에는 번호 리스트의 크기에 맞게 값도 맞춰주었다. 채울 그릇의 크기와 번호 리스트 크기가 3으로 같으므로 문제가 없다.

src = torch.arange(1, 13).view(3,-1)
index = torch.tensor([[0, 1, 2], [0, 1, 4], [2,3,4], [1, 1, 2]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)

만약 위와 같이 index shape가 (4, 3)이라면 에러가 발생한다. 이번에는 채울 그릇의 크기(3)보다 번호 리스트의 크기(4)가 더 크기 때문이다.
반면, index 의 1번 차원의 크기는 제한이 없다. 여기서는 3으로 동일하게 진행했지만 그 이상으로 얼마든지 키워도 된다(단, src 와 구조를 맞춰야함). 즉, index shape가 (3, 3)이든 (3,5) 또는 (3,6) 등 상관없다. 그러나, 테스트를 해보면 굉장히 이상하게(?) 채워지는 것을 알 수 있다.

뭔가 복잡한 것 같지만, 단순하게 번호 리스트(index)의 크기는 채울 그릇과 채울 값(src) 보다 더 크면 안된다. (더 작은 건 상관 없음)

3차원 tensor에서도 마찬가지다.

x = torch.zeros(2, 3, 4)
index = torch.tensor([[1,1,2],[0,0,2]])
x.scatter_(1, index.unsqueeze(1), 1.5)
>>tensor([[[0.0000, 0.0000, 0.0000, 0.0000],
           [1.5000, 1.5000, 0.0000, 0.0000],
           [0.0000, 0.0000, 1.5000, 0.0000]],

          [[1.5000, 1.5000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 1.5000, 0.0000]]])

다만, 3차원 구조는 맞춰줘야 하기때문에 채우고자 하는 방향(1번 차원)에 맞게 index를 unsqueeze 해서 shape를 (2,3) 에서 (2,1,3)으로 변경해준다.

profile
ML/DL swimmer

0개의 댓글