[CLASS] torch.nn.Module
공식문서에 따르면
torch.nn.Module 은 PyTorch의 모든 Neural Network의 Base Class이다. 모듈은 다른 모듈을 포함할 수 있고, 트리 구조로 형성할 수 있다.
공식문서에 예제를 코딩해보면서 감을 잡아보자.
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
MyModel을 만들 때, nn.Module을 상속받아서 기본적인 기능들을 사용할 수 있게 만들어준다.
이것을 응용해서 더하기 모델을 만들어 보자.
import torch
from torch import nn
class Add(nn.Module):
def __init__(self):
super().__init__() # 반드시 Add class의 부모 클래스인 nn.Module을 super()을 사용해서 초기화 시켜줘야 한다.
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
x1 = torch.tensor([1])
x2 = torch.tensor([2])
add = Add()
output = add(x1, x2)
print(output)
여기서 의문점이 들었던 것이 forward() 함수를 호출하지 않았는데 add객체에 파라미터로 전달하면 바로 forward function이 시작된다. 이것은 nn.Module을 상속받은 클래스의 특징이라고 볼 수 있다.