import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
# super를 통해서 nn.Module의 init 메소드에 저장되어있는 정보를 상속 받아야
# forward를 통한 다양한 클래스의 기능을 수행할 수 있다.
super(Model, self).__init__()
# 그 이외의 내가 만든 모델에 필요한 레이어 또는 가중치, 값과 같은 forward 연산에
# 필요한 다양한 것들을 해당 초기화 메소드에 추가해준다.
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20,20, 5)
# 해당 모델의 계산영역(순전파) 구조를 담당하는 함수이다.
# 해당 함수의 구조대로 계산이 일어나고 결과값이 도출된다.
# __call__ 역할을 수행한다고 생각하면 된다.
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
간단한 설명은 주석에 적힌 것과 같다.
__setattr__
특수 메소드가 호출된다.__setattr__
함수에서는__init__
에 다양한 정보가 있으므로, forward 연산이 가능하도록 만들기 위해서는 초기화 함수를 절대 오버라이딩을 하면 안되고, 상속 받은 것에서 덧붙여서 써야한다.reference