[딥러닝 Express] Chapter 08. 심층 신경망 [예제: MNIST 필기체 숫자 인식]

배규리·2024년 1월 24일

AI 기초

목록 보기
20/32
post-thumbnail

문제

심층 신경망을 사용해서 MNIST 필기체 숫자 인식 모델 만들기

교재 코드💻

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape = (28,28)))
model.add(tf.keras.layers.Dense(512, activation="relu"))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10, activation="softmax"))

model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

model.fit(x_train, y_train, epochs=5)
test_loss, test_acc = model.evaluate(x_test, y_test)
print("손실:", test_loss)
print("정확도:", test_acc)

직접 구현한 코드😊

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0

# 모델 구축
model = tf.keras.models.Sequential()
# 입력을 784x1 형태로 만들어준다.
model.add(tf.keras.layers.Flatten(input_shape = (28,28)))
model.add(tf.keras.layers.Dense(512, activation="relu"))
# 학습 시에 0.1 비율로 유닛을 학습에서 제외시켰다.
model.add(tf.keras.layers.Dropout(0.1))
model.add(tf.keras.layers.Dense(10, activation="softmax"))

# 학습
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

history = model.fit(x_train, 
                    y_train, 
                    epochs=10,
                    batch_size=256,
                    validation_data = (x_test, y_test),
                    verbose=2
                    )

# 결과 분석
history_dict = history.history
loss_values = history_dict['loss'] # 훈련 데이터의 손실함수 값
val_loss_values = history_dict['val_loss'] # 검증 데이터의 손실함수 값
acc = history_dict['accuracy']
epochs = range(1, len(acc)+1) # 에포크 수

plt.plot(loss_values)
plt.plot(val_loss_values)
plt.title("Loss Plot")
plt.ylabel("loss")
plt.xlabel("epochs")
plt.legend(["train error", "val error"], loc = "upper left")
plt.show()


profile
백엔드 개발은 취미인 AI 개발자🥹

0개의 댓글