
Lucid에서 nn.Module은 모델 정의의 단위이자 파라미터/버퍼/서브모듈을 조직하는 뼈대다. PyTorch의 nn.Module을 벤치마킹해 동일한 사용성을 목표로 했고, 파이프라인 전체(autograd, state_dict, 디바이스 이동, hooks 등)가 자연스럽게 이어지도록 설계했다. 이 글은 그런 목표 아래 lucid/nn/module.py를 어떻게 구현했고, 메서드별 시그니처와 핵심 로직을 어떤 선택으로 구성했는지 상세히 기록한다.
nn.Module과 역할PyTorch에서 nn.Module은
nn.Parameter)와 버퍼(러닝 스탯 등)를 자동 등록/관리하고,forward 정의만 제공하면 __call__로 hooks, autograd 연결을 투명하게 처리한다.Lucid도 동일한 철학을 따르며, 특히 파라미터/버퍼/서브모듈의 자동 등록, state_dict/load_state_dict, train/eval 전파, 디바이스 이동, hooks를 모사했다.
class Module:
def __init__(self) -> None:
object.__setattr__(self, "_parameters", OrderedDict())
object.__setattr__(self, "_buffers", OrderedDict())
object.__setattr__(self, "_modules", OrderedDict())
self.training = True
self.device: _DeviceType = "cpu"
self._forward_hooks: list[_ForwardHookType] = []
self._backward_hooks: list[_BackwardHookType] = []
self._state_dict_pass_attr = set()
_parameters, _buffers, _modules를 OrderedDict로 분리해 등록/순회/저장 순서를 보장.training/device 기본값을 설정해 train()/eval()/to() 호출 시 일관성 유지.__setattr__에서 자동 등록이 이루어지므로, 생성자에서만 object.__setattr__를 사용해 초기 상태를 보장해야 했다. 이후 일반 속성 대입은 커스텀 로직을 거친다.__setattr__ 오버라이드def __setattr__(self, name: str, value: Any) -> None:
registry_map = {nn.Parameter: self._parameters, nn.Buffer: self._buffers, Module: self._modules}
target_registry = None
for cls, registry in registry_map.items():
if isinstance(value, cls):
target_registry = registry
break
if target_registry is not None:
for registry in registry_map.values():
if registry is not target_registry and name in registry:
del registry[name]
target_registry[name] = value
else:
for registry in registry_map.values():
if name in registry:
del registry[name]
super().__setattr__(name, value)
register_parameter, register_buffer, add_module 없이도 직접 속성 할당으로 등록 가능.state_dict 충돌이 발생할 수 있어, 모든 레지스트리를 스캔해 클린업한다.setattr_raw는 내부 초기화나 메타정보 조작 시 자동 등록을 우회하는 escape hatch다.
def add_module(self, name: str, module: Self) -> None:
if not isinstance(module, Module) and module is not None:
raise TypeError(...)
self.__setattr__(name, module)
def register_parameter(self, name: str, param: nn.Parameter | None) -> None:
if not isinstance(param, nn.Parameter) and param is not None:
raise TypeError(...)
self.__setattr__(name, param)
def register_buffer(self, name: str, buffer: nn.Buffer | _ArrayOrScalar | None, dtype=None) -> None:
if buffer is not None and not isinstance(buffer, nn.Buffer):
buffer = nn.Buffer(buffer, dtype=dtype, device=self.device)
self.__setattr__(name, buffer)
None을 허용해 조건부 등록 패턴을 지원.device로 생성되며, dtype도 지정 가능.파라미터가 아닌 일반 Tensor를 실수로 등록하면 autograd 경로에 올라가지만 state_dict에는 빠지는 문제가 생길 수 있다. 이를 막기 위해 타입 체크 후 강제 변환 또는 에러를 던지도록 했다.
__call__과 hooksdef __call__(self, *args, **kwargs):
output = self.forward(*args, **kwargs)
for hook in self._forward_hooks:
hook(self, args, output)
if isinstance(output, Tensor) and self._backward_hooks:
for hook in self._backward_hooks:
output.register_hook(hook)
return output
forward 호출 전후로 hook 실행을 삽입. 출력이 Tensor 하나일 때만 backward hook을 붙여 PyTorch와 동일한 제약을 따른다.register_hook으로 연결되어 autograd 그래프에 포함된다.def forward(self) -> Tensor | tuple[Tensor, ...]:
raise NotImplementedError("The forward method must be implemented by the subclass.")
서브클래스가 필수로 구현해야 하며, 나머지 호출/후처리는 __call__이 담당.
def train(self, mode: bool = True) -> Self:
self.training = mode
for module in self._modules.values():
module.train(mode)
return self
def eval(self) -> Self:
return self.train(mode=False)
def to(self, device: _DeviceType) -> Self:
if device == self.device:
return self
self.device = device
for param in self.parameters(recurse=False):
param.to(device)
for buffer in self.buffers(recurse=False):
buffer.to(device)
for module in self.modules():
module.to(device)
return self
train/eval 플래그를 재귀적으로 설정.modules() 순회로 재귀 이동.초기 구현에서 modules() 재귀 호출이 중복 이동을 일으켜 성능 저하가 있었다. parameters(buffers)는 recurse=False로 자기 자신만 이동하고, 서브모듈 이동은 따로 처리해 한 번씩만 이동하도록 정리했다.
parameters/buffers/modules/childrendef parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
for _, param in self._parameters.items(): yield param
if recurse:
for module in self._modules.values(): yield from module.parameters(recurse=recurse)
def buffers(self, recurse: bool = True) -> Iterator[nn.Buffer]:
for buffer in self._buffers.values(): yield buffer
if recurse:
for module in self._modules.values(): yield from module.buffers(recurse=recurse)
def modules(self) -> Iterator[Self]:
yield self
for module in self._modules.values(): yield from module.modules()
def children(self) -> Iterator[Self]:
return iter(self._modules.values())
recurse 플래그로 자기 자신만 혹은 트리 전체를 순회.count_parameters, parameter_size가 위 순회를 재사용해 총 파라미터 수를 계산한다.state_dict와 load_state_dictdef state_dict(self, destination=None, prefix="", keep_vars=False):
destination = OrderedDict() if destination is None else destination
for name, param in self._parameters.items():
destination[prefix + name] = param if keep_vars else param.numpy()
for name, buffer in self._buffers.items():
destination[prefix + name] = buffer if keep_vars else buffer.numpy()
for name, module in self._modules.items():
module.state_dict(destination, prefix + name + ".", keep_vars)
for key in list(destination.keys()):
if key in self._state_dict_pass_attr: del destination[key]
return destination
def load_state_dict(self, state_dict: OrderedDict, strict: bool = True) -> None:
own_state = self.state_dict(keep_vars=True)
missing = set(own_state.keys()) - set(state_dict.keys())
unexpected = set(state_dict.keys()) - set(own_state.keys())
if strict and (missing or unexpected): ... # 에러 구성
for key, value in state_dict.items():
if key in own_state:
attr = own_state[key]
if isinstance(attr, (nn.Parameter, nn.Buffer)):
value_t = Tensor(value, device=self.device)
attr.data = value_t.data
else:
setattr(self, key, value)
elif strict:
raise KeyError(...)
_state_dict_pass_attr는 저장에서 제외할 속성 집합.Tensor로 감싸 device를 맞춘다.초기에는 state_dict가 numpy 뷰를 반환해 이후 값 변경 시 원본이 변하는 문제가 있었다. .numpy()로 명시적 복사본을 반환하도록 수정. strict 검증 메시지도 파이토치와 유사하게 묶어 가독성 있게 만들었다.
__repr__, extra_repr, hooksdef __repr__(self) -> str:
extra = self.extra_repr()
child_lines = [...]
main_str = self._get_name() + "("
if extra:
main_str += extra
if child_lines:
if extra: main_str += "\n"
main_str += "\n " + "\n ".join(child_lines) + "\n"
main_str += ")"
return main_str
def extra_repr(self) -> str:
exclude = {"training", "device"}
attrs = [...]
return ", ".join(attrs)
register_forward_hook, register_backward_hook은 리스트에 hook을 쌓고 제거 lambda를 반환해 PyTorch와 동일한 사용성을 제공.auto_repr 데코레이터는 특정 속성을 extra_repr에 노출하도록 클래스 단위로 설정. set_state_dict_pass_attr는 특정 속성을 state_dict에서 제외하도록 클래스 단위로 설정.Sequential / ModuleList / ModuleDictSequential
class Sequential(Module):
def __init__(self, *args: Module | OrderedDict[str, Module]) -> None:
super().__init__()
...
def forward(self, input: Tensor) -> Tensor:
for module in self._modules.values():
input = module(input)
return input
ModuleList
class ModuleList(Module):
def __init__(self, modules: list[Module] | None = None) -> None:
super().__init__()
if modules is not None: self.extend(modules)
def __getitem__(self, idx: int | slice) -> Module | Self: ...
def append(self, module: Module) -> None: self.add_module(str(len(self._modules)), module)
ModuleDict
class ModuleDict(Module):
def __init__(self, modules: dict[str, Module] | None = None) -> None:
super().__init__()
if modules is not None: self.update(modules)
def __getitem__(self, key: str) -> Module: return self._modules[key]
def __setitem__(self, key: str, module: Module) -> None: self.add_module(key, module)
nn.Module.state_dict/load_state_dict 인터페이스, train/eval/to 전파, hooks, 컨테이너(Sequential/ModuleList/ModuleDict) 사용성.device를 문자열로 관리하고, 파라미터/버퍼를 얕은 구조로 유지해 구현 단순성을 우선. backward hook은 단일 Tensor 출력에만 허용해 PyTorch 제약을 명시적으로 반영.__setattr__에서 모든 레지스트리를 스캔해 클린업. recurse=False로 자기 파라미터만 이동 후 modules() 재귀 호출로 분리. .numpy() 복사 반환. Sequential/ModuleList 슬라이스가 참조/복사 중 어디까지인지 정의 → PyTorch와 동일하게 새 컨테이너 생성.Lucid의 nn.Module은 파이토치의 사용성·철학을 최대한 모사하면서 다양한 백엔드에 맞게 단순화된 구현을 제공한다. 자동 등록과 state_dict, 모드/디바이스 전파, hook 지원, 컨테이너까지 포함한 이 뼈대 위에서 이후의 모든 모델 정의가 이루어진다. 이 글의 시그니처와 스니펫을 따라가면 Lucid와 유사한 nn.Module을 재현하거나, 필요한 부분만 가져와 다른 프로젝트에 맞게 변형할 수 있을 것이다.