Architecture Table
Code
import torch.nn as nn
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv1_act = nn.Tanh()
self.conv1_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
self.conv2_act = nn.Tanh()
self.conv2_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
self.conv3_act = nn.Tanh()
self.fc1 = nn.Linear(in_features=120, out_features=84)
self.fc1_act = nn.Tanh()
self.fc2 = nn.Linear(in_features=84, out_features=10)
self.softmax = nn.Softmax(dim=1) # [-1, 10] 중에서 10
def forward(self, x):
x = self.conv1_act(self.conv1(x))
x = self.conv1_pool(x)
x = self.conv2_act(self.conv2(x))
x = self.conv2_pool(x)
x = self.conv3_act(self.conv3(x)) # (-1, 120, 1, 1) -> fc1 x
x = x.view(x.shape[0], -1) # view = reshape, (B, 120) => flatten
x = self.fc1_act(self.fc1(x))
x = self.fc2(x)
x = self.softmax(x)
return x
정말 좋은 글 감사합니다!