딥러닝 실습-MNIST

김건우·2022년 1월 7일
0

머신러닝

목록 보기
15/21
post-thumbnail

Kaggle import

import os
os.environ['KAGGLE_USERNAME'] = 'name' # username
os.environ['KAGGLE_KEY'] = 'key' # key

데이터셋 준비

!kaggle datasets download -d oddrationale/mnist-in-csv
!unzip mnist-in-csv.zip

패키지 로드

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.optimizers import Adam, SGD
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import OneHotEncoder

데이터셋 로드

train_df = pd.read_csv('mnist_train.csv')
test_df = pd.read_csv('mnist_test.csv')

print(train_df)
print(test_df)

라벨분포

sns.countplot(train_df['label'])

입출력 나누기

train_df = train_df.astype(np.float32)
x_train = train_df.drop(columns=['label'],axis=1).values
y_train = train_df[['label']].values

test_df = test_df.astype(np.float32)
x_test = train_df.drop(columns=['label'],axis=1).values
y_test = train_df[['label']].values

print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

데이터 미리보기

index = 1
plt.title(str(y_train[index]))
plt.imshow(x_train[index].reshape((28,28)),cmap='gray')
plt.show()

One-Hot encoding

encoder = OneHotEncoder()
y_train = encoder.fit_transform(y_train).toarray()
y_test = encoder.fit_transform(y_test).toarray()

print(y_train.shape)

일반화

x_train = x_train / 255
x_test = x_test / 255

네트워크 구성

input = Input(shape=(784,))
hidden = Dense(1024, activation='relu')(input)
hidden = Dense(512, activation='relu')(hidden)
hidden = Dense(256, activation='relu')(hidden)
output = Dense(10, activation='softmax')(hidden)


model = Model(inputs=input, outputs=output)

model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.001),metrics=['acc'])

model.summary()

학습

history = model.fit(
    x_train,
    y_train,
    validation_data = (x_test, y_test),
    epochs=20
)
Epoch 1/20
1875/1875 [==============================] - 45s 23ms/step - loss: 0.1925 - acc: 0.9414 - val_loss: 0.0982 - val_acc: 0.9700
Epoch 2/20
1875/1875 [==============================] - 41s 22ms/step - loss: 0.0901 - acc: 0.9731 - val_loss: 0.0563 - val_acc: 0.9833
Epoch 3/20
1875/1875 [==============================] - 42s 22ms/step - loss: 0.0636 - acc: 0.9808 - val_loss: 0.0568 - val_acc: 0.9827
Epoch 4/20
1875/1875 [==============================] - 41s 22ms/step - loss: 0.0501 - acc: 0.9850 - val_loss: 0.0354 - val_acc: 0.9886
Epoch 5/20
1875/1875 [==============================] - 42s 22ms/step - loss: 0.0421 - acc: 0.9870 - val_loss: 0.0294 - val_acc: 0.9901
Epoch 6/20
1875/1875 [==============================] - 43s 23ms/step - loss: 0.0351 - acc: 0.9893 - val_loss: 0.0217 - val_acc: 0.9932
Epoch 7/20
1875/1875 [==============================] - 43s 23ms/step - loss: 0.0301 - acc: 0.9906 - val_loss: 0.0169 - val_acc: 0.9948
Epoch 8/20
1875/1875 [==============================] - 43s 23ms/step - loss: 0.0250 - acc: 0.9925 - val_loss: 0.0256 - val_acc: 0.9927
Epoch 9/20
1875/1875 [==============================] - 45s 24ms/step - loss: 0.0251 - acc: 0.9924 - val_loss: 0.0191 - val_acc: 0.9938
Epoch 10/20
1875/1875 [==============================] - 43s 23ms/step - loss: 0.0224 - acc: 0.9935 - val_loss: 0.0167 - val_acc: 0.9947
Epoch 11/20
1875/1875 [==============================] - 42s 23ms/step - loss: 0.0218 - acc: 0.9936 - val_loss: 0.0141 - val_acc: 0.9962
Epoch 12/20
1875/1875 [==============================] - 44s 23ms/step - loss: 0.0199 - acc: 0.9941 - val_loss: 0.0109 - val_acc: 0.9970
Epoch 13/20
1875/1875 [==============================] - 42s 22ms/step - loss: 0.0209 - acc: 0.9950 - val_loss: 0.0126 - val_acc: 0.9968
Epoch 14/20
1875/1875 [==============================] - 43s 23ms/step - loss: 0.0154 - acc: 0.9962 - val_loss: 0.0091 - val_acc: 0.9974
Epoch 15/20
1875/1875 [==============================] - 42s 22ms/step - loss: 0.0211 - acc: 0.9948 - val_loss: 0.0094 - val_acc: 0.9971
Epoch 16/20
1875/1875 [==============================] - 42s 23ms/step - loss: 0.0152 - acc: 0.9961 - val_loss: 0.0405 - val_acc: 0.9935
Epoch 17/20
1875/1875 [==============================] - 42s 23ms/step - loss: 0.0161 - acc: 0.9959 - val_loss: 0.0095 - val_acc: 0.9975
Epoch 18/20
1875/1875 [==============================] - 42s 22ms/step - loss: 0.0184 - acc: 0.9959 - val_loss: 0.0124 - val_acc: 0.9968
Epoch 19/20
1875/1875 [==============================] - 43s 23ms/step - loss: 0.0188 - acc: 0.9955 - val_loss: 0.0161 - val_acc: 0.9958
Epoch 20/20
1875/1875 [==============================] - 42s 23ms/step - loss: 0.0141 - acc: 0.9965 - val_loss: 0.0072 - val_acc: 0.9980

그래프로 표현

plt.figure(figsize=(160,10))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])

plt.figure(figsize=(16,10))
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])

profile
공부하는 개발자가 목표입니다.

0개의 댓글