모듈 클래스 생성시 보통 다음과 같은 초기화 함수를 작성하게 된다.
class MyModule(nn.Module):
def __init__(self,x):
super().__init__()
self.x = x
이때 super().__init__()
은 무엇을 상속 받는 것일까?
Pytorch nn.Module 문서에서 init 메소드를 살펴보자
def __init__(self) -> None:
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters: Dict[str, Optional[Parameter]] = OrderedDict()
self._buffers: Dict[str, Optional[Tensor]] = OrderedDict()
self._non_persistent_buffers_set: Set[str] = set()
self._backward_hooks: Dict[int, Callable] = OrderedDict()
self._is_full_backward_hook = None
self._forward_hooks: Dict[int, Callable] = OrderedDict()
self._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
self._state_dict_hooks: Dict[int, Callable] = OrderedDict()
self._load_state_dict_pre_hooks: Dict[int, Callable] = OrderedDict()
self._modules: Dict[str, Optional['Module']] = OrderedDict()
다양한 미리 변수들이 선언되어 있다.
대부분 직접적으로 건드리지 말라고 앞에 언더바를 붙여 놓은 것을 볼 수 있는데
parameters
, buffers
, hook
등 모두 nn.Module
클래스 내 함수로 접근하는 변수들임을 볼 수 있다.
정상적인 모듈을 만들기 위해서는 이러한 변수들을 상속 받아야 하기 때문에 직접 만드는 모듈은 super().__init__()
을 통해 위 변수들을 상속 받는다.
class MyModule(nn.Module):
def __init__(self,x):
super(MyModule,self).__init__()
self.x = x
class MyModule(nn.Module):
def __init__(self,x):
super().__init__()
self.x = x
이 두코드는 형태가 조금 다르다.
super(MyModule,self).__init__()
super().__init__()
super 안에 현재 클래스를 명시해준 것과 아닌 것으로 나눌 수 있는데
이는 기능적으론 아무런 차이가 없다
파생클래스와 self를 넣어서 현재 클래스가 어떤 클래스인지 명확하게 표시 해주는 용도이다.
super(파생클래스, self).__init__()