자주 사용되는 함수인데, 머릿 속에 잘 정리가 안되어서 하나하나 살펴보았다.
파이토치 공식 문서에는 위와 같이 적혀있다.
인자로는 input Tensor가 있고 dimension이 있다.
첫번째 줄 설명을 보면, '특정 위치에 dimension of size one을 insert한 새로운 Tensor를 반환한다'라고 되어있다.
여기에서 특정 위치는 파라미터 중 dim에 의해서 결정된다.
그러니까 dim으로 지정한 위치에 사이즈 1짜리 차원을 하나 넣어준다는 뜻인데...
예시를 통해 살펴보자.
import torch
x = torch.Tensor([[1,2,3,4], [1,2,3,4]])
요렇게 생긴 텐서를 만들어준다.
x.unsqueeze(0)
x.unsqueeze(0).shape
dim = 0
에 대해서 unsqueeze해주면 첫번째 차원 ([1, 2, 4]) 에 사이즈가 1인 dimension을 insert했다.
x.unsqueeze(1)
x.unsqueeze(1).shape
dim = 1
에 대해서 unsqueeze해주면 두번째 차원 ([2, 1, 4]) 에 사이즈가 1인 dimension을 insert했다.
x.unsqueeze(2)
x.unsqueeze(2).shape
dim = 2
일 때도 마찬가지로 되었다. ([2, 4, 1])
dimension을 insert한다는 개념이 조금 낯설지만, 위 예제를 통해서 직관적으로 이해할 수 있다.
unsqueeze라는 뜻이 squeeze를 un 한다는 뜻이면, 1차원으로([1,2,3]) 쥐어짜져(squeeze)있던 것을 2차원으로 unsqueeze([[1,2,3]])이렇게 풀어준다.
그냥 내가 직관적으로 생각하기에는 쥐어짜져 있던 걸 푸니까 약간 입체화 되는 느낌으로...
되는 게 아닐까 생각해봤다.
뭔가 그림으로 더 이해하고 싶어서 다른 예제를 가져와서 그림으로 그려보았다. 그런데 내 마음대로 그린 거라 맞을 지는 모르겠다. 요렇게 생기지 않았을까... 하는 추측이다.
x = torch.Tensor([2,4,6])
요런 1차원 텐서를 가져와서 unsqueeze 하면 어떤 모양이 되는 지 살펴보겠다.
x.unsqueeze(0)
x.unsqueeze(0).shape
이렇게 되고 이걸 그림으로 이해해보고 싶었다.
x.unsqueeze(1)
아니면 이건가...?
잘 모르겠다. 혹시 이 글을 보시는 분이 아신다면 피드백을 부탁드립니다.. 잘 모르겠다... 근데 궁금하다........