[Paper Review] An Architecture Combining Convolutional Neural Network(CNN) and Support Vector Machine(SVM) for Image Classification

์˜์„œ์ฟ ยท2021๋…„ 11์›” 1์ผ
2
post-thumbnail

์˜ค๋Š˜ ๋ฆฌ๋ทฐ/๋ฒˆ์—ญ/๊ตฌํ˜„ํ•  ๋…ผ๋ฌธ์€ "Abien Fred M. Agarap" ์ €์ž๊ฐ€ ์“ด ๋…ผ๋ฌธ์œผ๋กœ, "Yichuan Tang"์˜ "Deep Learning using Linear Support Vector Machines"์„ ๋ณด๊ณ  inspired๋˜์–ด ์—ฐ๊ตฌํ•˜๊ฒŒ ๋˜์—ˆ๋‹ค๊ณ  ํ•œ๋‹ค. ํ•˜๋‹จ์˜ ์ฐธ๊ณ  ๋…ผ๋ฌธ ์†Œ์Šค์— ํ•ด๋‹น ๋…ผ๋ฌธ ๋งํฌ์™€ ์ด๋ฒˆ ๋…ผ๋ฌธ์˜ ๋งํฌ๋ฅผ ์ฒจ๋ถ€์˜€๋‹ค.

(์ฐธ๊ณ ) ๋…ผ๋ฌธ ์†Œ์Šค

  • Deep Learning using Linear Support Vector Machines (ํด๋ฆญ)
  • An Architecture Combining Convolutional Neural Network (CNN) and Support Vector Machine (SVM) for Image Classification (ํด๋ฆญ)

Abstract

  • CNN(ํ•ฉ์„ฑ๊ณฑ์‹ ๊ฒฝ๋ง)์€ Hidden layer๋“ค๊ณผ learnable parameter๋“ค๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์œผ๋ฉฐ, ๊ฐ ๋‰ด๋Ÿฐ์—์„œ๋Š” input์„ ๋ฐ›์œผ๋ฉด ์ด๋ฅผ ๋‚ด์ ํ•˜๊ณ , ๋น„์„ ํ˜•์„ฑ์„ ๋”ํ•ด์ค€๋‹ค. Raw Image์™€ ํ•ด๋‹น class score๋ฅผ ์ด์–ด์ฃผ๋Š” ๋งค๊ฐœ์ฒด์˜ ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. (์ฃผ๋กœ CNN ๋งˆ์ง€๋ง‰ ๋‹จ์—๋Š” softmaxํ•จ์ˆ˜๊ฐ€ ์ด์šฉ์ด ๋œ๋‹ค.

  • ํ•˜์ง€๋งŒ, ๋ช‡๋ช‡ ๋…ผ๋ฌธ๋“ค์€ ์œ„์™€ ๊ฐ™์€ ๋ฐฉ๋ฒ•๋ก ์— ๋ฌธ์ œ๋ฅผ ์ œ๊ธฐํ•˜์˜€๋‹ค:

    • Abien Fred Agarap. 2017. A Neural Network Architecture Combining Gated Recurrent Unit (GRU) and Support Vector Machine (SVM) for Intrusion Detection in Network Traffic Data. arXiv preprint arXiv:1709.03082 (2017).
    • Abdulrahman Alalshekmubarak and Leslie S Smith. 2013. A novel approach combining recurrent neural network and support vector machines for time series classification. In Innovations in Information Technology (IIT), 2013 9th International Conference on. IEEE, 42โ€“47
    • Yichuan Tang. 2013. Deep learning using linear support vector machines. arXiv preprint arXiv:1306.0239 (2013).
  • ์œ„์—์„œ ๋ณด์—ฌ์ค€ ๋…ผ๋ฌธ๋“ค์€ ๊ณตํ†ต์ ์œผ๋กœ linear SVM์„ ์ด์šฉํ•˜๋Š” ๊ฒƒ์„ ์ œ์•ˆํ•œ๋‹ค. ์ด์— ์ €์ž๋Š” CNN๋‹จ์— Softmax ๋Œ€์‹  SVM์„ ์ด์šฉํ•˜์—ฌ ๋ถ„์„์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • MNIST

    • CNN-SVM : 99.04%
    • CNN-Softmax : 99.23%
  • MNIST-Fasion

    • CNN-SVM : 90.72%
    • CNN-Softmax : 91.86%
  • ์ €์ž๋Š” ์„ฑ๋Šฅ์€ ๋น„๋ก ์กฐ๊ธˆ ๋‚ฎ์„ ์ˆ˜ ์žˆ์„์ง€๋ผ๋„, ์ข€ ๋” ๊ณ ๋„ํ™”๋œ CNN์„ ์ด์šฉํ•˜๋ฉด ์„ฑ๋Šฅ์„ ๋”์šฑ ๋” ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋ผ๊ณ  ์ฃผ์žฅํ•œ๋‹ค.

๐Ÿ’ก ๋ฆฌ๋ทฐ ๋…ผ๋ฌธ ์„ ์ • ์ด์œ 
ํ•ด๋‹น ๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋ฅผ ์ด์šฉํ•˜์—ฌ State-of-the-art(SOTA)๋ฅผ ์ฐ์ง€๋Š” ์•Š์ง€๋งŒ, ํ›„์— ๋‹ค์–‘ํ•œ Vision ๋ถ„์•ผ์—์„œ ๋งˆ์ง€๋ง‰ ๋‹จ์— SVM Classifier๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ์— ๊ทผ๊ฐ„์ด ๋œ ๋…ผ๋ฌธ์„ ์„ ์ •ํ•˜๊ฒŒ ๋˜์—ˆ๋‹ค. ์ตœ๊ทผ ์—ฐ๊ตฌ์— ์žˆ์–ด์„œ ๋ชจ๋ธ์— ๊ฐ„๋‹จํ•œ ๋ณ€ํ™”๋ฅผ (๋”ํ•ด)์คŒ์œผ๋กœ์จ ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์„๊นŒ ํ•˜๋Š” ๊ณ ๋ฏผ์— ์ฐพ์•„๋ณด๊ณ  ์ •๋ฆฌํ•ด๋ณด๊ฒŒ ๋˜์—ˆ๋‹ค.


1. Introduction

  • ์œ„์— Abstract์—์„œ ๊ฐ„๋‹จํžˆ ์†Œ๊ฐœํ–ˆ๋“ฏ์ด NeuralNet(์ธ๊ณต์‹ ๊ฒฝ๋ง)์— softmax์ด์™ธ์— ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•๋ก (Ex. SVM)์„ ์ ์šฉํ•˜๋Š” ์—ฐ๊ตฌ๋“ค์ด ์ง„ํ–‰๋˜์–ด ์™”๋‹ค.
    • Abien Fred Agarap. 2017. A Neural Network Architecture Combining Gated Recurrent Unit (GRU) and Support Vector Machine (SVM) for Intrusion Detection in Network Traffic Data. arXiv preprint arXiv:1709.03082 (2017).
    • Abdulrahman Alalshekmubarak and Leslie S Smith. 2013. A novel approach combining recurrent neural network and support vector machines for time series classification. In Innovations in Information Technology (IIT), 2013 9th International Conference on. IEEE, 42โ€“47
    • Yichuan Tang. 2013. Deep learning using linear support vector machines. arXiv preprint arXiv:1306.0239 (2013).
  • ์ด๋Ÿฌํ•œ ์—ฐ๊ตฌ๋“ค์—์„œ ANN์— softmax๋ฅผ ์ ์šฉํ•˜๋Š” ๊ฒƒ๋ณด๋‹ค, SVM์„ ์ ์šฉํ•˜๋Š” ๊ฒƒ์ด ๋” ์ข‹๋‹ค๋Š” ๊ฒฐ๊ณผ๋“ค์ด ๋‚˜์™”๋‹ค. (์ด์ง„ ํŒ๋ณ„(binary classification) ํ•œ์ •, multinomial case์˜ ๊ฒฝ์šฐ one-versus-all ๋ฐฉ์‹ ์ฑ„์šฉ)
  • ํ•ด๋‹น ๋…ผ๋ฌธ์—์„œ๋Š” 2013๋…„์— ๋‚˜์˜จ "Deep learning using linear support vector machines" ๋…ผ๋ฌธ์—์„œ CNN๋ชจ๋ธ์„ ์ข€ ๋” ์‰ฝ๊ณ  ๊ฐ„ํŽธํ•œ 2-Conv Layer with Max Pooling๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ๋‹ค.

2. Metodology

2.1 Machine Intelligence Library

  • ํ•ด๋‹น ๋…ผ๋ฌธ์€ Google์˜ Tensorflow์„ ์ด์šฉํ•˜์—ฌ ์—ฐ๊ตฌ๋ฅผ ์ง„ํ–‰ํ•˜์˜€๋‹ค.
  • ์ด๋ฒˆ ๋…ผ๋ฌธ ๊ตฌํ˜„์— ์žˆ์–ด์„œ๋Š” ์ตœ๊ทผ ๊ฐ€์žฅ ๋งŽ์ด ์‚ฌ์šฉ๋˜๋Š” PyTorch๋ฅผ ์ด์šฉํ•˜์—ฌ ๋…ผ๋ฌธ๊ตฌํ˜„์„ ์ˆ˜ํ–‰ํ•ด๋ณด์•˜๋‹ค.
# Load Libraries
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.init
from torch.utils.data import Dataset
from torch.autograd import Variable
from PIL import Image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import helper

# GPU ์„ค์ •
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ๋žœ๋ค ์‹œ๋“œ ๊ณ ์ •
torch.manual_seed(123)

# GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ์ผ ๊ฒฝ์šฐ ๋žœ๋ค ์‹œ๋“œ ๊ณ ์ •
if device == 'cuda':
    torch.cuda.manual_seed_all(123)

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])

2.2 The Dataset

  • MNIST : 10-class classification problem having 60,000 training examples, and 10,000 test cases โ€“ all in grayscale

  • Fashion-MNIST : the same number of classes, and the same color profile as MNIST


Table 1: Dataset distribution for both MNIST and Fashion-MNIST

Import fashion-MINIST

# Download and load the training data
fashion_trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
fashion_trainloader = torch.utils.data.DataLoader(fashion_trainset, batch_size=128, shuffle=True)

# Download and load the test data
fashion_testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
fashion_testloader = torch.utils.data.DataLoader(fashion_testset, batch_size=128, shuffle=True)

Import MINIST

# Download and load the training data
mnist_trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
mnist_trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=128, shuffle=True)

# Download and load the test data
mnist_testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
mnist_testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=128, shuffle=True)
  • ๋ณ„๋„์˜ ์ „์ฒ˜๋ฆฌ๋Š” ์ˆ˜ํ–‰ํ•˜์ง€ ์•Š๋Š”๋‹ค. (No normalization or dimensionality reduction)

2.3 Support Vector Machine(SVM)

  • Support Vector Machine(SVM)์€ C. Cortes and V. Vapnik์— ์˜ํ•ด ๊ฐœ๋ฐœ๋œ ์ด์ง„๋ถ„๋ฅ˜ ๋ฐฉ๋ฒ•๋ก ์œผ๋กœ, ์ตœ์ ์˜ ์ดˆํ‰๋ฉด(f (w, x) = w ยท x + b)์„ ์ฐพ๋Š” ๋ฐ์— ์˜์˜๋ฅผ ๋‘”๋‹ค. ์ดˆํ‰๋ฉด์€ ์„œ๋กœ ๋‹ค๋ฅธ ๋‘ class๋ฅผ ๋ถ„๋ฅ˜ํ•ด์ค€๋‹ค.
  • SVM์€ ํ•ด๋‹น ์‹์„ ์ตœ์ ํ™”ํ•˜์—ฌ W parameter๋ฅผ ํ•™์Šตํ•œ๋‹ค.
    • L1-SVM

    • wTww^{T}w๋Š” Manhattan norm(L1 norm), C๋Š” penalty parameter, y'๋Š” ์‹ค์ œ y๊ฐ’, wTww^{T}w+b๋Š” ์˜ˆ์ธก y๊ฐ’์ด๋‹ค.
    • L2-SVM

    • โˆฃwโˆฃ2|w|^{2}๋Š” Euclidean norm(L2 norm), C๋Š” penalty parameter, y'๋Š” ์‹ค์ œ y๊ฐ’, wTww^{T}w+b๋Š” ์˜ˆ์ธก y๊ฐ’์ด๋‹ค.
class SVM:
		# set learning_rate, lambda, n iterations
    def __init__(self, learning_rate=0.001, lambda_param=0.01, n_iters=1000):
        self.lr = learning_rate
        self.lambda_param = lambda_param
        self.n_iters = n_iters
        self.w = None
        self.b = None

		# SVM fit function
    def fit(self, X, y):
        n_samples, n_features = X.shape
        
        y_ = np.where(y <= 0, -1, 1)
        
        self.w = np.zeros(n_features)
        self.b = 0

        for _ in range(self.n_iters):
            for idx, x_i in enumerate(X):
                condition = y_[idx] * (np.dot(x_i, self.w) - self.b) >= 1
                if condition:
                    self.w -= self.lr * (2 * self.lambda_param * self.w)
                else:
                    self.w -= self.lr * (2 * self.lambda_param * self.w - np.dot(x_i, y_[idx]))
                    self.b -= self.lr * y_[idx]

		# SVM predict function
    def predict(self, X):
        approx = np.dot(X, self.w) - self.b
        return np.sign(approx)

2.4 Convolutional Neural Network(CNN)

  • Convolutional Neural Network(CNN)์€ ์ปดํ“จํ„ฐ ๋น„์ „์—์„œ ๋งŽ์ด ์“ฐ์ด๋Š” deep feed-forward artificial neural network๋กœ, MLP ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ convolutional layers, pooling, ๊ทธ๋ฆฌ๊ณ  ๋น„์„ ํ˜• activation function์ธ tanh, sigmoid, ReLU ๋“ฑ์ด ์“ฐ์ธ๋‹ค.
  • ๋ณธ ์—ฐ๊ตฌ์—์„œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ธฐ๋ณธ CNN๋ชจ๋ธ์„ ์ด์šฉํ•œ๋‹ค.
    • 5x5x1 size filter
    • 2x2 max pooling
    • RELU as activation function (threshold = 0)

    • 10๋ฒˆ์งธ layer๋‹จ์—์„œ convolutional softmax ๋Œ€์‹  L2-SVM์„ ์ด์šฉํ•œ๋‹ค. ( y โˆˆ {-1, +1}, adam optimizer ์ด์šฉ)


์ €์ž๊ฐ€ ์ด์šฉํ•œ ๋ชจ๋ธ ๊ตฌ์กฐ(์ง์ ‘ ์ œ์ž‘)


์ €์ž๊ฐ€ ์ด์šฉํ•œ ๋ชจ๋ธ ๊ตฌ์กฐ(๋…ผ๋ฌธ ์ˆ˜๋ก)


์ €์ž๊ฐ€ ์ด์šฉํ•œ ๋ชจ๋ธ ๊ตฌ์กฐ(์ง์ ‘ ๊ตฌํ˜„)

CNN model

class CNN(torch.nn.Module):

    def __init__(self):
        super(CNN, self).__init__()
        self.drop_prob = 0.5

        # define layer1
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=5, stride=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=1))

        # define layer2
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=5, stride=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=1))

        # define fully connected layer (1024)
        self.fc1 = torch.nn.Linear(18 * 18 * 64, 1024, bias=True)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        self.layer3 = torch.nn.Sequential(
            self.fc1,
            torch.nn.Dropout(p= self.drop_prob))

            
        # define fully connected layer (10 classes)
        self.fc2 = torch.nn.Linear(1024, 10, bias=True)
        torch.nn.init.xavier_uniform_(self.fc2.weight)

    # define feed-forward
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)   # Flatten them for FC
        out = self.layer3(out)
        out = self.fc2(out)
        return out

CNN + SVM model (multi-Class Hinge Loss)

class multiClassHingeLoss(nn.Module):
    def __init__(self, p=1, margin=1, weight=None, size_average=True):
        super(multiClassHingeLoss, self).__init__()
        self.p=p
        self.margin=margin
        self.weight=weight
        self.size_average=size_average

    # define feed-forward		
    def forward(self, output, y):`
        output_y=output[torch.arange(0,y.size()[0]).long().cuda(),y.data.cuda()].view(-1,1)

        # output - output(y) + output(i)
        loss=output-output_y+self.margin

        # remove i=y items
        loss[torch.arange(0,y.size()[0]).long().cuda(),y.data.cuda()]=0
        
				# apply max function
        loss[loss<0]=0
        
				# apply power p function
        if(self.p!=1):
            loss=torch.pow(loss,self.p)

        # add weight
        if(self.weight is not None):
            loss=loss*self.weight

        # sum up
        loss=torch.sum(loss)

        if(self.size_average):
            loss/=output.size()[0]

        return loss

๐Ÿ’ก ์ž ๊น!! hinge loss๋ž€?

  • ํ•™์Šต๋ฐ์ดํ„ฐ ๊ฐ๊ฐ์˜ ๋ฒ”์ฃผ๋ฅผ ๊ตฌ๋ถ„ํ•˜๋ฉด์„œ ๋ฐ์ดํ„ฐ์™€์˜ ๊ฑฐ๋ฆฌ๊ฐ€ ๊ฐ€์žฅ ๋จผ ๊ฒฐ์ •๊ฒฝ๊ณ„(decision boundary)๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด ๊ณ ์•ˆ๋œ ์†์‹คํ•จ์ˆ˜์˜ ํ•œ ๋ถ€๋ฅ˜. ์ด๋กœ์จ ๋ฐ์ดํ„ฐ์™€ ๊ฒฝ๊ณ„ ์‚ฌ์ด์˜ ๋งˆ์ง„(margin)์ด ์ตœ๋Œ€ํ™”๋œ๋‹ค.
  • ์ด์ง„ ๋ถ„๋ฅ˜๋ฌธ์ œ์—์„œ ๋ชจ๋ธ์˜ ์˜ˆ์ธก๊ฐ’ย yโ€ฒ(์Šค์นผ๋ผ), ํ•™์Šต๋ฐ์ดํ„ฐ์˜ ์‹ค์ œ๊ฐ’ย y (-1 ๋˜๋Š” 1) ์‚ฌ์ด์˜ hinge loss๋Š” ์•„๋ž˜์™€ ๊ฐ™์ด ์ •์˜๋œ๋‹ค.
    loss=max(0,1โˆ’(yโ€ฒร—y))loss=max( 0, 1 โˆ’ (y' ร— y))

2.5 Data Analysis

  • 2๊ฐœ์˜ phase(train/test)
  • 2๊ฐœ์˜ dataset(MNIST, fashion-MNIST)

3. Experiments

  • ์•„๋ž˜ ๊ทธ๋ฆผ์€ ๊ฐ๊ฐ์˜ ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•˜์—ฌ ์„ค์ •ํ•ด์ค€ Hyper parameter ์ •๋ณด๋“ค์ด๋‹ค.


Table 2: Hyper-parameters used for CNN-Softmax and CNNSVM models.

Set hyper-parameter

learning_rate = 0.001
training_epochs = 50
# training_epochs = 10000
# ํ•ด๋‹น ๋…ผ๋ฌธ์—์„œ๋Š” ๋งŒ๋ฒˆ์˜ epoch๋ฅผ ์ˆ˜ํ–‰ํ–ˆ์ง€๋งŒ computation power๋กœ ์ธํ•ด epoch 50ํšŒ ์ˆ˜ํ–‰
batch_size = 128

Make Model for MNIST Data (CNN)

# MNIST CNN ๋ชจ๋ธ ์ •์˜
mnist_model = CNN().to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(mnist_model.parameters(), lr=learning_rate)

total_batch = len(mnist_trainloader)
print('์ด ๋ฐฐ์น˜์˜ ์ˆ˜ : {}'.format(total_batch))

Make Model for MNIST Data (CNN + SVM)

# MNIST CNN+SVM ๋ชจ๋ธ ์ •์˜
minst_SVM_model = CNN().to(device)

criterion = multiClassHingeLoss().to(device)
optimizer = torch.optim.Adam(minst_SVM_model.parameters(), lr=learning_rate)

total_batch = len(mnist_trainloader)
print('์ด ๋ฐฐ์น˜์˜ ์ˆ˜ : {}'.format(total_batch))

Make Model for fashion-MNIST Data (CNN)

# fashion-MNIST CNN ๋ชจ๋ธ ์ •์˜
fashion_model = CNN().to(device)

criterion = torch.nn.CrossEntropyLoss().to(device)    # ๋น„์šฉ ํ•จ์ˆ˜์— ์†Œํ”„ํŠธ๋งฅ์Šค ํ•จ์ˆ˜ ํฌํ•จ๋˜์–ด์ ธ ์žˆ์Œ.
optimizer = torch.optim.Adam(fashion_model.parameters(), lr=learning_rate)

total_batch = len(fashion_trainloader)
print('์ด ๋ฐฐ์น˜์˜ ์ˆ˜ : {}'.format(total_batch))

Make Model for fashion-MNIST Data (CNN + SVM)

# fashion-MNIST CNN + SVM ๋ชจ๋ธ ์ •์˜
fashion_SVM_model = CNN().to(device)

criterion = multiClassHingeLoss().to(device)
optimizer = torch.optim.Adam(fashion_SVM_model.parameters(), lr=learning_rate)

total_batch = len(fashion_trainloader)
print('์ด ๋ฐฐ์น˜์˜ ์ˆ˜ : {}'.format(total_batch))

Train Models

# mnist_model(CNN)
for epoch in range(training_epochs):
    avg_cost = 0

    for X, Y in mnist_trainloader: 
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()
        hypothesis = mnist_model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()

        avg_cost += cost / total_batch

    print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))

# minst_SVM_model(CNN + SVM)
for epoch in range(training_epochs):
    avg_cost = 0

    for X, Y in mnist_trainloader: 
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()
        hypothesis = minst_SVM_model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()

        avg_cost += cost / total_batch

    print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))
# fashion_model(CNN)
for epoch in range(training_epochs):
    avg_cost = 0

    for X, Y in fashion_trainloader: 
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()
        hypothesis = fashion_model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()

        avg_cost += cost / total_batch

    print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))

# fashion_SVM_model(CNN + SVM)
for epoch in range(training_epochs):
    avg_cost = 0

    for X, Y in fashion_trainloader: 
        X = X.to(device)
        Y = Y.to(device)

        optimizer.zero_grad()
        hypothesis = fashion_SVM_model(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()

        avg_cost += cost / total_batch

    print('[Epoch: {:>4}] cost = {:>.9}'.format(epoch + 1, avg_cost))

Test Models

# mnist_model(CNN)
with torch.no_grad():
    correct = 0
    total = 0
    for X_test, Y_test in mnist_testloader:
        X_test = X_test.to(device)
        Y_test = Y_test.to(device)
        prediction = mnist_model(X_test)
        predicted = torch.argmax(prediction, 1)
        total += Y_test.size(0)
        correct += (predicted == Y_test).sum().item()

print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# mnist_SVM_model(CNN+SVM)
with torch.no_grad():
    correct = 0
    total = 0
    for X_test, Y_test in mnist_testloader:
        X_test = X_test.to(device)
        Y_test = Y_test.to(device)
        prediction = mnist_SVM_model(X_test)
        predicted = torch.argmax(prediction, 1)
        total += Y_test.size(0)
        correct += (predicted == Y_test).sum().item()

print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
# fashion_model(CNN)
with torch.no_grad():
    correct = 0
    total = 0
    for X_test, Y_test in fashion_testloader:
        X_test = X_test.to(device)
        Y_test = Y_test.to(device)
        prediction = fashion_model(X_test)
        predicted = torch.argmax(prediction, 1)
        total += Y_test.size(0)
        correct += (predicted == Y_test).sum().item()

print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

# fashion_SVM_model(CNN + SVM)
with torch.no_grad():
    correct = 0
    total = 0
    for X_test, Y_test in fashion_testloader:
        X_test = X_test.to(device)
        Y_test = Y_test.to(device)
        prediction = fashion_SVM_model(X_test)
        predicted = torch.argmax(prediction, 1)
        total += Y_test.size(0)
        correct += (predicted == Y_test).sum().item()

print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
  • ๋ฐ‘์˜ ๊ทธ๋ฆผ์€ ๋ฐ์ดํ„ฐ ๋ถ„์„์˜ ๊ฒฐ๊ณผํ‘œ์ด๋‹ค.
    • Figure2 : CNN-Softmax์™€ CNN-SVM์˜ Training Accuracy๋ฅผ ์‹œ๊ฐํ™”ํ•œ ํ‘œ
      (MNIST)
    • Figure3 : CNN-Softmax์™€ CNN-SVM์˜ Training loss๋ฅผ ์‹œ๊ฐํ™”ํ•œ ํ‘œ
      (MNIST)
    • Figure4 : CNN-Softmax์™€ CNN-SVM์˜ Training Accuracy๋ฅผ ์‹œ๊ฐํ™”ํ•œ ํ‘œ
      (fashion-MNIST)
    • Figure5 : CNN-Softmax์™€ CNN-SVM์˜ Training loss๋ฅผ ์‹œ๊ฐํ™”ํ•œ ํ‘œ
      (fashion-MNIST)

  • ๋ชจ๋ธ ์„ฑ๋Šฅ (epoch = 10000)


Table 3: Test accuracy of CNN-Softmax and CNN-SVM on image classification using MNIST and Fashion-MNIST

  • ์ง์ ‘ ๊ตฌํ˜„ํ•œ ๋ชจ๋ธ ์„ฑ๋Šฅ (epoch = 50)
    • ํ•™์Šต์— ์ด์šฉํ•œ epoch์ˆ˜๊ฐ€ ์ƒ์ดํ•˜์—ฌ ์„ฑ๋Šฅ์— ์กฐ๊ธˆ์˜ ์ฐจ์ด๊ฐ€ ์žˆ์—ˆ์ง€๋งŒ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‹คํ—˜ํ™˜๊ฒฝ์„ ๋™์ผํ•˜๊ฒŒ ๊ตฌ์ถ•ํ•ด๋ณผ ์ˆ˜ ์žˆ์—ˆ๋‹ค.
DatasetCNN-softmaxCNN-SVM
MNIST98.47%98.77%
FASHION-MNIST88.13%87.84%

4. Conclusion and Rcommendation

  • ๋ณธ ์—ฐ๊ตฌ ๊ฒฐ๊ณผ๋Š” "Deep Learning using Linear Support Vector Machines"์˜ ์ œ์•ˆ๋œ CNN-SVM์— ๋Œ€ํ•œ ๊ฒ€ํ† ๋ฅผ ๋”์šฑ ๊ฒ€์ฆํ•˜๊ธฐ ์œ„ํ•œ ๋ฐฉ๋ฒ•๋ก ์˜ ๊ฐœ์„ ์„ ๋ณด์ฆํ•˜๋Š”๋ฐ ์˜์˜๋ฅผ ๋‘”๋‹ค.

  • "Deep Learning using Linear Support Vector Machines"์˜ ์กฐ์‚ฌ ๊ฒฐ๊ณผ์™€ ๋ชจ์ˆœ๋จ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ , ์–‘์ ์œผ๋กœ ๋งํ•˜๋ฉด, CNN-์†Œํ”„ํŠธ๋งฅ์Šค์™€ CNN-SVM์˜ ์‹œํ—˜ ์ •ํ™•๋„๋Š” ๊ด€๋ จ ์—ฐ๊ตฌ์™€ ๊ฑฐ์˜ ๊ฐ™๋‹ค.

  • ๋”ฐ๋ผ์„œ, ์ถ”๊ฐ€์ ์ธ ๋ฐ์ดํ„ฐ ์‚ฌ์ „ ์ฒ˜๋ฆฌ ๋ฐ ๋น„๊ต์  ์ •๊ตํ•œ base CNN ๋ชจ๋ธ์„ ์ด์šฉํ•˜๋ฉด ์ถฉ๋ถ„ํžˆ ํ•ด๋‹น ๊ฒฐ๊ณผ๋ฅผ ์žฌํ˜„ํ•  ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค.

profile
Always be passionate โœจ

0๊ฐœ์˜ ๋Œ“๊ธ€