3차원 텐서 요소 추출 후 2차원 텐서 변형 시키기(torch.gather 이용)

JunHyuk Kwon(권준혁)·2022년 10월 4일
0
post-thumbnail

3차원 텐서에서 특정한 규칙으로 추출하고자 하는 요소들을 2차원 텐서로 제작하고 싶은 그런 경우가 종종 있습니다(?) (특히 빡센 dataloader 제작할 때나 모델 제작할 때?)
그럴 때 빠르게 추출할 수 있는 torch.gather을 이용한 방식을 사용할 수 있습니다!!

torch.gather 이해하기

toch.gather 에 대한 PyTorch documentation documentation을 보면, 사실 잘 이해가 안 갑니다.
저는 이렇게 정성적으로 이해해보는게 쉬운 것 같습니다:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
torch.gather(추출하려는 텐서, 추출하려는 텐서에서 어떤 차원을 참고할 것인지?, 해당 차원에서 몇 번째 index 요소들을 참고할 것인지)

주의할 것:

  1. torch.gather은 텐서의 추출하고자하는 요소들 하나하나를 직접 접근해야한다는 것을 인지해야합니다.
  2. [추출하려는 텐서에서 어떤 차원을 참고할 것인지?]
    -> 2번째 매개변수인 dim은 추출하고자하는 원소들의 차원을 입력해주시면 됩니다. (= 3차원 텐서이면 C H W 중에 W 부분)
    -> 3차원 텐서의 요소를 추출하려면 dim=2
    -> 2차원 텐서의 요소를 추출하려면 dim=1
  3. [해당 차원에서 몇 번째 index 요소들을 참고할 것인지]
    -> 3번째 매개변수인 index는 torch.tensor() 형식으로 추출하려는 요소들을 묶어줘야 합니다.
    -> index 부분이 제일 핵심이고 어렵습니다!!

이해를 돕기 위한 예시:

3차원 텐서의 대각선 원소들을 담은 2차원 텐서 만들기

tensor의 변화 과정을 살펴봅니다!

우선 input, torch.gather을 통한 결과, 그리고 output 결과부터 확인해봅니다!

input 예시:

input = torch.tensor(
        [[[5,4,3],
          [7,6,2]],
         [[3,2,1],
          [1,2,8]]])
print(input.size())
>>>torch.Size([2, 2, 3])

torch.gather 후 결과:

gather_result = torch.tensor(
    [[[5],[6]],
     [[3],[2]]]
)
print(gather_result.size())
>>>torch.Size([2, 2, 1])

output 결과:

output = torch.tensor(
        [[5,6],
         [3,2]])
print(output.size())
>>>torch.Size([2, 2])

위 예시를 통해 파악할 것

  1. 위에서 이야기한 것처럼, torch.gather는 텐서의 각 원소들을 하나하나씩 가져온다고 했습니다. 실제로 gather_result 를 보시면 3차원 텐서의 결과로 각 원소들이 하나하나 별개의 크기가 1인 1차원 텐서로 구분된 것을 확인할 수 있습니다.
    -> 따라서 torch.size()를 찍어보면, dim=2일 때의 size가 1인 것을 확인할 수 있죠!
  2. torch.gather 결과는 원소 하나하나가 각각 별개의 크기가 1인 1차원 텐서로 구분되어있기 때문에 view나 reshape나 squeeze를 활용해서 차원을 맞춰줘야 합니다

3차원 텐서 -> 2차원 텐서로 변화해가는 과정

3차원 tensor의 대각 요소들을 추출하기 위한 과정들을 step by step으로 알아보겠습니다~

  1. 3차원 텐서의 C, H, W 추출하기
A = torch.tensor(
        [[[5,4,3],
          [7,6,2]],
         [[3,2,1],
          [1,2,8]]])
C,H,W = A.size()
print(C,H,W)
>>>2 2 3
  1. 3차원 텐서의 대각 요소들의 길이 구하기 (가로, 세로 길이가 다른 경우를 위해)
A = torch.tensor(
        [[[5,4,3],
          [7,6,2]],
         [[3,2,1],
          [1,2,8]]])
C,H,W = A.size()
# diag_size: 대각 요소들의 길이
diag_size = min(H,W)
print(diag_size)
>>>2
  1. 대각 요소들의 인덱스를 저장하는 1차원 텐서 만들기(gather의 인덱스로 사용될 부분)
A = torch.tensor(
        [[[5,4,3],
          [7,6,2]],
         [[3,2,1],
          [1,2,8]]])
C,H,W = A.size()
# diag_size: 대각 요소들의 길이
diag_size = min(H,W)
# rng는 대각선 요소들의 index를 저장한 1차원 tensor
rng = torch.arange(diag_size)
print(rng)
>>>tensor([0, 1])
  1. 대각 요소들의 인덱스들을 별개의 크기가 1인 1차원 텐서 형태로 나누기
gather_index = rng.view(len(rng),-1)
print(gather_index)
>>>tensor([[0],
        [1]])
  1. 추출하고자 하는 3차원 텐서의 Channel 값(위에서 구한 C)만큼 gather_index 반복하기 (torch.expand) 사용
    torch.Tensor.expand: 텐서를 어떤 형태로 확장할 것인가에 대한 함수
    -> 주의할 것: 아래 코드의 torch.expand의 마지막 매개변수 부분은, 각 원소(대각 요소들의 인덱스)를 별개로 저장하기 때문에 이를 고려한 부분입니다. 코드에서 다시 설명했습니다.
# 아래와 같이, dim=2에 해당하는 부분을 1로 지정했습니다.
# diag_size는 H, W 중 하나입니다. 대각 요소들을 추출하기 위함입니다.
gather_index = gather_index.expand(C,diag_size,1)
print(gather_index)
>>>tensor([[[0],
         [1]],

        [[0],
         [1]]])
  1. 1단계~5단계를 통해, torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor 해당 부분에서 index에 해당하는 부분을 드디어 완성했습니다. 이제 output을 생성하겠습니다!
# dim=2로 둬서, 3차원 텐서인 A의 각 요소들에 접근할 수 있도록 합니다.
# index는 3차원 텐서인 A에서 접근하고자하는 요소들의 위치를 다음 gather_index로 지정합니다.
output = torch.gather(input=A,dim=2,index=gather_index)
print(output)
>>>tensor([[[5],
         [6]],

        [[3],
         [2]]])

완성했습니다!

의의

이 글은 3차원 텐서에서 대각 요소들을 추출하는 경우를 예시로 들었습니다.(부스트캠프 내용을 참고해서 만들었어요) 이 과정을 알고리즘 문제 해결하듯이 직접 알고리즘이랑 텐서 변화 과정을 그려보면서 작성하다보면 잘 이해될 것 같습니다.
또한 나중에 3차원 텐서에서 요소를 추출할 때 해당 글에서와 같은 사고과정으로 요소를 추출하는 방법에 대해 생각하고 알고리즘을 설계하면 원하는 값들을 자유자재로 추출할 수 있게 될 것 같습니다!

profile
LinkedIn: https://www.linkedin.com/in/junhyuk-kwon-8578b5247/ (1촌 환영해요) (블로그 글은 나중에 시간되면 회고 쓰는걸로....)

0개의 댓글

관련 채용 정보