In this practice session, we will cover implementation of Convolutional Neural Networks (CNNs) using PyTorch library. Particulary, in part 1, we will see how to define and use convolutional, batch normalization, and dropout layer to build a simple CNNs to classify MNIST. In part 2, we will implement ResNet for CIFAR-10 classification task and see the effectiveness of residual connection.
Let's import required libraries and datasets. We will use MNIST and CIFAR-10 to train a simple CNNs and ResNet, respectively.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10
from IPython.display import Image
# MNIST
mnist_train = MNIST("./", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_test = MNIST("./", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)
mnist_train, mnist_val = torch.utils.data.random_split(mnist_train, [50000, 10000])
dataloaders = {}
dataloaders['train'] = DataLoader(mnist_train, batch_size=128, shuffle=True)
dataloaders['val'] = DataLoader(mnist_val, batch_size=128, shuffle=False)
dataloaders['test'] = DataLoader(mnist_test, batch_size=128, shuffle=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
In case of CIFAR-10, we use conventional transforms and normalization.
For more example, please refer to
https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
transforms_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transforms_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
cifar_train = CIFAR10(root='./', train=True,
download=True, transform=transforms_train)
cifar_test = CIFAR10(root='./', train=False,
download=True, transform=transforms_test)
cifar_loader = {}
cifar_loader['train'] = DataLoader(cifar_train, batch_size=128,
shuffle=True, num_workers=4)
cifar_loader['test'] = DataLoader(cifar_test, batch_size=128,
shuffle=False, num_workers=4)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz
0%| | 0/170498071 [00:00<?, ?it/s]
Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified
In PyTorch, 2-dimensional convolutional layer is given with the pytorch torch.nn.Conv2d
package. In this section, we will learn basic usage of pytorch convolutional layer with some example codes and practices.
As in our lecture session, we should specify convolution with the number of channels of input and output, kernel size, the size of stride and padding. In Pytorch torch.nn.Conv2d
class, those traits can be specified as class parameters. The detailed explanation and default values are available on below and the official sites (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)
in_channels (int) – Number of channels in the input image
out_channels (int) – Number of channels produced by the convolution
kernel_size (int or tuple) – Size of the convolving kernel
stride (int or tuple, optional) – Stride of the convolution. Default: 1
padding (int or tuple, optional) – Zero-padding added to both sides of the input. Default: 0
Now, let's define our convolutional layer and practice. As in our lecture note, let's suppose we have 32x32-sized image with 3 channels.
"""
Note: Since PyTorch Conv2d receives 4-dimensional input (i.e. a batch of image(s)),
we define input x with the first argument 1.
"""
x = torch.randn(1, 3, 32, 32)
Then, the number of input channel is 3, which is the same value with the input images' channel (the second argument of above randn
method).
How about the number of output channel, kernel and stride size? Following the figure in our lecture note, we can easily see that the number of output channel should be 1. You can naively regard the kernel size as the spatial size of filter in the lecture note. Thus, the kernel size should be (5,5), and the stride size should be 1. In practice, since the spatial size of filter in in square form (i.e., width = height), we usually specifiy kernel size with only single integer (in our case 5).
# Fill out the ?? of below
conv_layer = torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=5, stride=1)
# You should check the output size of covolution layer is [1, 1, 28, 28].
conv_layer(x).size()
torch.Size([1, 1, 28, 28])
Also check the other example in the lecture note as below.
# Fill out the ?? of below
x = torch.randn(1, 3, 32, 32)
conv_layer = torch.nn.Conv2d(in_channels=3, out_channels=10, kernel_size=5, stride=1, padding=2)
print(conv_layer(x).size())
torch.Size([1, 10, 32, 32])
In a similar manner, batch normalization and dropout layer also can be used from torch.nn.BatchNorm2d
and torch.nn.Dropout2d
.
x = torch.randn(1, 3, 32, 32)
bn = torch.nn.BatchNorm2d(num_features=3) # channel의 개수와 num_features가 같아야 한다.
print(bn(x).size()) # check batch norm does not change the size of input
dropout = torch.nn.Dropout2d(p=0.5) # dropout can specify probability of an element to be zeroed.
print(dropout(bn(x)).size()) # check dropout does not change the size of input
torch.Size([1, 3, 32, 32])
torch.Size([1, 3, 32, 32])
We can combine convolution layer, batch norm layer and activation function (e.g., ReLU) to construct a functional unit. In this case, we can use torch.nn.Sequential
to define a block of sequential layers.
x size : (128, 1, 28, 28)
output : (128, 10)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 1 input image channel, 32 output channels, 7x7 square convolution, 1 stride
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, stride=1),
nn.BatchNorm2d(num_features=32),
nn.ReLU(),
)
# 32 input image channel, 64 output channels, 7x7 square convolution, 1 stride
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1),
nn.BatchNorm2d(num_features=64),
nn.ReLU(),
)
self.fc = nn.Linear(64*16*16, 10)
def forward(self, x):
# x's shape : (128,1,28,28)
out = self.layer1(x)
# out's shape : (128,32,22,22)
out = self.layer2(out)
# out's shape : (128,64,16,16)
out = torch.flatten(out, 1)
# out's shape : (128,64*16*16)
out = self.fc(out)
# out's shape : (128,10)
return out
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net = Net().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
cuda:0
criterion = torch.nn.CrossEntropyLoss()
for _ in range(20):
for x, y in dataloaders['train']:
x, y = x.to(device), y.to(device)
out = net(x)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
net.eval()
correct = 0.0
for x, y in dataloaders['test']:
x, y = x.to(device), y.to(device)
out = net(x)
correct += (out.argmax(1) == y).float().sum().item()
print(100. * correct / len(dataloaders['test'].dataset))
99.09
In this section, we will implement ResNet and see the effectiveness of residual connection in terms of test performance.
The overall structure of ResNet is like below.
n can be chosen from {3,5,7,9,18} which of each corresponds to ResNet-20, 32, 44, 56, and 110, respectively.
Residual Block consists of 2 convolution layers with 3x3 size kernel and ReLU activation function (See below figure). Let's implement ResidualBlock
class below with 2 convolutional layers and residual connection.
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, down_sample=False):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.down_sample = down_sample
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
def down_sampling(self, x):
out = F.pad(x, (0, 0, 0, 0, 0, self.out_channels - self.in_channels))
out = nn.MaxPool2d(2, stride=self.stride)(out)
return out
def forward(self, x):
shortcut = x # this will be used to build residual connection.
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.down_sample:
shortcut=self.down_sampling(x)
out += shortcut # residual connection
out = self.relu(out)
return out
ResidualBlock
class which extends torch.nn.Module
. ResidualBlock
class receives in_channels
, out_channels
, stride
and down_sample
.
In ResNet, there are residual blocks that twice the output channel(16 to 32, 32 to 64). The stride
argument for ResidualBlock
is set to 2 in such residual blocks to down sample (reduce spatial dimension) while increasing channels.
However, the residual connection in the residual block can occur dimension mismatch since the output of other path (through convolutional layers) change the dimension of input with stride=2
. Thus, residual block should support downsample through the residual connection in demand.
We support this feature in down_sampling
method in ResidualBlock
class. It conducts zero-padding to exapnd the channels and max-pooling to shrink spatial dimension through residual block. Using down_sampling
in the middle of the forward
method to handle down_sample
condition to residual connection.
Now implement ResNet
class. Assume the block
argument will be ResidualBlock
we implemented above. Here are required implementation details.
init
method, specifiy all details of convolution, batch norm layers. get_layers
method, set down_sample boolean variable according to the stride information. Then, define a list of residual blocks (layer_list
). Make sure the down-sample only occurs at the first block in demand.class ResNet(nn.Module):
def __init__(self, num_layers, block, num_classes=10):
super(ResNet, self).__init__()
self.num_layers = num_layers
#input(channel:3) -> (conv 3x3) -> (bn) -> (relu) -> output(channel:16)
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
# feature map size = 16x32x32
self.layers_2n = self.get_layers(block, 16, 16, stride=1)
# feature map size = 32x16x16
self.layers_4n = self.get_layers(block, 16, 32, stride=2)
# feature map size = 64x8x8
self.layers_6n = self.get_layers(block, 32, 64, stride=2)
# output layers
self.avg_pool = nn.AvgPool2d(8, stride=1)
self.fc_out = nn.Linear(64, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def get_layers(self, block, in_channels, out_channels, stride):
if stride == 2:
down_sample = True
else:
down_sample = False
layer_list = nn.ModuleList([block(in_channels, out_channels, stride, down_sample)])
for _ in range(self.num_layers-1):
layer_list.append(block(out_channels, out_channels))
return nn.Sequential(*layer_list)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layers_2n(x)
x = self.layers_4n(x)
x = self.layers_6n(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc_out(x)
return x
In this practice we use resnet32 to train CIFAR-10.
def resnet18():
block = ResidualBlock
model = ResNet(3, block)
return model
def resnet32():
block = ResidualBlock
model = ResNet(5, block)
return model
By replacing ResidualBlock
with plain Block
(without residual connection), we can compare the effectiveness of residual connection.
class Block(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, down_sample=False):
super(Block, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.stride = stride
self.in_channels = in_channels
self.out_channels = out_channels
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
def cnn18():
block = Block
model = ResNet(3, block)
return model
def cnn32():
block = Block
model = ResNet(5, block)
return model
Training resnet is not different with other training schemes. We train 64000 batch steps with 128 batch size. The learning rate starts from 0.1 and is decayed at 32,000 and 48,000 step with 0.1 factor.
net = resnet18().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
decay_epoch = [32000, 48000]
step_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_epoch, gamma=0.1)
import time
start_time = time.time()
net.train()
step = 0
epochs = 0
losses = []
while step < 64000:
train_loss = 0.0
correct = 0.0
total = 0.0
for batch_idx, (x, y) in enumerate(cifar_loader['train']):
step += 1
step_lr_scheduler.step()
x, y = x.to(device), y.to(device)
out = net(x)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
correct += (out.argmax(1) == y).float().sum().item()
total += x.size(0)
train_loss += loss.item()
losses.append(train_loss)
epochs += 1
print("Epoch[{:d} ({:d}/64000) ({:.4f}sec)] loss: {:.2f} acc: {:.2f}".format(epochs, step, time.time()-start_time, train_loss, 100.*correct/total))
/usr/local/lib/python3.9/dist-packages/torch/optim/lr_scheduler.py:138: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
Epoch[1 (391/64000) (7.3776sec)] loss: 770.94 acc: 28.60
Epoch[2 (782/64000) (14.8934sec)] loss: 553.97 acc: 48.27
Epoch[3 (1173/64000) (22.6149sec)] loss: 423.91 acc: 61.23
Epoch[4 (1564/64000) (30.1059sec)] loss: 338.47 acc: 69.58
Epoch[5 (1955/64000) (37.5786sec)] loss: 292.68 acc: 73.91
Epoch[6 (2346/64000) (44.9162sec)] loss: 262.95 acc: 76.47
Epoch[7 (2737/64000) (52.4001sec)] loss: 240.67 acc: 78.48
Epoch[8 (3128/64000) (59.6620sec)] loss: 226.10 acc: 80.03
Epoch[9 (3519/64000) (66.7264sec)] loss: 211.61 acc: 81.14
Epoch[10 (3910/64000) (74.0309sec)] loss: 201.14 acc: 82.09
Epoch[11 (4301/64000) (81.4279sec)] loss: 192.44 acc: 82.89
Epoch[12 (4692/64000) (88.6449sec)] loss: 185.66 acc: 83.48
Epoch[13 (5083/64000) (95.6921sec)] loss: 179.04 acc: 84.08
Epoch[14 (5474/64000) (103.2531sec)] loss: 172.97 acc: 84.52
Epoch[15 (5865/64000) (110.4206sec)] loss: 168.45 acc: 85.00
Epoch[16 (6256/64000) (117.7995sec)] loss: 163.93 acc: 85.55
Epoch[17 (6647/64000) (125.0056sec)] loss: 159.06 acc: 85.89
Epoch[18 (7038/64000) (131.9813sec)] loss: 154.58 acc: 86.24
Epoch[19 (7429/64000) (139.2963sec)] loss: 149.25 acc: 86.76
Epoch[20 (7820/64000) (146.5095sec)] loss: 147.14 acc: 86.90
Epoch[21 (8211/64000) (153.8556sec)] loss: 145.00 acc: 86.98
Epoch[22 (8602/64000) (161.2942sec)] loss: 143.00 acc: 87.27
Epoch[23 (8993/64000) (168.7739sec)] loss: 140.13 acc: 87.60
Epoch[24 (9384/64000) (176.0373sec)] loss: 138.12 acc: 87.62
Epoch[25 (9775/64000) (183.3021sec)] loss: 134.82 acc: 88.06
Epoch[26 (10166/64000) (190.4438sec)] loss: 132.00 acc: 88.19
Epoch[27 (10557/64000) (198.0074sec)] loss: 129.47 acc: 88.56
Epoch[28 (10948/64000) (205.5802sec)] loss: 127.74 acc: 88.62
Epoch[29 (11339/64000) (212.6293sec)] loss: 127.16 acc: 88.73
Epoch[30 (11730/64000) (219.8726sec)] loss: 124.96 acc: 88.80
Epoch[31 (12121/64000) (227.2215sec)] loss: 123.96 acc: 88.99
Epoch[32 (12512/64000) (234.4533sec)] loss: 121.24 acc: 89.43
Epoch[33 (12903/64000) (241.7689sec)] loss: 121.30 acc: 89.14
Epoch[34 (13294/64000) (248.7863sec)] loss: 117.79 acc: 89.52
Epoch[35 (13685/64000) (256.2139sec)] loss: 118.10 acc: 89.44
Epoch[36 (14076/64000) (263.5176sec)] loss: 115.68 acc: 89.59
Epoch[37 (14467/64000) (270.9365sec)] loss: 115.37 acc: 89.69
Epoch[38 (14858/64000) (278.3734sec)] loss: 115.04 acc: 89.83
Epoch[39 (15249/64000) (285.9757sec)] loss: 112.16 acc: 89.97
Epoch[40 (15640/64000) (292.9311sec)] loss: 111.84 acc: 89.85
Epoch[41 (16031/64000) (300.4760sec)] loss: 109.70 acc: 90.10
Epoch[42 (16422/64000) (307.9932sec)] loss: 109.83 acc: 90.18
Epoch[43 (16813/64000) (315.1947sec)] loss: 107.85 acc: 90.25
Epoch[44 (17204/64000) (322.2472sec)] loss: 107.05 acc: 90.46
Epoch[45 (17595/64000) (329.4798sec)] loss: 105.78 acc: 90.47
Epoch[46 (17986/64000) (336.6178sec)] loss: 107.77 acc: 90.26
Epoch[47 (18377/64000) (343.9944sec)] loss: 107.88 acc: 90.28
Epoch[48 (18768/64000) (351.4488sec)] loss: 105.09 acc: 90.56
Epoch[49 (19159/64000) (359.2252sec)] loss: 105.22 acc: 90.65
Epoch[50 (19550/64000) (366.4349sec)] loss: 103.85 acc: 90.72
Epoch[51 (19941/64000) (373.9526sec)] loss: 103.95 acc: 90.60
Epoch[52 (20332/64000) (381.5875sec)] loss: 100.33 acc: 90.91
Epoch[53 (20723/64000) (389.3188sec)] loss: 101.65 acc: 90.89
Epoch[54 (21114/64000) (396.4953sec)] loss: 100.06 acc: 90.99
Epoch[55 (21505/64000) (403.9685sec)] loss: 101.06 acc: 90.89
Epoch[56 (21896/64000) (411.0765sec)] loss: 99.19 acc: 91.10
Epoch[57 (22287/64000) (418.5723sec)] loss: 100.30 acc: 90.98
Epoch[58 (22678/64000) (426.1539sec)] loss: 99.05 acc: 91.10
Epoch[59 (23069/64000) (433.3991sec)] loss: 98.29 acc: 91.00
Epoch[60 (23460/64000) (440.8175sec)] loss: 99.99 acc: 91.18
Epoch[61 (23851/64000) (448.1590sec)] loss: 97.27 acc: 91.34
Epoch[62 (24242/64000) (455.1867sec)] loss: 98.78 acc: 91.03
Epoch[63 (24633/64000) (462.4172sec)] loss: 97.50 acc: 91.19
Epoch[64 (25024/64000) (469.4981sec)] loss: 97.14 acc: 91.25
Epoch[65 (25415/64000) (477.0369sec)] loss: 94.74 acc: 91.54
Epoch[66 (25806/64000) (484.3379sec)] loss: 98.58 acc: 91.23
Epoch[67 (26197/64000) (491.7179sec)] loss: 96.46 acc: 91.20
Epoch[68 (26588/64000) (498.8555sec)] loss: 94.97 acc: 91.48
Epoch[69 (26979/64000) (506.3740sec)] loss: 95.55 acc: 91.46
Epoch[70 (27370/64000) (513.6973sec)] loss: 91.85 acc: 91.77
Epoch[71 (27761/64000) (521.2321sec)] loss: 91.67 acc: 91.58
Epoch[72 (28152/64000) (528.9102sec)] loss: 93.36 acc: 91.51
Epoch[73 (28543/64000) (536.1781sec)] loss: 93.85 acc: 91.56
Epoch[74 (28934/64000) (543.5638sec)] loss: 91.23 acc: 91.83
Epoch[75 (29325/64000) (550.4284sec)] loss: 94.02 acc: 91.56
Epoch[76 (29716/64000) (557.9773sec)] loss: 92.61 acc: 91.70
Epoch[77 (30107/64000) (565.2475sec)] loss: 92.84 acc: 91.59
Epoch[78 (30498/64000) (572.4811sec)] loss: 92.01 acc: 91.85
Epoch[79 (30889/64000) (580.0703sec)] loss: 89.81 acc: 91.93
Epoch[80 (31280/64000) (587.3492sec)] loss: 91.17 acc: 91.83
Epoch[81 (31671/64000) (594.7008sec)] loss: 90.31 acc: 91.84
Epoch[82 (32062/64000) (602.0518sec)] loss: 88.09 acc: 92.12
Epoch[83 (32453/64000) (609.5276sec)] loss: 55.12 acc: 95.01
Epoch[84 (32844/64000) (616.7906sec)] loss: 47.74 acc: 95.84
Epoch[85 (33235/64000) (624.1608sec)] loss: 42.93 acc: 96.36
Epoch[86 (33626/64000) (631.1735sec)] loss: 40.43 acc: 96.55
Epoch[87 (34017/64000) (638.5726sec)] loss: 39.21 acc: 96.62
Epoch[88 (34408/64000) (645.9556sec)] loss: 36.76 acc: 96.87
Epoch[89 (34799/64000) (653.2316sec)] loss: 36.14 acc: 96.96
Epoch[90 (35190/64000) (660.9627sec)] loss: 33.30 acc: 97.22
Epoch[91 (35581/64000) (668.3827sec)] loss: 31.57 acc: 97.17
Epoch[92 (35972/64000) (675.5018sec)] loss: 31.88 acc: 97.22
Epoch[93 (36363/64000) (682.8006sec)] loss: 30.31 acc: 97.43
Epoch[94 (36754/64000) (689.9211sec)] loss: 29.82 acc: 97.39
Epoch[95 (37145/64000) (697.3719sec)] loss: 28.64 acc: 97.57
Epoch[96 (37536/64000) (705.3345sec)] loss: 28.83 acc: 97.45
Epoch[97 (37927/64000) (712.7620sec)] loss: 27.33 acc: 97.71
Epoch[98 (38318/64000) (720.2508sec)] loss: 25.83 acc: 97.87
Epoch[99 (38709/64000) (727.6618sec)] loss: 25.94 acc: 97.68
Epoch[100 (39100/64000) (734.5958sec)] loss: 25.71 acc: 97.73
Epoch[101 (39491/64000) (742.2419sec)] loss: 24.99 acc: 97.81
Epoch[102 (39882/64000) (750.0543sec)] loss: 24.45 acc: 97.92
Epoch[103 (40273/64000) (757.1618sec)] loss: 24.02 acc: 97.92
Epoch[104 (40664/64000) (764.5942sec)] loss: 22.69 acc: 98.08
Epoch[105 (41055/64000) (771.7031sec)] loss: 22.70 acc: 98.06
Epoch[106 (41446/64000) (779.2763sec)] loss: 21.59 acc: 98.18
Epoch[107 (41837/64000) (786.7181sec)] loss: 22.31 acc: 98.04
Epoch[108 (42228/64000) (793.8452sec)] loss: 20.78 acc: 98.23
Epoch[109 (42619/64000) (801.4266sec)] loss: 20.54 acc: 98.20
Epoch[110 (43010/64000) (808.9340sec)] loss: 20.38 acc: 98.28
Epoch[111 (43401/64000) (816.5496sec)] loss: 20.61 acc: 98.23
Epoch[112 (43792/64000) (823.7053sec)] loss: 20.64 acc: 98.22
Epoch[113 (44183/64000) (831.0946sec)] loss: 19.89 acc: 98.25
Epoch[114 (44574/64000) (838.1194sec)] loss: 20.06 acc: 98.26
Epoch[115 (44965/64000) (845.7004sec)] loss: 19.35 acc: 98.36
Epoch[116 (45356/64000) (853.1254sec)] loss: 19.46 acc: 98.36
Epoch[117 (45747/64000) (860.7161sec)] loss: 19.38 acc: 98.33
Epoch[118 (46138/64000) (868.0838sec)] loss: 19.21 acc: 98.34
Epoch[119 (46529/64000) (875.4143sec)] loss: 18.16 acc: 98.43
Epoch[120 (46920/64000) (883.0658sec)] loss: 18.06 acc: 98.42
Epoch[121 (47311/64000) (890.5691sec)] loss: 17.26 acc: 98.52
Epoch[122 (47702/64000) (898.1425sec)] loss: 17.24 acc: 98.57
Epoch[123 (48093/64000) (905.6408sec)] loss: 17.67 acc: 98.54
Epoch[124 (48484/64000) (912.7591sec)] loss: 15.04 acc: 98.80
Epoch[125 (48875/64000) (919.8622sec)] loss: 14.24 acc: 98.87
Epoch[126 (49266/64000) (927.0275sec)] loss: 13.96 acc: 98.88
Epoch[127 (49657/64000) (934.2595sec)] loss: 13.83 acc: 98.91
Epoch[128 (50048/64000) (941.7233sec)] loss: 12.73 acc: 99.01
Epoch[129 (50439/64000) (949.4730sec)] loss: 13.47 acc: 98.94
Epoch[130 (50830/64000) (956.3760sec)] loss: 13.00 acc: 99.00
Epoch[131 (51221/64000) (964.1147sec)] loss: 12.73 acc: 99.03
Epoch[132 (51612/64000) (971.5320sec)] loss: 12.90 acc: 99.04
Epoch[133 (52003/64000) (978.7294sec)] loss: 12.80 acc: 99.01
Epoch[134 (52394/64000) (986.0192sec)] loss: 12.28 acc: 99.11
Epoch[135 (52785/64000) (993.1015sec)] loss: 12.77 acc: 99.02
Epoch[136 (53176/64000) (1000.2869sec)] loss: 12.77 acc: 98.99
Epoch[137 (53567/64000) (1007.5633sec)] loss: 12.44 acc: 99.09
Epoch[138 (53958/64000) (1014.4154sec)] loss: 12.59 acc: 99.00
Epoch[139 (54349/64000) (1021.9794sec)] loss: 12.80 acc: 99.01
Epoch[140 (54740/64000) (1029.1151sec)] loss: 12.11 acc: 99.06
Epoch[141 (55131/64000) (1037.0154sec)] loss: 11.83 acc: 99.10
Epoch[142 (55522/64000) (1044.2304sec)] loss: 12.21 acc: 99.06
Epoch[143 (55913/64000) (1052.0665sec)] loss: 12.36 acc: 99.02
Epoch[144 (56304/64000) (1059.4328sec)] loss: 12.20 acc: 99.03
Epoch[145 (56695/64000) (1067.1640sec)] loss: 12.19 acc: 99.05
Epoch[146 (57086/64000) (1074.0894sec)] loss: 11.74 acc: 99.10
Epoch[147 (57477/64000) (1081.3311sec)] loss: 11.91 acc: 99.09
Epoch[148 (57868/64000) (1089.0102sec)] loss: 12.08 acc: 99.11
Epoch[149 (58259/64000) (1096.2499sec)] loss: 12.17 acc: 99.03
Epoch[150 (58650/64000) (1103.7006sec)] loss: 11.59 acc: 99.16
Epoch[151 (59041/64000) (1111.1318sec)] loss: 11.78 acc: 99.08
Epoch[152 (59432/64000) (1118.4786sec)] loss: 11.68 acc: 99.08
Epoch[153 (59823/64000) (1125.7784sec)] loss: 11.71 acc: 99.10
Epoch[154 (60214/64000) (1133.2468sec)] loss: 11.25 acc: 99.14
Epoch[155 (60605/64000) (1140.5466sec)] loss: 11.38 acc: 99.12
Epoch[156 (60996/64000) (1148.2231sec)] loss: 11.42 acc: 99.14
Epoch[157 (61387/64000) (1155.5876sec)] loss: 11.47 acc: 99.16
Epoch[158 (61778/64000) (1163.0776sec)] loss: 11.76 acc: 99.14
Epoch[159 (62169/64000) (1170.4764sec)] loss: 11.48 acc: 99.16
Epoch[160 (62560/64000) (1177.9013sec)] loss: 11.55 acc: 99.16
Epoch[161 (62951/64000) (1185.1235sec)] loss: 11.17 acc: 99.22
Epoch[162 (63342/64000) (1192.7358sec)] loss: 11.11 acc: 99.22
Epoch[163 (63733/64000) (1199.8564sec)] loss: 11.27 acc: 99.15
Epoch[164 (64124/64000) (1207.2868sec)] loss: 10.85 acc: 99.20
Plot train loss and calculate test performance.
import matplotlib
import matplotlib.pyplot as plt
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7fc6fc5d45b0>]
net.eval()
test_correct = 0.0
test_total = 0.0
for batch_idx, (x, y) in enumerate(cifar_loader['test']):
x, y = x.to(device), y.to(device)
out = net(x)
test_correct += (out.argmax(1) == y).float().sum().item()
test_total += x.size(0)
print(test_correct/test_total * 100.)
90.91
Train CNNs wihtout residual connection.
start_time = time.time()
net_plain = cnn18().to(device)
optimizer = torch.optim.SGD(net_plain.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
decay_epoch = [32000, 48000]
step_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_epoch, gamma=0.1)
net_plain.train()
step = 0
epochs = 0
losses_plain = []
while step < 64000:
train_loss = 0.0
correct = 0.0
total = 0.0
for batch_idx, (x, y) in enumerate(cifar_loader['train']):
step += 1
step_lr_scheduler.step()
x, y = x.to(device), y.to(device)
out = net_plain(x)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
correct += (out.argmax(1) == y).float().sum().item()
total += x.size(0)
train_loss += loss.item()
losses_plain.append(train_loss)
epochs += 1
print("Epoch[{:d} ({:d}/64000) ({:.4f}sec)] loss: {:.2f} acc: {:.2f}".format(epochs, step, time.time()-start_time, train_loss, 100.*correct/total))
Epoch[1 (391/64000) (7.2146sec)] loss: 689.96 acc: 32.64
Epoch[2 (782/64000) (14.0851sec)] loss: 534.03 acc: 49.99
Epoch[3 (1173/64000) (21.1075sec)] loss: 431.50 acc: 60.38
Epoch[4 (1564/64000) (27.8955sec)] loss: 365.12 acc: 67.19
Epoch[5 (1955/64000) (35.2117sec)] loss: 322.34 acc: 71.42
Epoch[6 (2346/64000) (42.2845sec)] loss: 294.47 acc: 73.99
Epoch[7 (2737/64000) (49.2290sec)] loss: 271.72 acc: 75.82
Epoch[8 (3128/64000) (56.3330sec)] loss: 254.65 acc: 77.43
Epoch[9 (3519/64000) (63.5271sec)] loss: 241.44 acc: 78.56
Epoch[10 (3910/64000) (70.4259sec)] loss: 230.29 acc: 79.66
Epoch[11 (4301/64000) (77.7810sec)] loss: 221.22 acc: 80.33
Epoch[12 (4692/64000) (84.4624sec)] loss: 213.80 acc: 80.99
Epoch[13 (5083/64000) (91.6520sec)] loss: 207.05 acc: 81.70
Epoch[14 (5474/64000) (98.5922sec)] loss: 202.37 acc: 82.03
Epoch[15 (5865/64000) (105.6009sec)] loss: 196.83 acc: 82.59
Epoch[16 (6256/64000) (112.5703sec)] loss: 190.15 acc: 83.11
Epoch[17 (6647/64000) (119.7448sec)] loss: 186.84 acc: 83.63
Epoch[18 (7038/64000) (126.8488sec)] loss: 183.72 acc: 83.64
Epoch[19 (7429/64000) (133.8145sec)] loss: 180.63 acc: 84.00
Epoch[20 (7820/64000) (141.3312sec)] loss: 173.23 acc: 84.69
Epoch[21 (8211/64000) (148.5181sec)] loss: 171.94 acc: 84.75
Epoch[22 (8602/64000) (155.6639sec)] loss: 172.15 acc: 84.83
Epoch[23 (8993/64000) (162.7647sec)] loss: 166.39 acc: 85.33
Epoch[24 (9384/64000) (169.8326sec)] loss: 164.23 acc: 85.60
Epoch[25 (9775/64000) (176.5753sec)] loss: 162.15 acc: 85.63
Epoch[26 (10166/64000) (184.2141sec)] loss: 159.41 acc: 85.83
Epoch[27 (10557/64000) (191.5325sec)] loss: 157.73 acc: 86.02
Epoch[28 (10948/64000) (198.5394sec)] loss: 154.05 acc: 86.34
Epoch[29 (11339/64000) (205.9462sec)] loss: 154.87 acc: 86.25
Epoch[30 (11730/64000) (213.1766sec)] loss: 151.16 acc: 86.58
Epoch[31 (12121/64000) (220.3204sec)] loss: 148.59 acc: 86.81
Epoch[32 (12512/64000) (227.7061sec)] loss: 149.22 acc: 86.74
Epoch[33 (12903/64000) (234.9168sec)] loss: 145.34 acc: 87.20
Epoch[34 (13294/64000) (242.4034sec)] loss: 145.30 acc: 87.09
Epoch[35 (13685/64000) (249.7116sec)] loss: 144.28 acc: 87.20
Epoch[36 (14076/64000) (256.6604sec)] loss: 145.93 acc: 87.09
Epoch[37 (14467/64000) (263.6526sec)] loss: 141.30 acc: 87.39
Epoch[38 (14858/64000) (270.8084sec)] loss: 141.42 acc: 87.42
Epoch[39 (15249/64000) (278.1218sec)] loss: 139.26 acc: 87.79
Epoch[40 (15640/64000) (285.5205sec)] loss: 138.97 acc: 87.62
Epoch[41 (16031/64000) (292.6338sec)] loss: 137.25 acc: 87.71
Epoch[42 (16422/64000) (299.8960sec)] loss: 137.89 acc: 87.68
Epoch[43 (16813/64000) (306.4615sec)] loss: 135.12 acc: 88.06
Epoch[44 (17204/64000) (313.7000sec)] loss: 134.36 acc: 88.04
Epoch[45 (17595/64000) (321.3406sec)] loss: 132.38 acc: 88.35
Epoch[46 (17986/64000) (328.3074sec)] loss: 131.92 acc: 88.36
Epoch[47 (18377/64000) (335.5251sec)] loss: 132.07 acc: 88.24
Epoch[48 (18768/64000) (342.7595sec)] loss: 130.68 acc: 88.37
Epoch[49 (19159/64000) (349.6263sec)] loss: 131.12 acc: 88.36
Epoch[50 (19550/64000) (357.0380sec)] loss: 129.32 acc: 88.59
Epoch[51 (19941/64000) (364.0892sec)] loss: 129.24 acc: 88.47
Epoch[52 (20332/64000) (371.3756sec)] loss: 125.61 acc: 88.80
Epoch[53 (20723/64000) (378.6191sec)] loss: 126.88 acc: 88.69
Epoch[54 (21114/64000) (386.2752sec)] loss: 127.91 acc: 88.60
Epoch[55 (21505/64000) (393.5364sec)] loss: 128.70 acc: 88.63
Epoch[56 (21896/64000) (401.0215sec)] loss: 125.44 acc: 88.89
Epoch[57 (22287/64000) (408.1538sec)] loss: 125.58 acc: 88.75
Epoch[58 (22678/64000) (415.5634sec)] loss: 125.72 acc: 88.87
Epoch[59 (23069/64000) (422.6407sec)] loss: 123.13 acc: 88.96
Epoch[60 (23460/64000) (430.0556sec)] loss: 122.40 acc: 89.13
Epoch[61 (23851/64000) (437.3062sec)] loss: 125.74 acc: 88.78
Epoch[62 (24242/64000) (444.5036sec)] loss: 122.89 acc: 89.15
Epoch[63 (24633/64000) (452.3249sec)] loss: 123.77 acc: 88.91
Epoch[64 (25024/64000) (459.5320sec)] loss: 122.94 acc: 89.25
Epoch[65 (25415/64000) (466.6736sec)] loss: 120.77 acc: 89.23
Epoch[66 (25806/64000) (473.8049sec)] loss: 120.41 acc: 89.33
Epoch[67 (26197/64000) (481.0442sec)] loss: 121.50 acc: 89.21
Epoch[68 (26588/64000) (488.1004sec)] loss: 119.03 acc: 89.32
Epoch[69 (26979/64000) (495.2542sec)] loss: 119.38 acc: 89.39
Epoch[70 (27370/64000) (502.5304sec)] loss: 118.89 acc: 89.41
Epoch[71 (27761/64000) (509.6415sec)] loss: 118.88 acc: 89.42
Epoch[72 (28152/64000) (517.0511sec)] loss: 119.19 acc: 89.33
Epoch[73 (28543/64000) (524.6494sec)] loss: 116.20 acc: 89.56
Epoch[74 (28934/64000) (532.1471sec)] loss: 117.11 acc: 89.50
Epoch[75 (29325/64000) (539.0524sec)] loss: 116.53 acc: 89.57
Epoch[76 (29716/64000) (546.3680sec)] loss: 116.04 acc: 89.69
Epoch[77 (30107/64000) (553.8518sec)] loss: 114.30 acc: 89.81
Epoch[78 (30498/64000) (560.7083sec)] loss: 117.23 acc: 89.61
Epoch[79 (30889/64000) (568.3386sec)] loss: 115.09 acc: 89.64
Epoch[80 (31280/64000) (575.4272sec)] loss: 115.39 acc: 89.71
Epoch[81 (31671/64000) (582.9671sec)] loss: 115.24 acc: 89.78
Epoch[82 (32062/64000) (590.4794sec)] loss: 111.38 acc: 90.12
Epoch[83 (32453/64000) (597.6994sec)] loss: 72.23 acc: 93.75
Epoch[84 (32844/64000) (604.8231sec)] loss: 62.88 acc: 94.50
Epoch[85 (33235/64000) (611.8527sec)] loss: 56.76 acc: 94.99
Epoch[86 (33626/64000) (619.0148sec)] loss: 54.20 acc: 95.25
Epoch[87 (34017/64000) (626.0946sec)] loss: 51.18 acc: 95.52
Epoch[88 (34408/64000) (633.2502sec)] loss: 48.66 acc: 95.59
Epoch[89 (34799/64000) (640.3661sec)] loss: 47.33 acc: 95.76
Epoch[90 (35190/64000) (647.5268sec)] loss: 45.99 acc: 96.00
Epoch[91 (35581/64000) (654.4824sec)] loss: 45.55 acc: 95.94
Epoch[92 (35972/64000) (661.5328sec)] loss: 43.50 acc: 96.10
Epoch[93 (36363/64000) (669.0501sec)] loss: 42.39 acc: 96.25
Epoch[94 (36754/64000) (675.8983sec)] loss: 41.10 acc: 96.30
Epoch[95 (37145/64000) (683.0879sec)] loss: 40.02 acc: 96.43
Epoch[96 (37536/64000) (689.9680sec)] loss: 39.87 acc: 96.43
Epoch[97 (37927/64000) (697.5684sec)] loss: 38.91 acc: 96.40
Epoch[98 (38318/64000) (705.0191sec)] loss: 37.55 acc: 96.64
Epoch[99 (38709/64000) (712.0232sec)] loss: 38.73 acc: 96.52
Epoch[100 (39100/64000) (719.0744sec)] loss: 37.01 acc: 96.65
Epoch[101 (39491/64000) (726.0455sec)] loss: 35.86 acc: 96.82
Epoch[102 (39882/64000) (733.0002sec)] loss: 34.39 acc: 96.92
Epoch[103 (40273/64000) (740.0890sec)] loss: 34.96 acc: 96.80
Epoch[104 (40664/64000) (747.1744sec)] loss: 34.97 acc: 96.85
Epoch[105 (41055/64000) (754.5613sec)] loss: 33.63 acc: 97.05
Epoch[106 (41446/64000) (761.6926sec)] loss: 33.05 acc: 96.99
Epoch[107 (41837/64000) (768.4621sec)] loss: 31.74 acc: 97.07
Epoch[108 (42228/64000) (775.7621sec)] loss: 31.21 acc: 97.12
Epoch[109 (42619/64000) (783.0051sec)] loss: 30.95 acc: 97.28
Epoch[110 (43010/64000) (790.3146sec)] loss: 30.28 acc: 97.30
Epoch[111 (43401/64000) (797.2794sec)] loss: 30.53 acc: 97.29
Epoch[112 (43792/64000) (804.7434sec)] loss: 28.15 acc: 97.52
Epoch[113 (44183/64000) (811.9753sec)] loss: 28.40 acc: 97.43
Epoch[114 (44574/64000) (819.5330sec)] loss: 29.56 acc: 97.38
Epoch[115 (44965/64000) (826.2002sec)] loss: 29.17 acc: 97.38
Epoch[116 (45356/64000) (833.0710sec)] loss: 27.47 acc: 97.51
Epoch[117 (45747/64000) (839.8714sec)] loss: 28.09 acc: 97.48
Epoch[118 (46138/64000) (847.1031sec)] loss: 27.38 acc: 97.54
Epoch[119 (46529/64000) (853.9757sec)] loss: 27.01 acc: 97.68
Epoch[120 (46920/64000) (861.0741sec)] loss: 27.05 acc: 97.49
Epoch[121 (47311/64000) (868.4948sec)] loss: 25.27 acc: 97.70
Epoch[122 (47702/64000) (875.9437sec)] loss: 26.23 acc: 97.68
Epoch[123 (48093/64000) (883.0364sec)] loss: 25.37 acc: 97.69
Epoch[124 (48484/64000) (890.1893sec)] loss: 21.02 acc: 98.17
Epoch[125 (48875/64000) (897.0977sec)] loss: 19.10 acc: 98.36
Epoch[126 (49266/64000) (904.3798sec)] loss: 19.64 acc: 98.38
Epoch[127 (49657/64000) (911.3010sec)] loss: 19.37 acc: 98.37
Epoch[128 (50048/64000) (918.2357sec)] loss: 18.63 acc: 98.46
Epoch[129 (50439/64000) (925.3542sec)] loss: 18.14 acc: 98.56
Epoch[130 (50830/64000) (932.3081sec)] loss: 18.32 acc: 98.48
Epoch[131 (51221/64000) (939.6098sec)] loss: 17.52 acc: 98.50
Epoch[132 (51612/64000) (946.9333sec)] loss: 18.32 acc: 98.46
Epoch[133 (52003/64000) (954.2119sec)] loss: 18.10 acc: 98.41
Epoch[134 (52394/64000) (961.7445sec)] loss: 17.30 acc: 98.60
Epoch[135 (52785/64000) (968.3853sec)] loss: 16.70 acc: 98.63
Epoch[136 (53176/64000) (975.5887sec)] loss: 17.46 acc: 98.57
Epoch[137 (53567/64000) (982.8676sec)] loss: 16.70 acc: 98.65
Epoch[138 (53958/64000) (989.9611sec)] loss: 16.97 acc: 98.61
Epoch[139 (54349/64000) (997.1548sec)] loss: 17.21 acc: 98.60
Epoch[140 (54740/64000) (1004.1162sec)] loss: 16.58 acc: 98.58
Epoch[141 (55131/64000) (1011.4397sec)] loss: 16.97 acc: 98.54
Epoch[142 (55522/64000) (1018.6986sec)] loss: 16.68 acc: 98.64
Epoch[143 (55913/64000) (1026.1204sec)] loss: 16.96 acc: 98.59
Epoch[144 (56304/64000) (1033.3482sec)] loss: 16.73 acc: 98.62
Epoch[145 (56695/64000) (1040.8932sec)] loss: 16.49 acc: 98.65
Epoch[146 (57086/64000) (1047.9436sec)] loss: 15.84 acc: 98.64
Epoch[147 (57477/64000) (1055.0756sec)] loss: 16.30 acc: 98.62
Epoch[148 (57868/64000) (1062.2017sec)] loss: 16.09 acc: 98.70
Epoch[149 (58259/64000) (1069.3741sec)] loss: 15.77 acc: 98.70
Epoch[150 (58650/64000) (1076.7737sec)] loss: 15.84 acc: 98.73
Epoch[151 (59041/64000) (1084.0027sec)] loss: 15.53 acc: 98.73
Epoch[152 (59432/64000) (1091.1483sec)] loss: 15.96 acc: 98.68
Epoch[153 (59823/64000) (1098.4126sec)] loss: 16.13 acc: 98.60
Epoch[154 (60214/64000) (1105.4872sec)] loss: 15.56 acc: 98.72
Epoch[155 (60605/64000) (1112.7199sec)] loss: 14.85 acc: 98.78
Epoch[156 (60996/64000) (1119.6450sec)] loss: 15.91 acc: 98.66
Epoch[157 (61387/64000) (1127.0011sec)] loss: 15.38 acc: 98.73
Epoch[158 (61778/64000) (1134.2234sec)] loss: 16.08 acc: 98.68
Epoch[159 (62169/64000) (1141.3570sec)] loss: 15.36 acc: 98.76
Epoch[160 (62560/64000) (1148.5884sec)] loss: 15.25 acc: 98.80
Epoch[161 (62951/64000) (1155.6970sec)] loss: 15.60 acc: 98.76
Epoch[162 (63342/64000) (1162.9399sec)] loss: 15.12 acc: 98.72
Epoch[163 (63733/64000) (1170.2920sec)] loss: 14.52 acc: 98.81
Epoch[164 (64124/64000) (1177.5772sec)] loss: 14.71 acc: 98.79
Plot train loss and calculate test performance.
plt.plot(losses, label='resnet')
plt.plot(losses_plain, label='cnn')
plt.legend()
plt.show()
net_plain.eval()
test_correct = 0.0
test_total = 0.0
for batch_idx, (x, y) in enumerate(cifar_loader['test']):
x, y = x.to(device), y.to(device)
out = net_plain(x)
test_correct += (out.argmax(1) == y).float().sum().item()
test_total += x.size(0)
print(test_correct/test_total * 100.)
90.47
Reference
- AI504: Programming for AI Lecture at KAIST AI