[데이터마이닝] 딥러닝 기초 - Tensorflow, mnist data

이혜윤·2023년 3월 31일

데이터마이닝

목록 보기
1/3

MNIST Dataset

MNIST Dataset이란 우편 번호를 자동으로 인식하고자 만든 학습 데이터셋. 로지스틱 모형으로 해당 기능을 구현해보쟛!

  • 28*28=784개 픽셀에 대한 0-255 사이의 숫자로 표현된 형식
  • X1 ~ X784 개의 독립 변수, 이에 대한 레이블 0~9

# mnistData.ipynb

import tensorflow as tf
import numpy as np
from tensorflow.keras import datasets

mnist = datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()
train_x.shape, test_x.shape
>> ((60000, 28, 28), (10000, 28, 28))

image = train_x[0]
image.shape
>> (28, 28)

import matplotlib.pyplot as plt
plt.imshow(image, 'gray')
plt.show()

train_y.shape
>> (60000,)

train_y[0]
>> 5

from tensorflow.keras.utils import to_categorical
to_categorical(1, 5) # 5개 중에 1번째가 핫~
>> array([0., 1., 0., 0., 0.], dtype=float32)

label = train_y[0]
label
>> 5

to_categorical(label, 10) 
>> array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=float32)

1. one-hot encoding

배경

  • Mnist Data Y=0,1, ... ,9
    • 해당 값이 크기를 가지는 것은 아니다.
  • logistic regression model Y=0,1

적용

이 경우, 통계학의 dummy variable 개념을 이용하여 아래와 같이 데이터 변환하여 0,1로 구성

Hypothesis

  • x1는 0일 확률 0.7, 1일 확률이 0.2, 2일 확률이 0.1 , ...
  • x2는 0일 확률 0, 1일 확률 0.9 , ...

2. Softmax

  • train_x, test_x 는 28 * 28 행렬
  • 784개의 Feature 생성 후 255 로 나누어 모든 값을 0-1 사이의 값으로 설정
  • train_y, test_y 를 one-hot encoding 수행
  • Input 784개, Output 10개
  • Stochastic Gradient Descent 알고리즘 사용
  • Categorical Cross Entropy를 손실함수로 사용

# mnist_softmax_tf2.ipynb
import tensorflow as tf
import numpy as np
from tensorflow.keras import datasets
from tensorflow.keras.utils import to_categorical
mnist = datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()

print(train_x.shape)
np.max(train_x[0]), np.min(train_x[0])
>> (60000, 28, 28)
>> (255, 0)

train_x[0]
>> array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,
         18,  18,  18, 126, 136, 175,  26, 166, 255, 247, 127,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  30,  36,  94, 154, 170,
        253, 253, 253, 253, 253, 225, 172, 253, 242, 195,  64,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  49, 238, 253, 253, 253, 253,
        253, 253, 253, 253, 251,  93,  82,  82,  56,  39,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  18, 219, 253, 253, 253, 253,
        253, 198, 182, 247, 241,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  80, 156, 107, 253, 253,
        205,  11,   0,  43, 154,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  14,   1, 154, 253,
         90,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 139, 253,
        190,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  11, 190,
        253,  70,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  35,
        241, 225, 160, 108,   1,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         81, 240, 253, 253, 119,  25,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,  45, 186, 253, 253, 150,  27,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,  16,  93, 252, 253, 187,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0, 249, 253, 249,  64,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,  46, 130, 183, 253, 253, 207,   2,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  39,
        148, 229, 253, 253, 253, 250, 182,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  24, 114, 221,
        253, 253, 253, 253, 201,  78,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  23,  66, 213, 253, 253,
        253, 253, 198,  81,   2,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,  18, 171, 219, 253, 253, 253, 253,
        195,  80,   9,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,  55, 172, 226, 253, 253, 253, 253, 244, 133,
         11,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0, 136, 253, 253, 253, 212, 135, 132,  16,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0]], dtype=uint8)

train_x = train_x.reshape(-1,784) 
test_x = test_x.reshape(-1,784) 
train_x.shape, test_x.shape
>> ((60000, 784), (10000, 784))

train_x = train_x / 255 # 표준화
test_x = test_x / 255

train_x[0]
>> array([0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.01176471, 0.07058824, 0.07058824,
       0.07058824, 0.49411765, 0.53333333, 0.68627451, 0.10196078,
       0.65098039, 1.        , 0.96862745, 0.49803922, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.11764706, 0.14117647, 0.36862745, 0.60392157,
       0.66666667, 0.99215686, 0.99215686, 0.99215686, 0.99215686,
       0.99215686, 0.88235294, 0.6745098 , 0.99215686, 0.94901961,
       0.76470588, 0.25098039, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.19215686, 0.93333333,
       0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.99215686,
       0.99215686, 0.99215686, 0.99215686, 0.98431373, 0.36470588,
       0.32156863, 0.32156863, 0.21960784, 0.15294118, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.07058824, 0.85882353, 0.99215686, 0.99215686,
       0.99215686, 0.99215686, 0.99215686, 0.77647059, 0.71372549,
       0.96862745, 0.94509804, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.31372549, 0.61176471, 0.41960784, 0.99215686, 0.99215686,
       0.80392157, 0.04313725, 0.        , 0.16862745, 0.60392157,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.05490196,
       0.00392157, 0.60392157, 0.99215686, 0.35294118, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.54509804,
       0.99215686, 0.74509804, 0.00784314, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.04313725, 0.74509804, 0.99215686,
       0.2745098 , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.1372549 , 0.94509804, 0.88235294, 0.62745098,
       0.42352941, 0.00392157, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.31764706, 0.94117647, 0.99215686, 0.99215686, 0.46666667,
       0.09803922, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.17647059,
       0.72941176, 0.99215686, 0.99215686, 0.58823529, 0.10588235,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.0627451 , 0.36470588,
       0.98823529, 0.99215686, 0.73333333, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.97647059, 0.99215686,
       0.97647059, 0.25098039, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.18039216, 0.50980392,
       0.71764706, 0.99215686, 0.99215686, 0.81176471, 0.00784314,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.15294118,
       0.58039216, 0.89803922, 0.99215686, 0.99215686, 0.99215686,
       0.98039216, 0.71372549, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.09411765, 0.44705882, 0.86666667, 0.99215686, 0.99215686,
       0.99215686, 0.99215686, 0.78823529, 0.30588235, 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.09019608, 0.25882353, 0.83529412, 0.99215686,
       0.99215686, 0.99215686, 0.99215686, 0.77647059, 0.31764706,
       0.00784314, 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.07058824, 0.67058824, 0.85882353,
       0.99215686, 0.99215686, 0.99215686, 0.99215686, 0.76470588,
       0.31372549, 0.03529412, 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.21568627, 0.6745098 ,
       0.88627451, 0.99215686, 0.99215686, 0.99215686, 0.99215686,
       0.95686275, 0.52156863, 0.04313725, 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.53333333, 0.99215686, 0.99215686, 0.99215686,
       0.83137255, 0.52941176, 0.51764706, 0.0627451 , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 0.        , 0.        ])
train_y_onehot = to_categorical(train_y) # one hot encoding
test_y_onehot = to_categorical(test_y)
train_y_onehot[0]
>> array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=float32)

from tensorflow.keras import layers
model = tf.keras.Sequential()
model.add(layers.Dense(10, activation='softmax', input_dim=784))
model.compile(optimizer='sgd',loss='categorical_crossentropy',metrics=['accuracy'])

(i) 정의된 모형을 6만개의 데이터를 총 5회 학습하여 (W,b)모수를 구한다.

  • 학습에 사용되지 않은 테스트 데이터 만개를 통해 학습결과를 테스트한다.
model.fit(train_x,train_y_onehot,epochs=5)

>> Epoch 1/5
1875/1875 [==============================] - 6s 3ms/step - loss: 0.7664 - accuracy: 0.8215 - val_loss: 0.4793 - val_accuracy: 0.8843
Epoch 2/5
1875/1875 [==============================] - 6s 3ms/step - loss: 0.4546 - accuracy: 0.8818 - val_loss: 0.3998 - val_accuracy: 0.8964
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4026 - accuracy: 0.8917 - val_loss: 0.3670 - val_accuracy: 0.9040
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3766 - accuracy: 0.8970 - val_loss: 0.3481 - val_accuracy: 0.9069
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.3599 - accuracy: 0.9003 - val_loss: 0.3370 - val_accuracy: 0.9098

model.evaluate(train_x, train_y_onehot)
>> 1875/1875 [==============================] - 3s 2ms/step - loss: 0.3527 - accuracy: 0.9031
[0.35273510217666626, 0.9031000137329102]

model.evaluate(test_x,test_y_onehot)
>> 313/313 [==============================] - 1s 3ms/step - loss: 0.4212 - accuracy: 0.8920
[0.4211805760860443, 0.8920000195503235]

predicted = model.predict(test_x)
predicted[0],test_y_onehot[0]
>> (array([3.04816407e-04, 9.52855601e-07, 1.77901296e-04, 2.51405081e-03,
        4.47172351e-05, 1.08292574e-04, 2.74398462e-06, 9.93954301e-01,
        2.43925446e-04, 2.64819362e-03], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], dtype=float32))

(ii) 정의된 모형을 6만개의 데이터를 batch_size=100으로 총 5회 학습하여 (W,b) 모수를 구한다.

  • batch_size 미지정시 default 값은 32
  • 정확도는 89% 정도이다.
model.fit(train_x, train_y_onehot, validation_data=(test_x, test_y_onehot), batch_size = 100, epochs=5)

>> Epoch 1/5
600/600 [==============================] - 3s 3ms/step - loss: 1.1547 - accuracy: 0.7371 - val_loss: 0.7217 - val_accuracy: 0.8473
Epoch 2/5
600/600 [==============================] - 2s 3ms/step - loss: 0.6479 - accuracy: 0.8513 - val_loss: 0.5521 - val_accuracy: 0.8721
Epoch 3/5
600/600 [==============================] - 3s 5ms/step - loss: 0.5402 - accuracy: 0.8668 - val_loss: 0.4842 - val_accuracy: 0.8819
Epoch 4/5
600/600 [==============================] - 2s 3ms/step - loss: 0.4883 - accuracy: 0.8759 - val_loss: 0.4463 - val_accuracy: 0.8888
Epoch 5/5
600/600 [==============================] - 2s 3ms/step - loss: 0.4563 - accuracy: 0.8815 - val_loss: 0.4212 - val_accuracy: 0.8920

model.evaluate(train_x, train_y_onehot)
>> 1875/1875 [==============================] - 3s 2ms/step - loss: 0.4435 - accuracy: 0.8838
[0.4435127377510071, 0.8837666511535645]

model.evaluate(test_x, test_y_onehot)
>> 313/313 [==============================] - 1s 3ms/step - loss: 0.4212 - accuracy: 0.8920
[0.4211805760860443, 0.8920000195503235]

predicted = model.predict(test_x)
>> 313/313 [==============================] - 1s 1ms/step

predicted[0], test_y_onehot[0]
>> (array([5.3943484e-04, 2.3173239e-05, 4.9038179e-04, 2.3329458e-03,
        5.1109842e-04, 4.1478052e-04, 5.0893617e-05, 9.8665601e-01,
        5.5489421e-04, 8.4262872e-03], dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], dtype=float32))


3. mini-batch, epoch

3.1 SGD(Stochastic Gradient Descent)

배경- 기존 방식의 문제점

  • 손실함수 1번 계산 시, 가중값 갱신 1번
  • 학습 데이터가 많으면 가중 값을 한 번 갱신하기 위해 전체 데이터를 학습하는 계산 필요

SGD의 도입

  • 학습 데이터를 셔플링하고 배치 크기 데이터로 손실함수를 계산하고 가중값을 갱신

SGD의 장점

  • 같은 시간에 더 많이 해를 갱신

SGD의 단점

  • 해가 진동
  • 손실함수의 Convexity 보장 X
    • batch_size 크면 : 진동이 작아져 정확한 경로를 탐색해 해로 수렴하지만, 계산 속도가 느림
    • batch_size 작으면: 진동이 커져서 정확한 경로가 아닌 방향으로 움직여 문제가 발생하지만, 계산 속도가 빠르다.

해결방안

  • convexity 불만족 시 다음 배치 데이터에서 기울기 방향을 바꿔 탈출
  • 학습률을 처음에는 크게 했다가 감소하도록 부여
    • decay: lr의 감소율.
    • lr= lr(1/(1_decayiteration))

학습 진행

  • 설정
    • batch_size=100, training_epochs=5, learning_rate=0.001
    • mnist dataset = 6만 건의 training set + 1만 건의 test set
    • 계산 방식: input * (60000 X 784) 에 weight를 곱하는 방식
  • 문제
    • 6만 건을 이용하여 cost 함수를 한 번 계산하고 최대경사법을 통해 weight를 한 번 조정하는 것이 비효율적
  • 해결
    • mini-batch(ex.100) 단위로 cost 및 weight 계산 수행
    • 60000/100=600번 반복하면 Full Data (6만건) 모두 학습 가능 = 1 Epoch
    • 주어진 데이터가 많을 때 적은 메모리를 사용하면서 더 정확한 추정 가능
    • 20 epoch 학습을 한다고하면 600 * 20 = 12,000번 weight를 조정

3.2 Cross Entropy funciton

Binomial Logistic Regression의 CrossEntropy function

Multinomial Logistic Regression의 CrossEntropy funciton

  • 0 아니면 1

one-hot encoding 형식 (n by c)

  • 0,1,2,3, ...

결국 모두 logistic regression

  • binominal regression을 multinomial regression으로 계산하기 위해 softmax 도입

  • 같은 데이터를 두 개의 엔트로피 식으로 계산한 결과

#CrossEntropy.py
import math
y=[0,1,1,0] # binary
hypothesis=[0.3, 0.7, 0.6, 0.2] 
# y에 대한 확률값. 첫번째 객체가 1일 확률이 0.3, 1일 확률 0.7, 1일 확률 0.6, 0일 확률 0.2
sum=0.
for i in range(4):
    sum += -y[i]*math.log(hypothesis[i])-(1-y[i])*math.log(1-hypothesis[i])
crossEntropy=sum/4
print (crossEntropy)
>> 0.36182976573941633

# 이번엔 똑같은 데이터에 대해 ont-hot-encoding으로 표현.
y1Hot=[[1,0],[0,1],[0,1],[1,0]]
#hypothesis1Hot: [prob Y=0, prob Y=1]
hypothesis1Hot=[[0.7,0.3],[0.3,0.7],[0.4,0.6],[0.8,0.2]]
# (0일 확률이 0.7,1일 확률이 0.3)
sum1=0.
for i in range(4):
    sum2=0.
    for j in range(2):
        sum2+=-y1Hot[i][j]*math.log(hypothesis1Hot[i][j])
    sum1+=sum2

crossEntropy1Hot=sum1/4
print (crossEntropy1Hot)
>> 0.36182976573941633

  • 결국 표현 형식이 다를 뿐, 모두 logistic regression
  • binominal에서 multinominal로 변환되었을 때 one-hot-encoding으로 표현했느냐 아니냐의 차이!
profile
구르미 누나

1개의 댓글

comment-user-thumbnail
2023년 4월 15일

역시 데ㅐ마왕!

답글 달기