[PYTORCH] 손글씨 숫자 이미지 분류

jarden·2023년 3월 8일
0

pytorch

목록 보기
1/1

Part 3. Pytorch

3.2 예제 : 손글씨 숫자 이미지 분류 문제

3.2.2 CNN으로 손글씨 숫자 이미지 분류하기

- 모듈 및 분석 환경 설정

[모듈 불러오기]
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim # 가중치 추정에 필요한 최적화 알고리즘을 포함
from torchvision import datasets, transforms

from matplotlib import pyplot as plt
[분석 환경 설정]
#### -- 1-2. 분석 환경 설정 -- ####
is_cuda = torch.cuda.is_available()
device = torch.device('cuda' if is_cuda else 'cpu')

print ('Current cuda device is', device)

가중치 업데이트 연산 과정에서 어떠한 장비를 선택할지에 대한 코드이다. CUDA를 통해 GPU를 사용할 수 있다면 torch.cuda.is.available()에 True값이, 사용할 수 없다면 False 값이 저장된다. 이에 따라 device에 'cuda' 또는 'cpu'로 설정된다.

- 데이터 불러오기

[MNIST 데이터 불러오기]
#### -- 2-1. MNIST 데이터 불러오기 -- ####
train_data = datasets.MNIST(root = './data', train = True, download = True, transform = transforms.ToTensor())
test_data = datasets.MNIST(root = './data', train = False, transform = transforms.ToTensor())

print('number of training data: ', len(train_data))
print('number of test data: ', len(test_data))
[MNIST 데이터 확인하기]
#### -- 2-2. MNIST 데이터 확인하기 -- ####
image, label = train_data[0]

plt.imshow(image.squeeze().numpy(), cmap = 'gray')
plt.title('label : %s' % label)
plt.show()

MNIST 데이터는 단일 채널로 [1, 28, 28] 3차원 텐서이다. 3차원 텐서를 2차원으로 줄이기 위해 image.squeeze()를 사용한다. squeeze() 함수는 크기가 1인 차원을 없애는 함수로 2차원인 [28, 28]로 만들어준다.

[미니 배치 구성하기]
#### -- 2-3. Mini-Batch 구성하기 -- ####
train_loader = torch.utils.data.DataLoader(dataset = train_data,
                                           batch_size = batch_size, shuffle = True)
test_loader  = torch.utils.data.DataLoader(dataset = test_data,
                                           batch_size = batch_size, shuffle = True)

first_batch = train_loader.__iter__().__next__()
print('{:15s} | {:<25s} | {}'.format('name', 'type', 'size'))
print('{:15s} | {:<25s} | {}'.format('Num of Batch', '', len(train_loader)))
print('{:15s} | {:<25s} | {}'.format('first_batch', str(type(first_batch)), len(first_batch)))
print('{:15s} | {:<25s} | {}'.format('first_batch[0]', str(type(first_batch[0])), first_batch[0].shape))
print('{:15s} | {:<25s} | {}'.format('first_batch[1]', str(type(first_batch[1])), first_batch[1].shape))
[CNN 구조 설계하기]
#### -- 3-1. CNN 구조 설계하기 -- ####
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
[Optimizer 및 손실 함수 정의]
#### -- 3-2. Optimizer 및 손실함수 정의 -- ####
model = CNN().to(device)
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss( )

MNIST는 다중 클래스 분류 문제이기에 교차 엔트로피를 손실함수로 사용한다.

[설계한 CNN 모형 확인하기]
#### -- 3-3. 설계한 CNN 모형 확인하기 -- ####
print(model)
[모델 학습]
#### -- 3-4. 모델 학습하기 -- ###
model.train()
i = 1
for epoch in range(epoch_num):
    for data, target in train_loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            print('Train Step: {}\tLoss: {:.3f}'.format(i, loss.item()))
        i += 1

- 모델 평가

#### -- 4. 모델 평가하기 -- ###
model.eval()
correct = 0
for data, target in test_loader:
#     data, target = Variable(data, volatile=True), Variable(target)
    data = data.to(device)
    target = target.to(device)
    output = model(data)
    prediction = output.data.max(1)[1]
    correct += prediction.eq(target.data).sum()

print('Test set: Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
profile
늦더라도 한 발짝씩 나아가는 컴공생

0개의 댓글