[Lucid] nn.Module 구현

안암동컴맹·2025년 12월 11일

Lucid Development

목록 보기
8/17
post-thumbnail

🧩 nn.Module 구현

Lucid에서 nn.Module은 모델 정의의 단위이자 파라미터/버퍼/서브모듈을 조직하는 뼈대다. PyTorch의 nn.Module을 벤치마킹해 동일한 사용성을 목표로 했고, 파이프라인 전체(autograd, state_dict, 디바이스 이동, hooks 등)가 자연스럽게 이어지도록 설계했다. 이 글은 그런 목표 아래 lucid/nn/module.py를 어떻게 구현했고, 메서드별 시그니처와 핵심 로직을 어떤 선택으로 구성했는지 상세히 기록한다.


🧭 배경: PyTorch nn.Module과 역할

PyTorch에서 nn.Module

  • 파라미터(nn.Parameter)와 버퍼(러닝 스탯 등)를 자동 등록/관리하고,
  • 서브모듈을 트리 구조로 묶어 순회와 저장/로드를 단순화하며,
  • forward 정의만 제공하면 __call__로 hooks, autograd 연결을 투명하게 처리한다.

Lucid도 동일한 철학을 따르며, 특히 파라미터/버퍼/서브모듈의 자동 등록, state_dict/load_state_dict, train/eval 전파, 디바이스 이동, hooks를 모사했다.

🧱 초기화와 내부 레지스트리

class Module:
    def __init__(self) -> None:
        object.__setattr__(self, "_parameters", OrderedDict())
        object.__setattr__(self, "_buffers", OrderedDict())
        object.__setattr__(self, "_modules", OrderedDict())

        self.training = True
        self.device: _DeviceType = "cpu"

        self._forward_hooks: list[_ForwardHookType] = []
        self._backward_hooks: list[_BackwardHookType] = []
        self._state_dict_pass_attr = set()
  • _parameters, _buffers, _modulesOrderedDict로 분리해 등록/순회/저장 순서를 보장.
  • training/device 기본값을 설정해 train()/eval()/to() 호출 시 일관성 유지.
  • hook, state_dict 필터링용 내부 리스트/셋 초기화.

구현상의 고민

  • __setattr__에서 자동 등록이 이루어지므로, 생성자에서만 object.__setattr__를 사용해 초기 상태를 보장해야 했다. 이후 일반 속성 대입은 커스텀 로직을 거친다.

🔗 자동 등록: __setattr__ 오버라이드

def __setattr__(self, name: str, value: Any) -> None:
    registry_map = {nn.Parameter: self._parameters, nn.Buffer: self._buffers, Module: self._modules}
    target_registry = None

    for cls, registry in registry_map.items():
        if isinstance(value, cls):
            target_registry = registry
            break

    if target_registry is not None:
        for registry in registry_map.values():
            if registry is not target_registry and name in registry:
                del registry[name]
        target_registry[name] = value

    else:
        for registry in registry_map.values():
            if name in registry:
                del registry[name]

    super().__setattr__(name, value)
  • 역할: 파라미터/버퍼/서브모듈을 속성 대입만으로 자동 등록. 동일 이름에 중복 등록 시 이전 레지스트리에서 제거해 일관성을 유지.
  • 파이토치 모사: register_parameter, register_buffer, add_module 없이도 직접 속성 할당으로 등록 가능.
  • 실패 사례 방지: 다른 레지스트리에 남아 있던 동일 이름을 제거하지 않으면 state_dict 충돌이 발생할 수 있어, 모든 레지스트리를 스캔해 클린업한다.

Raw setattr

setattr_raw는 내부 초기화나 메타정보 조작 시 자동 등록을 우회하는 escape hatch다.

🧾 등록 API: 파라미터/버퍼/모듈

def add_module(self, name: str, module: Self) -> None:
    if not isinstance(module, Module) and module is not None:
        raise TypeError(...)
    self.__setattr__(name, module)

def register_parameter(self, name: str, param: nn.Parameter | None) -> None:
    if not isinstance(param, nn.Parameter) and param is not None:
        raise TypeError(...)
    self.__setattr__(name, param)

def register_buffer(self, name: str, buffer: nn.Buffer | _ArrayOrScalar | None, dtype=None) -> None:
    if buffer is not None and not isinstance(buffer, nn.Buffer):
        buffer = nn.Buffer(buffer, dtype=dtype, device=self.device)
    self.__setattr__(name, buffer)
  • 디자인: PyTorch와 동일하게 None을 허용해 조건부 등록 패턴을 지원.
  • 장치/ dtype 전파: 버퍼는 현재 모듈 device로 생성되며, dtype도 지정 가능.

난관

파라미터가 아닌 일반 Tensor를 실수로 등록하면 autograd 경로에 올라가지만 state_dict에는 빠지는 문제가 생길 수 있다. 이를 막기 위해 타입 체크 후 강제 변환 또는 에러를 던지도록 했다.

🧭 실행 흐름: __call__과 hooks

def __call__(self, *args, **kwargs):
    output = self.forward(*args, **kwargs)
    for hook in self._forward_hooks:
        hook(self, args, output)

    if isinstance(output, Tensor) and self._backward_hooks:
        for hook in self._backward_hooks:
            output.register_hook(hook)

    return output
  • 역할: forward 호출 전후로 hook 실행을 삽입. 출력이 Tensor 하나일 때만 backward hook을 붙여 PyTorch와 동일한 제약을 따른다.
  • 전파: backward hook은 Tensor의 register_hook으로 연결되어 autograd 그래프에 포함된다.

forward 정의

def forward(self) -> Tensor | tuple[Tensor, ...]:
    raise NotImplementedError("The forward method must be implemented by the subclass.")

서브클래스가 필수로 구현해야 하며, 나머지 호출/후처리는 __call__이 담당.

🔄 학습/평가 모드와 디바이스 이동

def train(self, mode: bool = True) -> Self:
    self.training = mode
    for module in self._modules.values():
        module.train(mode)
    return self

def eval(self) -> Self:
    return self.train(mode=False)

def to(self, device: _DeviceType) -> Self:
    if device == self.device: 
        return self
    self.device = device

    for param in self.parameters(recurse=False): 
        param.to(device)
    for buffer in self.buffers(recurse=False): 
        buffer.to(device)
    for module in self.modules(): 
        module.to(device)

    return self
  • 모드 전파: 하위 모듈까지 train/eval 플래그를 재귀적으로 설정.
  • 디바이스 전파: 파라미터/버퍼는 현재 모듈에서 직접 이동, 서브모듈은 modules() 순회로 재귀 이동.

교훈

초기 구현에서 modules() 재귀 호출이 중복 이동을 일으켜 성능 저하가 있었다. parameters(buffers)recurse=False로 자기 자신만 이동하고, 서브모듈 이동은 따로 처리해 한 번씩만 이동하도록 정리했다.

📦 순회와 집계: parameters/buffers/modules/children

def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
    for _, param in self._parameters.items(): yield param
    if recurse:
        for module in self._modules.values(): yield from module.parameters(recurse=recurse)

def buffers(self, recurse: bool = True) -> Iterator[nn.Buffer]:
    for buffer in self._buffers.values(): yield buffer
    if recurse:
        for module in self._modules.values(): yield from module.buffers(recurse=recurse)

def modules(self) -> Iterator[Self]:
    yield self
    for module in self._modules.values(): yield from module.modules()

def children(self) -> Iterator[Self]:
    return iter(self._modules.values())
  • 재귀 옵션: PyTorch와 동일하게 recurse 플래그로 자기 자신만 혹은 트리 전체를 순회.
  • 파생 유틸: count_parameters, parameter_size가 위 순회를 재사용해 총 파라미터 수를 계산한다.

💾 저장/로드: state_dictload_state_dict

def state_dict(self, destination=None, prefix="", keep_vars=False):
    destination = OrderedDict() if destination is None else destination

    for name, param in self._parameters.items():
        destination[prefix + name] = param if keep_vars else param.numpy()

    for name, buffer in self._buffers.items():
        destination[prefix + name] = buffer if keep_vars else buffer.numpy()

    for name, module in self._modules.items():
        module.state_dict(destination, prefix + name + ".", keep_vars)

    for key in list(destination.keys()):
        if key in self._state_dict_pass_attr: del destination[key]

    return destination
def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
    own_state = self.state_dict(keep_vars=True)
    missing = set(own_state.keys()) - set(state_dict.keys())
    unexpected = set(state_dict.keys()) - set(own_state.keys())
    
    if strict and (missing or unexpected): ...  # 에러 구성
    for key, value in state_dict.items():
        if key in own_state:
            attr = own_state[key]
            if isinstance(attr, (nn.Parameter, nn.Buffer)):
                value_t = Tensor(value, device=self.device)
                attr.data = value_t.data
            else:
                setattr(self, key, value)

        elif strict:
            raise KeyError(...)
  • 동작: 파라미터/버퍼를 prefix 기반 키로 평탄화하여 저장, 트리 구조는 키 이름으로 표현. _state_dict_pass_attr는 저장에서 제외할 속성 집합.
  • 로드: strict 모드에서 missing/unexpected 키를 검증. 저장 시 numpy로 바꿨다면 로드 시 Tensor로 감싸 device를 맞춘다.

문제와 해결

초기에는 state_dict가 numpy 뷰를 반환해 이후 값 변경 시 원본이 변하는 문제가 있었다. .numpy()로 명시적 복사본을 반환하도록 수정. strict 검증 메시지도 파이토치와 유사하게 묶어 가독성 있게 만들었다.

🛠 표현과 후처리: __repr__, extra_repr, hooks

def __repr__(self) -> str:
    extra = self.extra_repr()
    child_lines = [...]
    main_str = self._get_name() + "("
    if extra: 
        main_str += extra

    if child_lines:
        if extra: main_str += "\n"
        main_str += "\n  " + "\n  ".join(child_lines) + "\n"
    main_str += ")"

    return main_str

def extra_repr(self) -> str:
    exclude = {"training", "device"}
    attrs = [...]
    return ", ".join(attrs)
  • 역할: 모듈 트리와 주요 속성을 사람이 읽기 좋은 형태로 출력. PyTorch의 모듈 문자열 표현을 모사.
  • hook 등록: register_forward_hook, register_backward_hook은 리스트에 hook을 쌓고 제거 lambda를 반환해 PyTorch와 동일한 사용성을 제공.

auto_repr / pass_attr

  • auto_repr 데코레이터는 특정 속성을 extra_repr에 노출하도록 클래스 단위로 설정.
  • set_state_dict_pass_attr는 특정 속성을 state_dict에서 제외하도록 클래스 단위로 설정.

📚 컨테이너: Sequential / ModuleList / ModuleDict

Sequential

class Sequential(Module):
    def __init__(self, *args: Module | OrderedDict[str, Module]) -> None:
        super().__init__()
        ...
    def forward(self, input: Tensor) -> Tensor:
        for module in self._modules.values():
            input = module(input)
        return input
  • OrderedDict 또는 가변 인자로 모듈을 받아 차례로 적용. 슬라이싱/인덱싱도 지원.

ModuleList

class ModuleList(Module):
    def __init__(self, modules: list[Module] | None = None) -> None:
        super().__init__()
        if modules is not None: self.extend(modules)

    def __getitem__(self, idx: int | slice) -> Module | Self: ...

    def append(self, module: Module) -> None: self.add_module(str(len(self._modules)), module)
  • 리스트 인터페이스를 모사해 가변 길이의 서브모듈 컬렉션을 다룬다. 인덱스·슬라이스로 부분 리스트 생성 가능.

ModuleDict

class ModuleDict(Module):
    def __init__(self, modules: dict[str, Module] | None = None) -> None:
        super().__init__()
        if modules is not None: self.update(modules)

    def __getitem__(self, key: str) -> Module: return self._modules[key]

    def __setitem__(self, key: str, module: Module) -> None: self.add_module(key, module)
  • 이름 기반 조회/갱신이 필요한 설정(멀티 브랜치, 조건부 경로)에서 사용.

🔍 벤치마킹과 차이점

  • 벤치마킹 대상: PyTorch nn.Module.
  • 모사한 부분: 자동 등록, state_dict/load_state_dict 인터페이스, train/eval/to 전파, hooks, 컨테이너(Sequential/ModuleList/ModuleDict) 사용성.
  • 다른 점: Lucid는 NumPy/MLX 백엔드에 맞춰 device를 문자열로 관리하고, 파라미터/버퍼를 얕은 구조로 유지해 구현 단순성을 우선. backward hook은 단일 Tensor 출력에만 허용해 PyTorch 제약을 명시적으로 반영.

구현 중 직면한 어려움과 해결

  1. 자동 등록 충돌: 동일 이름을 다른 레지스트리에 할당해 생기는 state_dict 충돌 → __setattr__에서 모든 레지스트리를 스캔해 클린업.
  2. 디바이스 이동 중복: 재귀 이동 과정에서 중복 to 호출 → recurse=False로 자기 파라미터만 이동 후 modules() 재귀 호출로 분리.
  3. state_dict 뷰 문제: numpy 뷰를 반환해 원본이 변하는 버그 → .numpy() 복사 반환.
  4. hooks 범위: backward hook을 다중 출력에 붙일지 여부 → PyTorch 호환성을 위해 단일 Tensor 출력일 때만 부착.
  5. 컨테이너 슬라이싱: Sequential/ModuleList 슬라이스가 참조/복사 중 어디까지인지 정의 → PyTorch와 동일하게 새 컨테이너 생성.

✅ 정리

Lucid의 nn.Module은 파이토치의 사용성·철학을 최대한 모사하면서 다양한 백엔드에 맞게 단순화된 구현을 제공한다. 자동 등록과 state_dict, 모드/디바이스 전파, hook 지원, 컨테이너까지 포함한 이 뼈대 위에서 이후의 모든 모델 정의가 이루어진다. 이 글의 시그니처와 스니펫을 따라가면 Lucid와 유사한 nn.Module을 재현하거나, 필요한 부분만 가져와 다른 프로젝트에 맞게 변형할 수 있을 것이다.

profile
Korea Univ. Computer Science & Engineering

0개의 댓글