contiguous()
는 PyTorch에서 텐서의 메모리 레이아웃을 연속적(contiguous)으로 변환하는 역할을 함. 이 함수는 텐서가 연속적인 메모리 레이아웃을 가지지 않는 경우 이를 새로운 텐서로 만들어 반환하는 역할.
PyTorch 텐서는 기본적으로 데이터를 행우선(row-major) 순서로 저장합니다. 하지만, 슬라이싱, 전치(transpose) 등의 연산은 텐서의 메모리 레이아웃을 변경하여, 원래의 연속적 메모리에서 비연속적(non-contiguous)인 상태로 만들 수 있다.
import torch
# 연속적 텐서
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x.is_contiguous()) # True
# 전치 연산 후 비연속적 텐서
x_t = x.transpose(0, 1)
print(x_t.is_contiguous()) # False
contiguous()
를 사용하는가?PyTorch의 대부분의 연산(특히 C/C++로 구현된 저수준 연산)은 텐서가 연속적 메모리 레이아웃을 가질 때만 효율적으로 작동.
contiguous()
를 호출하여 연속적인 메모리 레이아웃으로 변환해야 한다.import torch
import torch.nn.functional as F
# 비연속적 텐서 생성
x = torch.randn(2, 3, 4)
x_t = x.transpose(1, 2) # 메모리 레이아웃 변경
print(x_t.is_contiguous()) # False
# 연산 시 문제 발생
try:
F.softmax(x_t, dim=-1)
except RuntimeError as e:
print(e)
# contiguous로 해결
x_t_contig = x_t.contiguous()
F.softmax(x_t_contig, dim=-1) # 정상 작동
get_batch_loss
함수에서의 역할contiguous()
가 get_batch_loss
함수에서 사용된 이유는, 슬라이싱(labels[..., 1:]
, output[..., :-1, :]
) 이후 텐서가 비연속적일 가능성이 있기 때문임. 이를 보장하지 않으면 아래 연산에서 문제가 발생할 수 있음:
shifted_labels
shifted_labels = labels[..., 1:].contiguous()
labels[..., 1:]
)으로 인해 텐서가 비연속적으로 변환될 가능성이 있음.contiguous()
를 호출하여, 슬라이싱 이후의 텐서를 연속적인 메모리 레이아웃으로 변환.output
output = output[..., :-1, :].contiguous()
output[..., :-1, :]
)으로 인해 비연속적일 가능성이 있음.contiguous()
를 호출하여, 이후의 연산(CrossEntropyLoss)에서 문제가 발생하지 않도록 함.슬라이싱 연산은 텐서를 복사하지 않고, 원래의 텐서에서 뷰(view)를 생성. 이 과정에서 메모리 레이아웃이 변경될 수 있어, 비연속적 상태가 될 가능성이 높습니다.
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x[:, 1:] # 슬라이싱
print(y.is_contiguous()) # False
contiguous()
는 텐서의 메모리 레이아웃을 연속적으로 만들어, PyTorch 연산이 정상적으로 작동하도록 보장한다.[..., :-1]
, [..., 1:]
)이나 전치(transpose
) 이후 텐서는 비연속적일 가능성이 있으므로, 이를 해결하기 위해 contiguous()
를 사용.get_batch_loss
함수에서는 CrossEntropyLoss 연산에서 메모리 레이아웃 문제를 방지하기 위해 사용됨.