[PyTorch] torch.gather 설명

hj choi·2022년 9월 26일
0

AI

목록 보기
8/27

torch.gather

torch.gather(input, dim, index, *, sparse_grad = False, out = None)

parameters

  • input(Tensor) : input으로 받는 텐서

  • dim(int) : 어떤 axis(축)으로 변경할건지. dim을 기준으로 텐서의 shape이 바뀐다. 이것을 이해하려면 1D array, 2D array, 3D array,... 차원이 변함에 따라 axis의 기준이 어떻게 바뀌는지 이해해야한다.

    • 1D array 일때 :
      1차원이므로 행/열/차원의 개념이 없다.
    • 2D array 일때
      2차원으로, 행/열이 존재한다.
      • axis = 0 : 행
      • axis = 1 : 열
    • 3D array 일 때 :
      • axis = 0 : 차원
      • axis = 1 : 행
      • axis = 2 : 열

    새로운 차원이 더해지면 새로운 차원에 해당하는 축은 0이되고, 나머지는 1씩 더해진다고 생각하면 4차원이상 높아져도 헷갈리지 않을 수 있다.

  • index(LongTensor) : index에 넣어준 텐서의 구성요소별로 새로운 텐서를 구성한다. index의 텐서는 반드시 input으로 받은 텐서와 차원이 같아야함.

예시

###임의의 크기의 3D tensor에서 대각선 요소 모으기

0개의 댓글