기본형
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Sequential(nn.Conv2d(3,8,3, padding=1),
nn.BatchNorm2d(8),
nn.ReLU()
)
self.Maxpool1 = nn.MaxPool2d(2)
self.conv2 = nn.Sequential(nn.Conv2d(8,16,3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU())
self.Maxpool2 = nn.MaxPool2d(2)
self.conv3 = nn.Sequential(nn.Conv2d(16,32,3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU())
self.Maxpool3 = nn.MaxPool2d(2)
self.fc = nn.Linear(32*4*4, 10)
def forward(self, x):
x = self.conv1(x)
x = self.Maxpool1(x)
x = self.conv2(x)
x = self.Maxpool2(x)
x = self.conv3(x)
x = self.Maxpool3(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)
return x