모든 neural network modules의 기본 class
- 모듈은 다른 모듈을 포함할 수 있고, 트리 구조로 형성할 수 있다.
- 새로운 모델을 만들 때 상속 받아서 사용한다.
nn.Module
을 상속받아서nn.Module
의 기본적인 기능들을 사용할 수 있다.super().__init__()
: nn.Module의 생성자 호출nn.Module
상속 후__init__()
과forward()
를 override해서 자신만의 모델을 만들 수 있다!__init__()
: 모델에서 사용될 module( ex)nn.Linear
,nn.Conv2d
), activation function 등을 정의forward()
: 모델에서 실행되어야하는 계산을 정의
(nn.Module
을 상속받은 클래스의 객체는forward()
함수를 호출하지 않아도 model 객체에 파라미터로 전달하면 바로 forward()가 시작된다. )
import torch
import torch.nn as nn
class Model(nn.Module): # 상속 => __init__, forward를 override
def __init__(self):
super().__init__() # nn.Module(부모 class)의 생성자 호출 ( -> 자식 class에서 사용 가능)
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
x1 = torch.tensor([1])
x2 = torch.tensor([2])
add = Model()
output = add(x1, x2)
print(output)