[PyTorch] nn.module 상속하는 방법과 이유

Junsoo-Kim·2025년 3월 24일

PyTorch

목록 보기
1/2
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)  
        x = self.fc2(x)  
        return x

PyTorch 코드에서 모델을 설계할 때, 항상 클래스 인자에 nn.Module이 들어가고 __init__ method에 super().__init__()가 들어가는 것을 확인할 수 있다.

나는 파이썬을 깊게 공부해보지는 않아서 잘 몰랐는데, 이참에 정리해보고자 한다.

Class Inheritance

클래스 상속이란, 부모 클래스(parent, super class)가 자식 클래스(sub, child class)에게 속성과 method를 물려주는 것을 말한다.

그런데 상속만 하면 바로 attribute와 method를 사용 가능한 줄 알았는데, 하나의 과정이 더 필요하다.

  • super().__init__() : 여기서 super는 부모 클래스를 의미하는데, 이건 한마디로 부모 클래스의 __init__ 메소드를 불러와 초기화시키는 것이다.

왜 이 과정이 필요하냐면, 부모 클래스에서 정의한 method에 있는 인자들이 제대로 초기화가 되어있지 않으면 부모의 method를 쓰고 싶어도 호출할 때 AttributeError가 발생할 것이기 때문이다.

왜 AttributeError가 발생하는지 한번 살펴보자.

import torch.nn as nn

class MyModel(nn.Module):  # nn.Module을 상속
    def __init__(self):
        # super().__init__() 만약 없다면
        self.fc1 = nn.Linear(10, 5) 

model = MyModel() # AttributeError 발생

파이썬에서는 속성을 할당할 때 자동으로 __setattr__()가 실행된다. 지금 우리는 nn.Module 클래스를 상속받았기 때문에 부모 클래스의 __setattr__() 메소드가 실행될 것이다. 한번 보자.

# nn.Module의 __setattr__ method
def __setattr__(self, name, value):
    if isinstance(value, nn.Module):  
        self._modules[name] = value  # nn.Module이면 _modules에 자동 등록
    else:
        object.__setattr__(self, name, value)  # 일반 속성은 기본 방식대로 저장

만약 속성을 할당한 것이 nn.Module의 것이라면 올바르게 딕셔너리 형태로 값이 저장될 것이다. 따라서 무조건 super().__init__()을 해서 속성값들을 초기화 시켜줘야 관련된 것들을 사용할 수 있다는 얘기이다.

nn.Module의 __init__() method는 아래와 같다.

 def __init__(self):
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._non_persistent_buffers_set = set()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

저기서 self._modules에 우리가 할당한 layer의 속성이 들어간다.


번외로 super(Subclass name, self).__init__()로 초기화하는 방법도 있다고 했는데 의미하는 바는 똑같고 버전에 따라 조금 차이가 있다고 들었다.

profile
CS Undergraduate 2nd, University of Seoul

0개의 댓글