torch.nn.Sequential을 사용하여 간단한 모델을 순차적으로 정의할 수 있음import torch.nn as nn
import torchsummary as ts
# Sequential API
model = nn.Sequential(
nn.Linear(784, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.Sigmoid(),
)
ts.summary(model, (784,))
# ----------------------------------------------------------------
# Layer (type) Output Shape Param #
# ================================================================
# Linear-1 [-1, 64] 50,240
# ReLU-2 [-1, 64] 0
# Linear-3 [-1, 10] 650
# Sigmoid-4 [-1, 10] 0
# ================================================================
# Total params: 50,890
# Trainable params: 50,890
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.00
# Forward/backward pass size (MB): 0.00
# Params size (MB): 0.19
# Estimated Total Size (MB): 0.20
# ----------------------------------------------------------------
torch.nn.functional 모듈과 forward()를 사용해 torch.nn.Module을 상속받은 클래스에 정의함으로써 보다 유연한 모델링이 가능함import torch
import torch.nn as nn
import torch.nn.functional as F
# Subclassing API
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32 * 14 * 14, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = torch.flatten(x, 1)
x = self.fc1(x)
return F.log_softmax(x, dim=1)
model = Model()
ts.summary(model, (1, 28, 28))
# ----------------------------------------------------------------
# Layer (type) Output Shape Param #
# ================================================================
# Conv2d-1 [-1, 32, 28, 28] 320
# Linear-2 [-1, 10] 62,730
# ================================================================
# Total params: 63,050
# Trainable params: 63,050
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.00
# Forward/backward pass size (MB): 0.19
# Params size (MB): 0.24
# Estimated Total Size (MB): 0.43
# ----------------------------------------------------------------
*이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다.