[PyTorch] Lab07.2 - MNIST Introduction

Yun Geonil·2021년 2월 20일
0

📌 학습 목표


  • MNIST
  • Code : MNIST Classifier

MNIST

  • MNIST는 손으로 쓴 숫자 글씨 데이터이다.

    아래 그림과 같이 0 ~ 9의 숫자를 손으로 쓴 이미지 데이터가 포함되어있다. MNIST는 우체국에서 손으로 쓴 글씨를 분류하기 위해 만든 데이터셋이다.

  • MNIST Data example

- 28 x 28 image
- 1 chanel gray scale image
- 0~9 digits
  • torchvision

    torchvision library를 이용한다.

    torchvision 에는 datasets, models, transforms등 여러 패키지가 존재한다.

Code : MNIST Classifier

torchvision library를 이용해 손글씨 분류기를 만들어본다.

  • import and settings

    여기서 device는 연산을 수행할 공간을 정의하는 부분이다. GPU가 있다면 설정해주도록 하자.

    torchvision.datasetstorchvision.transforms를 import 해준다. datasets에는 MNIST데이터가 들어있고 transforms에는 이미지 데이터를 torch에 맞게 변환해주는 메서드들이 포함되어있다.

import torch
import torch.utils as utils
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dsets
import torchvision.transforms as transforms

device = 'cpu'
torch.manual_seed(1)

# hyper parameter
training_epochs = 15
batch_size = 100
learning_rate = 0.1
  • MNIST dataset

    root는 설치 경로, train은 train data인지, transform은 변환, download는 다운로드 할것인지를 설정하는 parameter이다.

mnist_train = dsets.MNIST(root='MNIST_data/', 
                          train=True, 
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/',
                         train=False, 
                         transform=transforms.ToTensor(),
                         download=True)
  • dataset loader

    torch.utilsdata.DataLoader를 이용해 데이터셋을 로드한다. dataset, batch_size, shuffle, drop_last 등의 parameter를 설정해준다.

    여기서 shuffle은 섞을 것인지를 설정하는 parameter이고, drop_last는 마지막에 batch_size를 충족시키지 못하는 데이터들을 버릴것인지를 설정하는 parameter이다.

data_loader = utils.data.DataLoader(dataset=mnist_train,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   drop_last=True)
  • model architect

    nn.Module을 상속받는 softmax분류기 모델을 만든다. 여기서 input은 28 by 28의 이미지 데이터이므로 이를 1차원으로 변환해 입력하기 때문에 28*28로 설정하고, output은 0~9까지의 10개 이므로 10으로 설정한다.

class SoftmaxClassifierModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(28*28, 10, bias=True).to(device)
    
    def forward(self, x):
        return self.linear(x)

model = SoftmaxClassifierModel()
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
  • training and test

    여기서 X data는 28 by 28의 2차원 vector인데 이를 1차원 vector로 만들기 위해 view() 를 사용하여 변환해준다.

 for epoch in range(training_epochs):
    avg_cost = 0
    total_batch = len(data_loader)
    
    for X, Y in data_loader:
        #reshape
        X = X.view(-1, 28*28).to(device)
        Y = Y.to(device)
        
        optimizer.zero_grad()
        
        hypothesis = model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()
        
        avg_cost += cost / total_batch
        
    print('Epoch: {:4d}/{} Cost: {:.9f}'.format(
        epoch+1, training_epochs, avg_cost
    ))
print('Learning finished')
'''
Epoch:    1/15 Cost: 0.532311916
Epoch:    2/15 Cost: 0.358782530
Epoch:    3/15 Cost: 0.330935270
Epoch:    4/15 Cost: 0.316523969
Epoch:    5/15 Cost: 0.306851596
Epoch:    6/15 Cost: 0.300046444
Epoch:    7/15 Cost: 0.295035958
Epoch:    8/15 Cost: 0.290838420
Epoch:    9/15 Cost: 0.287174433
Epoch:   10/15 Cost: 0.284407109
Epoch:   11/15 Cost: 0.282002598
Epoch:   12/15 Cost: 0.279584587
Epoch:   13/15 Cost: 0.277803093
Epoch:   14/15 Cost: 0.275842309
Epoch:   15/15 Cost: 0.274407804
Learning finished
'''

with torch.no_grad()란 gradient 계산을 하지 않겠다는 선언과 같다. 이 선언 아래에서 test데이터를 통해 accuracy를 계산한다.

훌륭한 성능은 아니지만 약 88%의 정확도를 가진 모델을 구축했다!

# Test model using test data
with torch.no_grad():
    X_test = mnist_test.data.view(-1, 28*28).float().to(device)
    Y_test = mnist_test.targets.to(device)
    
    prediction = model(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_prediction.float().mean()
    print('Accuracy', accuracy.item())
'''
Accuracy 0.8841999769210815
'''

0개의 댓글