[Lucid] weakref 활용을 통한 메모리 최적화

안암동컴맹·2026년 1월 2일

Lucid Development

목록 보기
19/20
post-thumbnail

💿 weakref 활용을 통한 메모리 최적화

이번 일지는 weakref를 통해 메모리 누수를 줄인 과정을 정리한다. 주제 자체가 low-level이고, Python의 내부 메모리 관리와 연결되어 있어 기술적으로 깊게 다룬다. 특히 Lucid의 자동미분 그래프가 커졌을 때 발생한 메모리 증가 문제를 어떻게 분석했고, 그 원인이 참조 그래프의 구조에 있음을 어떻게 확인했는지, 그리고 어떤 방식으로 weakref를 적용했는지까지 상세히 기록한다.


📉 문제: 메모리 사용량의 비정상적 증가

가장 먼저 체감했던 문제는 iteration이 진행될수록 메모리 사용량이 누적된다는 점이었다.

단순한 forward/backward를 반복했는데도 RSS가 감소하지 않았다. 특히 retain_graph=False가 기본인 상황에서도 그래프가 남아 있었다. 결과적으로 실험이 길어질수록 메모리가 상승했고, 일정 시점 이후에는 시스템이 swap memory을 쓰기 시작했다.

증상은 크게 세 가지로 정리할 수 있었다.

  • forward/backward가 끝난 후에도 TensorOperation의 개수가 줄지 않는다.
  • 특정 연산을 많이 호출하는 그래프에서 메모리 증가가 더 빨라진다.
  • gc.collect()를 호출해도 기대만큼 회수되지 않는다.

Lucid의 자동미분 구조는 상대적으로 얇지만, 그만큼 참조 구조가 단순해 보이는 착각이 있었다. 하지만 실제로는 작은 사이클들이 많이 생겨난다는 사실이 드러났다.

🧪 재현을 위한 최소 시나리오

문제는 아래와 같은 간단한 스크립트에서도 나타났다. 실제 학습 루프에서는 더 빠르게 악화되었다.

import lucid

for i in range(1000):
    x = lucid.random.randn(1024, 1024, requires_grad=True)
    y = (x @ x).sum()
    y.backward()

이 루프에서 기대했던 동작은 다음과 같다.

  1. 각 iteration에서 생성된 그래프는 backward 이후 해제된다.
  2. iteration 간에 살아남는 텐서는 입력 텐서 외에는 거의 없다.
  3. 따라서 메모리는 일정 범위 내에서 진동한다.

하지만 실제로는 iteration이 지나도 그래프가 계속 남는 현상이 관찰되었다.

📊 메모리 프로파일링

조사를 위해 다음 도구를 사용했다.

  • tracemalloc으로 큰 객체의 위치를 추적
  • gc.get_objects()로 Tensor/Operation 수를 측정
  • gc.get_referrers()로 참조 그래프를 확인
  • 간단한 id() 기반 스냅샷 비교로 객체 생존 기간 추적

가장 눈에 띄는 점은 다음과 같았다.

  • Operation 객체가 줄어들지 않는다.
  • Tensor._op 경로로 계속 살아있는 텐서가 있다.

결국 핵심은 참조 구조가 해제되기 어렵게 되어 있었다는 점이다.

🧭 원인 탐색 과정

Lucid의 그래프 구조는 크게 아래처럼 요약된다.

  • 각 연산은 Operation 객체로 캡슐화된다.
  • 결과 Tensor는 Tensor._op에 연산을 기록한다.
  • 결과 Tensor는 Tensor._prev에 입력 텐서를 저장한다.
  • backward를 위해 BackwardOperation이 생성된다.

문제는 이 구조가 순환 참조를 쉽게 만든다는 점이다. 예를 들어 가장 단순한 연산만 보더라도 다음 순환이 생성된다.

Tensor -> _op (Operation)
Operation -> result (Tensor)

이것만으로도 사이클이 발생한다. 또한 BackwardOperation이 입력 텐서들을 강하게 참조하면, 그 텐서들이 다시 _op를 통해 다른 객체를 붙잡게 된다. 즉, 그래프 전체는 작은 사이클들의 집합이 된다. 그리고 Python의 GC는 이를 즉시 회수하지 않는다.

🧠 Python의 메모리 관리 모델

문제를 제대로 이해하려면 CPython의 메모리 모델을 이해해야 했다. 핵심은 두 가지다.

  1. Reference counting
  2. Cyclic GC

1️⃣ Reference Counting

CPython은 객체마다 ob_refcnt를 갖는다. 강한 참조가 늘면 refcount가 증가하고, 강한 참조가 사라지면 refcount가 감소한다. refcount가 0이 되는 순간 객체는 즉시 해제된다.

이 메커니즘은 빠르고 예측 가능하지만, 순환 참조(cycle)에는 취약하다.

2️⃣ Cyclic GC

순환 참조는 refcount가 0이 되지 않기 때문에, 별도의 cyclic GC가 이를 처리한다. CPython은 generational GC를 사용한다.

  • 0세대는 자주 검사한다.
  • 1세대는 덜 자주 검사한다.
  • 2세대는 가장 덜 자주 검사한다.

객체가 여러 번 GC를 통과하면 더 높은 세대로 승격된다. 그래프가 커질수록 high generation에 잔류하는 객체가 많아진다. 결국 GC 주기가 길어지고, 메모리 회수 시점이 뒤로 밀린다.

이 상황에서 큰 그래프를 계속 만들면, 메모리는 실제로 누수처럼 보이는 현상을 만든다.

🌊 왜 weakref가 필요한가

기본 구조는 다음처럼 요약된다.

  • 연산 결과 텐서가 연산 객체를 참조한다.
  • 연산 객체가 결과 텐서를 참조한다.
  • backward를 위해 텐서들이 다시 참조된다.

이 구조에서 한 쪽을 약하게 만들어야 사이클이 끊어진다. weakref는 이 역할에 적합하다. weakref는 참조 카운트를 증가시키지 않는다. 따라서 약한 참조만 남은 객체는 바로 해제될 수 있다. 그리고 해제된 후에는 약한 참조가 None을 반환한다.

즉, weakref는 다음을 가능하게 한다.

  • 사이클을 깨서 즉시 해제
  • 객체의 생존 여부를 안전하게 체크
  • 구조적 연결은 유지하되, 생존 강제는 막음

이 전략은 autodiff 그래프에 아주 잘 맞는다. 그래프는 필요할 때만 유지하면 되고, iteration이 끝나면 해제되어야 한다.

🧩 weakref의 이론적 차이

강한 참조와 약한 참조의 차이는 단순하지만 매우 중요하다.

💪🏻 강한 참조

  • obj = other 형태
  • 참조 카운트를 증가시킨다
  • 객체 생존을 보장한다

👶🏻 약한 참조

  • weakref.ref(obj) 형태
  • 참조 카운트를 증가시키지 않는다
  • 객체 생존을 보장하지 않는다
  • 참조 대상이 해제되면 None을 반환한다

이 차이는 그래프의 수명 설계에 직접 영향을 준다. 강한 참조는 구조의 안정성을 보장하지만, 약한 참조는 구조의 연결만 유지하면서 수명을 조절한다.

🧰 Lucid에서 weakref를 적용한 구체적 지점

Lucid 코드에서 실제로 약한 참조를 도입한 부분은 크게 두 군데였다.

  1. BackwardOperation이 입력 텐서를 강하게 붙잡는 구조
  2. Operation이 결과 텐서를 강하게 붙잡는 구조

특히 1번이 큰 병목이었다.

1️⃣ BackwardOperation에서의 weakref

기존에는 backward에 필요한 입력 텐서가 강한 참조로 저장되었다. 이는 backward가 끝난 뒤에도 텐서를 살려두는 문제가 있었다. 그래서 다음과 같이 변경했다.

# before
self.tensor_refs = tensors

# after
self.tensor_refs = tuple(weakref.ref(t) for t in tensors)

이렇게 하면 BackwardOperation은 텐서의 생존을 강제하지 않는다. backward가 실행될 때만 텐서가 살아있다면 충분하다. 이미 해제되었다면 ref()None을 반환하고, 그 경우 backward는 조용히 종료한다.

이 구조는 실제로 Lucid의 core.py에 반영되어 있다.

tensor_refs = tuple(weakref.ref(t) for t in tensors)
live_tensors = tuple(ref() for ref in self.tensor_refs)
if any(t is None for t in live_tensors): 
    return

이 패턴은 weakref 사용의 핵심이다.

2️⃣ forward op 참조의 약화

BackwardOperation은 내부적으로 forward_op_ref를 가진다. 이 또한 강한 참조일 필요가 없다. 이미 forward 결과가 사라졌다면, 그 op 역시 해제 가능한 것이 자연스럽다. 그래서 forward_op_ref도 weakref로 관리했다.

forward_op_ref = weakref.ref(op_self)

이렇게 하면 backward는 필요한 최소한의 정보만 유지한다.

🔍 참조 그래프의 변화

weakref 도입 전의 그래프는 아래에 가깝다.

Tensor -> _op -> Operation -> result -> Tensor
Tensor -> _backward_op -> tensor_refs -> Tensor

weakref 도입 후에는 다음과 같이 바뀐다.

Tensor -> _op -> Operation -> result -> Tensor
Tensor -> _backward_op -> tensor_refs (weak) -> Tensor

첫 번째 사이클은 여전히 존재한다. 하지만 두 번째 경로가 약해졌기 때문에, 전체 그래프가 훨씬 쉽게 회수된다. 이는 중요한 설계 포인트다. weakref는 모든 참조를 약하게 만들 필요는 없다. 핵심 경로만 약하게 만들어도 그래프의 수명이 크게 줄어든다.

🧯 모든 참조를 weakref로 바꾸지 않은 이유

weakref는 유용하지만 만능이 아니다. 다음과 같은 이유로 무분별하게 쓰면 문제가 생긴다.

  • 약한 참조는 생존을 보장하지 않는다.
  • 예상보다 빨리 해제되면 backward가 불완전해질 수 있다.
  • 유효성 체크를 늘려야 하고 코드가 복잡해진다.

따라서 weakref는 생존이 필수적이지 않은 경로에만 도입하는 것이 중요하다.

Lucid에서는 다음 기준을 사용했다.

  • backward 실행 시점에만 살아있으면 되는 객체는 weakref.
  • 그래프 구조 자체를 유지해야 하는 객체는 strong ref.
  • 디버깅/검증 목적의 임시 참조는 상황에 따라 제거.

이 기준을 세우고 적용한 것이 이번 최적화의 핵심이다.

🔄 참조 카운팅과 closure 문제

또 하나 중요한 원인은 grad 함수의 closure였다. Python의 closure는 캡처한 변수를 강하게 참조한다. 그리고 grad 함수가 BackwardOperation에 저장되면, 그 closure가 텐서를 붙잡고 있을 수 있다.

예를 들어 다음과 같은 구조다.

def grad_func():
    return a.grad_rule(b, c)

여기서 a, b, c가 Tensor라면, grad_func는 해당 텐서를 강하게 참조한다. 그 결과 BackwardOperation은 텐서를 붙잡게 된다.

이 문제를 해결하기 위해, 가능한 경우 grad 함수는 직접 텐서를 캡처하지 않고 참조를 간접화했다. Lucid에서는 이를 위해 tensor_refs를 weakref로 유지하고, grad_func는 필요할 때 그 참조를 다시 불러오도록 설계했다.

이 구조는 다음과 같은 장점이 있다.

  • closure가 텐서를 강하게 붙잡지 않는다.
  • backward 시점에 살아있는 텐서만 접근한다.
  • 메모리 그래프가 빠르게 회수된다.

⚙️ Python 메모리 할당자 관점에서 본 문제

Python 객체가 해제되더라도 OS에 바로 반환되지 않는 경우가 많다. 이 때문에 RSS가 쉽게 줄지 않는다. 이 현상을 이해하려면 CPython의 allocator 구조를 알아야 한다.

  • 작은 객체는 pymalloc이 관리한다.
  • 큰 객체는 시스템 malloc으로 간다.
  • 해제된 메모리는 arena 또는 pool에 남는다.
  • OS는 이를 바로 회수하지 않을 수 있다.

즉, 객체가 해제되었어도 RSS가 줄지 않는 현상은 흔하다. 하지만 이번 문제는 그보다 더 심각했다. 객체 자체가 해제되지 않고 계속 살아 있었다. weakref는 이 구조에서 객체 생존 자체를 줄이는 역할을 한다. 따라서 allocator의 특성에 상관없이 메모리 증가를 늦춘다.


🔧 실제 코드 레벨 적용 예시

Lucid의 핵심 변경을 요약하면 아래와 같다.

# core.py
import weakref

# forward path
tensor_refs = tuple(weakref.ref(t) for t in tensors)

result._backward_op = BackwardOperation(
    forward_op_ref=weakref.ref(op_self),
    grad_func=grad_func,
    tensor_refs=tensor_refs,
    device=device,
)
# BackwardOperation.__call__
live_tensors = tuple(ref() for ref in self.tensor_refs)
if any(t is None for t in live_tensors):
    return

이 구조는 다음의 의미를 가진다.

  • backward는 텐서를 강하게 붙잡지 않는다.
  • 필요 시점에만 접근한다.
  • 이미 해제된 텐서가 있으면 스킵한다.

즉, 그래프의 수명은 더 짧아지고, 메모리 회수 타이밍은 빨라진다.

🧩 텐서 그래프 관점에서 본 변화

weakref 적용 전후의 그래프 성질은 다음과 같이 정리된다.

◀️ 적용 전

  • 많은 텐서가 backward 이후에도 살아있음
  • Operation 객체가 결과를 강하게 붙잡음
  • grad closure가 입력 텐서를 강하게 붙잡음
  • GC 주기에 따라 회수 시점이 결정됨

▶️ 적용 후

  • backward 종료 직후 텐서가 빠르게 해제됨
  • Operation 객체만 남고, 결과는 빠르게 해제됨
  • tensor_refs가 약해져 cycle이 완화됨
  • GC가 늦게 돌아도 메모리 증가가 덜함

이 차이가 체감 성능과 안정성을 크게 바꿨다.

🔬 내부 동작: weakref의 생애주기

weakref는 객체의 생애주기에 깊게 관여한다. 흐름을 요약하면 다음과 같다.

  1. 객체 A가 생성된다.
  2. weakref.ref(A)가 생성된다.
  3. A의 내부 weakref 리스트에 등록된다.
  4. A의 refcount가 0이 되면 A가 해제된다.
  5. A가 해제되는 순간 weakref는 None으로 무효화된다.

이 과정에서 중요한 사실은, weakref 자체는 A의 수명을 연장하지 않는다는 점이다. 따라서 weakref는 생존이 보장되지 않는 참조다.

🧵 weakref 적용 이후의 코드 가이드라인

이번 작업 이후, 다음과 같은 기준을 내부 가이드로 정했다.

  • backward 경로에서 텐서를 강하게 붙잡지 않는다.
  • 결과를 저장하는 캐시는 필요 시점에만 강하게 유지한다.
  • 그래프를 유지해야 하는 경우에는 명시적으로 옵션을 둔다.
  • 디버깅 시에는 keep_grad를 활용한다.

이는 향후 다른 모듈을 확장할 때도 유용한 기준이 된다.


✅ 마무리

weakref는 흔히 피해야 하는 것처럼 보이기도 하지만, 그래프 기반 시스템에서는 오히려 필수에 가까운 도구였다. 이번 최적화는 Lucid의 내부 구조를 더 선명하게 만들었다.

이후의 개발에서는 다음을 기준으로 삼을 생각이다.

  • 객체의 생존 이유를 항상 명시한다.
  • 강한 참조는 의도적이어야 한다.
  • 필요 없는 참조는 제거하거나 약하게 만든다.

이런 기준이 유지된다면, Lucid는 더 큰 모델과 더 긴 학습에도 안정적으로 대응할 수 있을 것이다.

profile
Korea Univ. Computer Science & Engineering

0개의 댓글