📒 MNIST data set
📝 What is MNIST?
- 손으로 쓴 숫자의 dataset으로, 0 ~ 9 까지의 숫자이다.
- 링크에서 다운로드 가능하다.
- 28 x 28 해상도의 이미지이며, 1개의 채널을 갖는 gray image이다.
- torchvision 에서 다양하고 유명한 dataset들을 사용할 수 있다.
📝 In PyTorch (학습)
import torchvision.datsets as dset
mnist_train = dsets.MNIST(root="MNIST_data/", train=True, transform=transforms.ToTensor(), download=True)
mnist_test = dsets.MNIST(root="MNIST_data/", train=False, transform=transforms.ToTensor(), download=True)
data_loader = torch.utils.DataLoader(DataLoader=mnist_train, batch_size=100, shuffle=True, drop_last=True)
linear = torch.nn.Linear(784, 10, bias=True)
traning_epochs = 15
batch_size = 100
criterion = torch.nn.CrossEntrophyLoss()
optimizer = torch.optim.SGD(linear.parameters(), lr=0.1)
for epoch in range(training_epochs):
avg_cost = 0
total_batch = len(data_loader)
for X, Y in data_loader:
X = X.view(-1, 28 * 28)
optimizer.zero_grad()
hypothesis = linear(X)
cost = criterion(hypothesis, Y)
cost.backward()
optimizer.step()
avg_cost += cost / total_batch
print(f'cost = {avg_cost}')
📝 In PyTorch (테스트)
With torch.no_grad():
X_test = mnist_test.test_data.view(-1, 28 * 28).float()
Y_test = mnist_test.test_labels
prediction = linear(X_test)
correct_prediction = torch.argmax(prediction, 1) == Y_test
accuracy = correct_predicition.float().mean()
print(f"Accuracy: {accuracy.item()}")