torch.nn.Module은 모든 뉴럴 네트워크 모듈의 기본 클래스이다. 일반적인 모델들은 이 클래스를 상속받아야한다. 모듈들은 다른 모듈을 또 포함할 수 있다.
__init__() 메소드에는 신경망 레이어의 구성요소들을 정의하고, __forward__에서는 호출 될 때 수행되는 연산을 정의한다. torch.nn.Module을 상속받는 모든 클래스에서 override되어야 한다. 일반적으로 이 두가지 메소드는 반드시 정의한다.
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
torch.nn.Module에는 많은 메소드들이 있지만 모두 소개할 수는 없어 가장 중요한 내용만 소개하였다.