MNIST๋Š” ์†์œผ๋กœ ์“ด ์ˆซ์ž ๊ธ€์”จ๋ฅผ ๋ชจ์•„๋†“์€ ๋ฐ์ดํ„ฐ์„ธํŠธ์ด๋‹ค. ํ‘๋ฐฑ ์ด๋ฏธ์ง€์ด๊ณ  ๋ฒ”์ฃผ๊ฐ€ 10๊ฐœ๋กœ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์œผ๋ฉฐ 28 * 28 ํ”ฝ์…€์ด๋‹ค.

๐ŸŽ  ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()
import matplotlib.pyplot as plt

digit = X_train[1]
plt.imshow(digit, cmap = 'gray')
plt.show()

์œ„ ์ฝ”๋“œ๋กœ ๋ฐ์ดํ„ฐ ์ค‘ ํ•˜๋‚˜๋ฅผ ์ด๋ฏธ์ง€ ํ˜•ํƒœ๋กœ ํ™•์ธํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค.

๐ŸŽ  ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

X_train = X_train.reshape((60000, 28 * 28))
X_test = X_test.reshape((10000, 28 * 28))

Dense ๋ ˆ์ด์–ด์— ๋„ฃ๊ธฐ ์œ„ํ•ด ์œ„์™€ ๊ฐ™์€ ํ˜•ํƒœ๋กœ ๋ฐ”๊ฟ”์ค€๋‹ค.

X_train = X_train.astype(float) / 255
X_test = X_test.astype(float) / 255

0 ~ 255 ์‚ฌ์ด์˜ ๊ฐ’์„ ๊ฐ€์ง€๊ธฐ ๋•Œ๋ฌธ์— ์œ„์™€ ๊ฐ™์ด normalization ํ•ด์ค€๋‹ค.

from keras.utils import to_categorical

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)

y๊ฐ€ 0 ~ 9 ์‚ฌ์ด์˜ ๊ฐ’์„ ๊ฐ€์ง€๊ธฐ ๋•Œ๋ฌธ์— One-Hot Encoding์„ ํ†ตํ•ด ๊ฐ’์„ ๋ณ€๊ฒฝํ•ด์ค€๋‹ค.

๐ŸŽ  ๋ชจ๋ธ๋ง

from keras import models
from keras import layers

mnist = models.Sequential()
mnist.add(layers.Dense(512, activation = 'relu', input_shape = (28 * 28,)))
mnist.add(layers.Dense(256, activation = 'relu'))
mnist.add(layers.Dense(10, activation = 'softmax'))

mnist.compile(loss = 'categorical_crossentropy', optimizer = 'rmsprop', metrics = ['accuracy'])

๋‹ค์ค‘ ๋ถ„๋ฅ˜ ๋ฌธ์ œ์ด๊ธฐ ๋•Œ๋ฌธ์— categorical_crossentropy๋กœ ์„ค์ •ํ•œ๋‹ค.

hist = mnist.fit(X_train, y_train,
epochs = 100,
batch_size = 128,
validation_split = 0.2)

๋ณ„๋„๋กœ validation data๋ฅผ ์ง€์ •ํ•˜์ง€ ์•Š์•˜๊ธฐ ๋•Œ๋ฌธ์— validation_spilt์„ ํ†ตํ•ด 0.2์˜ validation data๋ฅผ ์ง€์ •ํ•œ๋‹ค.

๐ŸŽ  ๊ฒฐ๊ณผ

์œ„ ๋ชจ๋ธ์˜ loss, val_loss๋ฅผ ์‹œ๊ฐํ™”ํ•ด๋ณด๋ฉด Training Loss๋Š” Epoch๊ฐ€ ์ฆ๊ฐ€ํ• ์ˆ˜๋ก ๊ฐ์†Œํ•˜๋Š” ๋ชจ์Šต์„ ๋ณด์ด์ง€๋งŒ Validation Loss๋Š” ์˜คํžˆ๋Ÿฌ ๋” ์ฆ๊ฐ€ํ•˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ์šฐ๋ฆฌ๋Š” ์œ„ ๋ชจ๋ธ์ด Overfitting ๋˜์—ˆ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

0๊ฐœ์˜ ๋Œ“๊ธ€

๊ด€๋ จ ์ฑ„์šฉ ์ •๋ณด

Powered by GraphCDN, the GraphQL CDN