MNIST

이규호·2021년 1월 31일
0

📒 MNIST data set


📝 What is MNIST?


  • 손으로 쓴 숫자의 dataset으로, 0 ~ 9 까지의 숫자이다.
  • 링크에서 다운로드 가능하다.
  • 28 x 28 해상도의 이미지이며, 1개의 채널을 갖는 gray image이다.
  • torchvision 에서 다양하고 유명한 dataset들을 사용할 수 있다.

📝 In PyTorch (학습)


import torchvision.datsets as dset

# train dataset
mnist_train = dsets.MNIST(root="MNIST_data/", train=True, transform=transforms.ToTensor(), download=True)
# test dataset
mnist_test = dsets.MNIST(root="MNIST_data/", train=False, transform=transforms.ToTensor(), download=True)
# minibatch를 위한 데이터로더
data_loader = torch.utils.DataLoader(DataLoader=mnist_train, batch_size=100, shuffle=True, drop_last=True)
# MNIST 데이터 이미지는 28 * 28 = 784, 0 ~ 9 = 10개
linear = torch.nn.Linear(784, 10, bias=True)
traning_epochs = 15
batch_size = 100
# Cost function, optimizer
criterion = torch.nn.CrossEntrophyLoss()
optimizer = torch.optim.SGD(linear.parameters(), lr=0.1)

for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = len(data_loader)
    # (image, label)
    for X, Y in data_loader:
        X = X.view(-1, 28 * 28)
        optimizer.zero_grad()
        hypothesis = linear(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()
        avg_cost += cost / total_batch
    print(f'cost = {avg_cost}')

📝 In PyTorch (테스트)


With torch.no_grad():
    X_test = mnist_test.test_data.view(-1, 28 * 28).float()
    Y_test = mnist_test.test_labels
    
    prediction = linear(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_predicition.float().mean()
    print(f"Accuracy: {accuracy.item()}")
    
profile
Beginner

0개의 댓글