torch에서 모든 Module이 제대로 기능을 수행하기 위해서 초기화 메소드의 상속이 필수적인 이유

Gangtaro·2022년 2월 2일
0
post-custom-banner
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
		def __init__(self):
				# super를 통해서 nn.Module의 init 메소드에 저장되어있는 정보를 상속 받아야
				# forward를 통한 다양한 클래스의 기능을 수행할 수 있다. 
				super(Model, self).__init__()

				# 그 이외의 내가 만든 모델에 필요한 레이어 또는 가중치, 값과 같은 forward 연산에
				# 필요한 다양한 것들을 해당 초기화 메소드에 추가해준다.
				self.conv1 = nn.Conv2d(1, 20, 5)
				self.conv2 = nn.Conv2d(20,20, 5)

		# 해당 모델의 계산영역(순전파) 구조를 담당하는 함수이다.
		# 해당 함수의 구조대로 계산이 일어나고 결과값이 도출된다. 
		# __call__ 역할을 수행한다고 생각하면 된다.
		def forward(self, x):
				x = F.relu(self.conv1(x))
				return F.relu(self.conv2(x))

간단한 설명은 주석에 적힌 것과 같다.

  • nn.Module 클래스 내부에서 새로운 변수를 만들 때,
    ”변수 = 값” 형태의 코드를 적으면 __setattr__ 특수 메소드가 호출된다.
  • nn.Module 클래스의 __setattr__ 함수에서는
    새로운 값을 만들 때 사용한 이 ”값”이 nn.Module의 인스턴스인지 아닌지를 체크하는 과정이 있다.
  • 이 값이 만약에 nn.Module의 인스턴스라면 전체 모듈의 submodule로 취급하고, 아니라면 그냥 무시하게 된다.
  • 따라서 모듈이 토치 모듈에서 잘 연계되어서 작동하기 위해서는 초기화 메소드의 상속 함수를 사용해주는 것이 필수적이다.
    - 또한 __init__에 다양한 정보가 있으므로, forward 연산이 가능하도록 만들기 위해서는 초기화 함수를 절대 오버라이딩을 하면 안되고, 상속 받은 것에서 덧붙여서 써야한다.

reference

post-custom-banner

0개의 댓글