딥러닝 - PyTorch: Sequential API, Subclassing API

dumbbelldore·2025년 1월 15일
0

zero-base 33기

목록 보기
75/97

1. Sequential API

  • torch.nn.Sequential을 사용하여 간단한 모델을 순차적으로 정의할 수 있음
import torch.nn as nn
import torchsummary as ts

# Sequential API
model = nn.Sequential(
    nn.Linear(784, 64), 
    nn.ReLU(),
    nn.Linear(64, 10),
    nn.Sigmoid(),
)

ts.summary(model, (784,))
# ----------------------------------------------------------------
#         Layer (type)               Output Shape         Param #
# ================================================================
#             Linear-1                   [-1, 64]          50,240
#               ReLU-2                   [-1, 64]               0
#             Linear-3                   [-1, 10]             650
#            Sigmoid-4                   [-1, 10]               0
# ================================================================
# Total params: 50,890
# Trainable params: 50,890
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.00
# Forward/backward pass size (MB): 0.00
# Params size (MB): 0.19
# Estimated Total Size (MB): 0.20
# ----------------------------------------------------------------

2. Subclassing API

  • torch.nn.functional 모듈과 forward()를 사용해 torch.nn.Module을 상속받은 클래스에 정의함으로써 보다 유연한 모델링이 가능함
import torch
import torch.nn as nn
import torch.nn.functional as F

# Subclassing API
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(32 * 14 * 14, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=1)

model = Model()

ts.summary(model, (1, 28, 28))
# ----------------------------------------------------------------
#         Layer (type)               Output Shape         Param #
# ================================================================
#             Conv2d-1           [-1, 32, 28, 28]             320
#             Linear-2                   [-1, 10]          62,730
# ================================================================
# Total params: 63,050
# Trainable params: 63,050
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.00
# Forward/backward pass size (MB): 0.19
# Params size (MB): 0.24
# Estimated Total Size (MB): 0.43
# ----------------------------------------------------------------

*이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다.

profile
데이터 분석, 데이터 사이언스 학습 저장소

0개의 댓글