view는 PyTorch에서 텐서의 차원을 재구조화하는 데 사용되는 함수이다. view를 사용하면 텐서의 데이터를 재배열하지 않고, 단순히 차원만 변경할 수 있는데, 이는 매우 효율적이며 메모리 사용을 최소화하기에 많이들 사용한다.
view 함수 사용법view 함수는 새로운 차원을 지정하여 텐서를 재구조화한다. 예를 들어, 1차원 텐서를 2차원 텐서로 변환할 수 있다.
import torch
# 1차원 텐서 생성
x = torch.arange(12)
print(x) # tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# 2차원 텐서로 변환
x_view = x.view(3, 4)
print(x_view)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
view와 reshape의 차이view와 reshape는 비슷한 기능을 하지만, 몇 가지 차이점있다:
view: 텐서가 메모리에서 연속적(contiguous)일 때만 사용 가능하다.view를 사용하면 오류가 발생한다.reshape: 연속적이지 않은 텐서에도 사용할 수 있으며, 필요에 따라 새로운 메모리 공간을 할당하여 데이터를 복사한다.# 연속적이지 않은 텐서 생성
y = torch.ones(3, 4).transpose(0, 1)
print(y.is_contiguous()) # False
# view 사용 시 오류 발생
try:
y_view = y.view(2, 6)
except RuntimeError as e:
print(e) # RuntimeError: view size is not compatible with input tensor's size and stride
# reshape 사용 시 정상 작동
y_reshape = y.reshape(2, 6)
print(y_reshape)
PyTorch에서 텐서가 contiguous하다는 것은 메모리에서 데이터가 연속적으로 저장되어 있다는 의미이다.
이는 텐서의 데이터가 메모리에서 순차적으로 배치되어 있어, 특정 연산을 수행할 때 더 효율적으로 접근할 수 있음을 의미하는데, 때로 이 contiguous한 텐서가 함수를 거치며 또는 특정 작업을 거치며 uncontiguous하게 되어 tensor 연산시에 에러가 날 때가 있다.
Stride는 텐서의 각 차원에서 다음 요소로 이동하기 위해 건너뛰어야 하는 메모리 위치의 수를 나타낸다.
예를 들어, 2차원 텐서에서 stride는 행과 열을 따라 이동할 때의 메모리 간격을 나타낸다.
PyTorch에서 텐서의 contiguous 여부와 stride를 설명하는 예시:
import torch
# 2차원 텐서 생성
a = torch.randn(3, 4)
print("Original Tensor:")
print(a)
# 텐서의 stride 확인
print("Original Stride:", a.stride())
# 텐서의 contiguous 여부 확인
print("Is Contiguous:", a.is_contiguous())
# 텐서의 transpose 연산
b = a.transpose(0, 1)
print("Transposed Tensor:")
print(b)
# Transposed 텐서의 stride 확인
print("Transposed Stride:", b.stride())
# Transposed 텐서의 contiguous 여부 확인
print("Is Contiguous:", b.is_contiguous())
# Transposed 텐서를 contiguous 상태로 변경
b_contiguous = b.contiguous()
print("Contiguous Transposed Tensor:")
print(b_contiguous)
# Contiguous Transposed 텐서의 stride 확인
print("Contiguous Transposed Stride:", b_contiguous.stride())
# Contiguous Transposed 텐서의 contiguous 여부 확인
print("Is Contiguous:", b_contiguous.is_contiguous())
출력
Original Tensor:
tensor([[ 0.1234, -0.5678, 1.2345, -1.6789],
[ 0.9876, -0.5432, 1.0987, -1.2345],
[ 0.8765, -0.4321, 1.9876, -1.5432]])
Original Stride: (4, 1)
Is Contiguous: TrueTransposed Tensor:
tensor([[ 0.1234, 0.9876, 0.8765],
[-0.5678, -0.5432, -0.4321],
[ 1.2345, 1.0987, 1.9876],
[-1.6789, -1.2345, -1.5432]])
Transposed Stride: (1, 4)
Is Contiguous: FalseContiguous Transposed Tensor:
tensor([[ 0.1234, 0.9876, 0.8765],
[-0.5678, -0.5432, -0.4321],
[ 1.2345, 1.0987, 1.9876],
[-1.6789, -1.2345, -1.5432]])
Contiguous Transposed Stride: (3, 1)
Is Contiguous: True
Original Tensor:
a 텐서는 (3, 4) 크기의 텐서이다.a.stride()는 (4, 1)을 반환하는데, 이는 행을 따라 이동할 때 4개의 메모리 위치를 건너뛰고, 열을 따라 이동할 때 1개의 메모리 위치를 건너뛴다는 의미.a.is_contiguous()는 True를 반환.Transposed Tensor:
b 텐서는 a의 전치(transpose)된 텐서로, 크기는 (4, 3)b.stride()는 (1, 4)를 반환. 이는 행을 따라 아래로 이동할 때 1개의 메모리 위치를 건너뛰고, 열을 오른쪽으로 따라 이동할 때 4개의 메모리 위치를 건너뛰는걸 의미함.b.is_contiguous()는 False를 반환b가 메모리에서 연속적으로 저장되지 않음을 의미함.Contiguous Transposed Tensor:
b_contiguous는 b를 contiguous 상태로 변경한 텐서b_contiguous.stride()는 (3, 1)을 반환하는데, 이는 새로운 메모리 배치에서 행을 따라 이동할 때 3개의 메모리 위치를 건너뛰고, 열을 따라 이동할 때 1개의 메모리 위치를 건너뛰는걸 의미.b_contiguous.is_contiguous()는 True를 반환contiguous() 메소드를 사용하여 메모리에서 연속적으로 저장되도록 변경할 수 있다.추가적으로,
view도, reshape는 메모리 블록이 연속적이냐 의 가부로 작동여부가 갈리지만, 둘 모두 데이터의 순서 자체는 변경하지 않는다.