3차원 텐서에서 특정한 규칙으로 추출하고자 하는 요소들을 2차원 텐서로 제작하고 싶은 그런 경우가 종종 있습니다(?) (특히 빡센 dataloader 제작할 때나 모델 제작할 때?)
그럴 때 빠르게 추출할 수 있는 torch.gather을 이용한 방식을 사용할 수 있습니다!!
toch.gather 에 대한 PyTorch documentation documentation을 보면, 사실 잘 이해가 안 갑니다.
저는 이렇게 정성적으로 이해해보는게 쉬운 것 같습니다:
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
torch.gather(추출하려는 텐서, 추출하려는 텐서에서 어떤 차원을 참고할 것인지?, 해당 차원에서 몇 번째 index 요소들을 참고할 것인지)
이해를 돕기 위한 예시:
우선 input, torch.gather을 통한 결과, 그리고 output 결과부터 확인해봅니다!
input = torch.tensor(
[[[5,4,3],
[7,6,2]],
[[3,2,1],
[1,2,8]]])
print(input.size())
>>>torch.Size([2, 2, 3])
gather_result = torch.tensor(
[[[5],[6]],
[[3],[2]]]
)
print(gather_result.size())
>>>torch.Size([2, 2, 1])
output = torch.tensor(
[[5,6],
[3,2]])
print(output.size())
>>>torch.Size([2, 2])
3차원 tensor의 대각 요소들을 추출하기 위한 과정들을 step by step으로 알아보겠습니다~
A = torch.tensor(
[[[5,4,3],
[7,6,2]],
[[3,2,1],
[1,2,8]]])
C,H,W = A.size()
print(C,H,W)
>>>2 2 3
A = torch.tensor(
[[[5,4,3],
[7,6,2]],
[[3,2,1],
[1,2,8]]])
C,H,W = A.size()
# diag_size: 대각 요소들의 길이
diag_size = min(H,W)
print(diag_size)
>>>2
A = torch.tensor(
[[[5,4,3],
[7,6,2]],
[[3,2,1],
[1,2,8]]])
C,H,W = A.size()
# diag_size: 대각 요소들의 길이
diag_size = min(H,W)
# rng는 대각선 요소들의 index를 저장한 1차원 tensor
rng = torch.arange(diag_size)
print(rng)
>>>tensor([0, 1])
gather_index = rng.view(len(rng),-1)
print(gather_index)
>>>tensor([[0],
[1]])
# 아래와 같이, dim=2에 해당하는 부분을 1로 지정했습니다.
# diag_size는 H, W 중 하나입니다. 대각 요소들을 추출하기 위함입니다.
gather_index = gather_index.expand(C,diag_size,1)
print(gather_index)
>>>tensor([[[0],
[1]],
[[0],
[1]]])
# dim=2로 둬서, 3차원 텐서인 A의 각 요소들에 접근할 수 있도록 합니다.
# index는 3차원 텐서인 A에서 접근하고자하는 요소들의 위치를 다음 gather_index로 지정합니다.
output = torch.gather(input=A,dim=2,index=gather_index)
print(output)
>>>tensor([[[5],
[6]],
[[3],
[2]]])
완성했습니다!
이 글은 3차원 텐서에서 대각 요소들을 추출하는 경우를 예시로 들었습니다.(부스트캠프 내용을 참고해서 만들었어요) 이 과정을 알고리즘 문제 해결하듯이 직접 알고리즘이랑 텐서 변화 과정을 그려보면서 작성하다보면 잘 이해될 것 같습니다.
또한 나중에 3차원 텐서에서 요소를 추출할 때 해당 글에서와 같은 사고과정으로 요소를 추출하는 방법에 대해 생각하고 알고리즘을 설계하면 원하는 값들을 자유자재로 추출할 수 있게 될 것 같습니다!