nn.Module init시 super().__init__() 해야하는 이유

mincheol2·2022년 2월 2일
0

Pytorch

목록 보기
6/6
post-custom-banner

super().__init__()

모듈 클래스 생성시 보통 다음과 같은 초기화 함수를 작성하게 된다.

class MyModule(nn.Module):
	def __init__(self,x):
    	super().__init__()
        self.x = x

이때 super().__init__()은 무엇을 상속 받는 것일까?

Pytorch nn.Module 문서에서 init 메소드를 살펴보자

    def __init__(self) -> None:
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
        self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
        self._non_persistent_buffers_set: Set[str] = set()
        self._backward_hooks: Dict[int, Callable] = OrderedDict()
        self._is_full_backward_hook = None
        self._forward_hooks: Dict[int, Callable] = OrderedDict()
        self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
        self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
        self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
        self._modules: Dict[str, Optional['Module']] = OrderedDict()

다양한 미리 변수들이 선언되어 있다.
대부분 직접적으로 건드리지 말라고 앞에 언더바를 붙여 놓은 것을 볼 수 있는데
parameters, buffers, hook 등 모두 nn.Module클래스 내 함수로 접근하는 변수들임을 볼 수 있다.

정상적인 모듈을 만들기 위해서는 이러한 변수들을 상속 받아야 하기 때문에 직접 만드는 모듈은 super().__init__()을 통해 위 변수들을 상속 받는다.

super().__init__() vs super(MyClass,self).__init__()

class MyModule(nn.Module):
	def __init__(self,x):
    	super(MyModule,self).__init__()
        self.x = x
class MyModule(nn.Module):
	def __init__(self,x):
    	super().__init__()
        self.x = x

이 두코드는 형태가 조금 다르다.

  • super(MyModule,self).__init__()
  • super().__init__()

super 안에 현재 클래스를 명시해준 것과 아닌 것으로 나눌 수 있는데

이는 기능적으론 아무런 차이가 없다

파생클래스와 self를 넣어서 현재 클래스가 어떤 클래스인지 명확하게 표시 해주는 용도이다.
super(파생클래스, self).__init__()

profile
옹오옹오오오옹ㅇㅇ
post-custom-banner

0개의 댓글