[Pytorch]torch.contiguous()

ma-kjh·2024년 12월 23일
0

Pytorch

목록 보기
23/25

contiguous()는 PyTorch에서 텐서의 메모리 레이아웃을 연속적(contiguous)으로 변환하는 역할을 함. 이 함수는 텐서가 연속적인 메모리 레이아웃을 가지지 않는 경우 이를 새로운 텐서로 만들어 반환하는 역할.

1. 연속적 메모리(contiguous memory)란?

PyTorch 텐서는 기본적으로 데이터를 행우선(row-major) 순서로 저장합니다. 하지만, 슬라이싱, 전치(transpose) 등의 연산은 텐서의 메모리 레이아웃을 변경하여, 원래의 연속적 메모리에서 비연속적(non-contiguous)인 상태로 만들 수 있다.

  • 연속적 텐서: 메모리 상에서 값들이 순차적으로 저장됨.
  • 비연속적 텐서: 메모리 상에서 값들이 순차적이지 않고, 내부적으로 인덱싱을 통해 접근.

예시:

import torch

# 연속적 텐서
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x.is_contiguous())  # True

# 전치 연산 후 비연속적 텐서
x_t = x.transpose(0, 1)
print(x_t.is_contiguous())  # False

2. contiguous()를 사용하는가?

PyTorch의 대부분의 연산(특히 C/C++로 구현된 저수준 연산)은 텐서가 연속적 메모리 레이아웃을 가질 때만 효율적으로 작동.

  • 비연속적인 텐서를 사용할 경우, contiguous()를 호출하여 연속적인 메모리 레이아웃으로 변환해야 한다.
  • 변환하지 않으면 연산 중에 RuntimeError가 발생하거나 성능이 저하될 수 있음.

예제: 비연속적 텐서의 문제와 해결

import torch
import torch.nn.functional as F

# 비연속적 텐서 생성
x = torch.randn(2, 3, 4)
x_t = x.transpose(1, 2)  # 메모리 레이아웃 변경
print(x_t.is_contiguous())  # False

# 연산 시 문제 발생
try:
    F.softmax(x_t, dim=-1)
except RuntimeError as e:
    print(e)

# contiguous로 해결
x_t_contig = x_t.contiguous()
F.softmax(x_t_contig, dim=-1)  # 정상 작동

3. get_batch_loss 함수에서의 역할

contiguous()get_batch_loss 함수에서 사용된 이유는, 슬라이싱(labels[..., 1:], output[..., :-1, :]) 이후 텐서가 비연속적일 가능성이 있기 때문임. 이를 보장하지 않으면 아래 연산에서 문제가 발생할 수 있음:

(1) shifted_labels

shifted_labels = labels[..., 1:].contiguous()
  • 슬라이싱(labels[..., 1:])으로 인해 텐서가 비연속적으로 변환될 가능성이 있음.
  • contiguous()를 호출하여, 슬라이싱 이후의 텐서를 연속적인 메모리 레이아웃으로 변환.

(2) output

output = output[..., :-1, :].contiguous()
  • 출력 텐서도 슬라이싱(output[..., :-1, :])으로 인해 비연속적일 가능성이 있음.
  • contiguous()를 호출하여, 이후의 연산(CrossEntropyLoss)에서 문제가 발생하지 않도록 함.

4. 왜 슬라이싱이 비연속적일 가능성이 있는가?

슬라이싱 연산은 텐서를 복사하지 않고, 원래의 텐서에서 뷰(view)를 생성. 이 과정에서 메모리 레이아웃이 변경될 수 있어, 비연속적 상태가 될 가능성이 높습니다.

예시:

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = x[:, 1:]  # 슬라이싱
print(y.is_contiguous())  # False

5. 요약

  • contiguous()는 텐서의 메모리 레이아웃을 연속적으로 만들어, PyTorch 연산이 정상적으로 작동하도록 보장한다.
  • 슬라이싱([..., :-1], [..., 1:])이나 전치(transpose) 이후 텐서는 비연속적일 가능성이 있으므로, 이를 해결하기 위해 contiguous()를 사용.
  • get_batch_loss 함수에서는 CrossEntropyLoss 연산에서 메모리 레이아웃 문제를 방지하기 위해 사용됨.
profile
거인의 어깨에 올라서서 더 넓은 세상을 바라보라 - 아이작 뉴턴

0개의 댓글