PyTorch의 연산 그래프와 효율적인 메모리 관리

SeongGyun Hong·2024년 12월 26일
post-thumbnail

연산 그래프란 PyTorch에서 텐서 간의 연산 관계를 추적하여, 역전파(backpropagation)를 수행할 때 각 파라미터의 기울기(gradient)를 자동으로 계산할 수 있도록 지원하는 구조를 의미한다.
대표적으로 loss.backward()를 호출했을 때, 연산 그래프를 통해 손실(loss)로부터 각 파라미터의 기울기가 계산됨.

1. 기본 동작: 연산 그래프 생성

PyTorch는 텐서에 연산을 가하면 그 결과 텐서를 단순히 반환하지 않고, 해당 텐서를 연산 그래프에 연결된 상태로 남겨둔다..
왜냐하면 추후에 미분 및 역전파를 지원해야 하기 때문이다.

1.1 연산 그래프가 불필요한 경우

코드를 작성하다 보면, 모든 연산에 대해 역전파가 필요하지는 않은 경우가 많다.
예를 들어, 손실 값(loss)을 기록하거나 단순히 텐서 값을 출력하는 경우에는 연산 그래프가 필요하지 않다.
그럼에도 불구하고 그래프가 생성되면 불필요한 추가 연산메모리 사용량 증가를 초래할 수 있기에 이는 최적화하는 것이 좋다.

1.2 연산 그래프를 끊는 방법

이러한 불필요한 연산 그래프 생성을 방지하려면 detach() 메서드를 사용해 텐서를 그래프에서 분리하거나, 필요 시 데이터만 가져오도록 처리하는 작업이 필요하다.

detach():

  • 해당 텐서를 연산 그래프에서 분리해주는 메서드

  • 데이터는 GPU에 그대로 유지되며, 추가 연산이 필요 없는 경우 적합

    detached_tensor = tensor.detach()  # 그래프에서 분리된 텐서 반환

.item():

  • 텐서의 값을 Python의 float 타입으로 변환하며, 데이터를 GPU에서 CPU로 이동시킨다.

  • 반복적으로 호출하거나 큰 데이터에서 사용하면 성능 저하를 초래할 수 있다.

    scalar_value = tensor.item()  # CPU로 데이터를 이동 후 값 반환

1.3 성능 비교: .item() vs detach()

  • .item()은 물론 텐서그래프에서 해당 텐서를 detach하는 역할을 하기도 하지만, 이는 값을 반환하기 위해 GPU에서 CPU로 데이터를 이동하는 동작에서 부산물 처럼 얻어지는 결과이고 GPU -> CPU로 이동하는 해당 작업에서 추가 비용이 발생된다.
  • 반면, detach()는 연산 그래프만 제거하고 데이터는 GPU에 남기므로 배치단위 연산을 하는 경우에 .item()과 잘 엮어 활용하면 배치단위에서는 detach()로 연산하고 해당 epoch가 끝났을 때 item()을 사용하여 연산하면 메모리 효율을 노릴 수 있다.

2. detach(), item()을 사용했을 때 연산 그래프 분리 여부 확인 실습

import torch

# GPU 또는 CPU에서 실행할 수 있도록 장치 설정
device = "cuda" if torch.cuda.is_available() else "cpu"

# requires_grad=True로 연산 그래프 생성
x = torch.tensor([2.0], requires_grad=True, device=device)

# 연산 수행 (x를 기반으로 새로운 텐서 생성)
y = x * 3  # y = 3 * x
z = y ** 2  # z = y^2 = (3 * x)^2

# 원래 텐서(x)의 기울기 계산 여부 확인
z.backward()  # z는 연산 그래프에 연결된 상태
print("x.grad:", x.grad)  # 기울기 확인


# 1. .item() 사용하여 값 추출
scalar_value = z.item()

# 2. .item() 호출 후 z로 backward() 시도
try:
    z.backward()
except RuntimeError as e:
    print(f"RuntimeError 발생: {e}")

# 3. .detach()로 연산 그래프를 끊고 값 확인
detached_z = z.detach()
print("detach() 사용 후 값:", detached_z)

try:
    z.backward()
except RuntimeError as e:
    print(f"RuntimeError 발생: {e}")

출력값

x.grad: tensor([36.], device='cuda:0')
RuntimeError 발생: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
detach() 사용 후 값: tensor([36.], device='cuda:0')
RuntimeError 발생: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

3. 요약

연산 그래프가 필요하지 않은 경우에는 detach()를 적극 활용하고, 반드시 값만 필요할 때 .item()을 사용하자!

profile
헤매는 만큼 자기 땅이다.

0개의 댓글