transpose(input, dim0, dim1) | t() |
---|---|
dim1차원과 dim0차원을 transpose하여 return | 2D 이하의 tensor를 input으로 받아 0차원과 1차원을 transpose |
a = torch.randint(0, 10, size = (3, 2))
a_trans = torch.transpose(a, 0, 1)
a_t = a.t()
print("a", a)
print("a.shape", a.shape)
print("after transpose", a_trans)
print("after t()", a_t)
print("after transpose, contiguous?", a_trans.is_contiguous())
print("after t(), contiguous?", a_t.is_contiguous())
transpose 연산을 통해 0차원과 1차원이 바뀐 것을 확인할 수 있습니다; 3 by 2 2 by 3
또한, transpose와 t()를 수행한 이후의 값인 a_trans와 a_t tensor의 값들은 메모리에 연속적으로 할당되어 있지 않음을 확인할 수 있습니다.
view와 reshape은 공통적으로 tensor의 shape을 바꾸는 작업을 수행합니다. 그러나 view와 reshape은 input의 contiguity에 따른 차이가 있습니다.
view | reshape |
---|---|
contiguous (x) 수행 X | contiguous (x) copy 하여 reshape 수행 |
view: input tensor의 뷰를 생성한다는 개념으로, view 작업을 통해 얻은 새로운 tensor는 원래 tensor와 데이터를 공유합니다. 따라서 view를 적용하기 위해서는 contiguity 제약 조건을 만족해야 합니다.
reshape: input tensor가 contiguity 제약 조건을 만족하지 않는다면, copy 후 작업을 수행합니다.
참고문헌
https://inmoonlight.github.io/2021/03/03/PyTorch-view-transpose-reshape/