Fashion Mnist Tensorflow Tutorial

jiho·2019년 12월 2일
0

ML

목록 보기
2/2

이 포스트는 Tensorflow Version 2.0 Hello world 격인 Fashion Mnist 학습코드에 대한 설명입니다. 주석으로 자세히 설명하겠습니다.

import tensorflow as tf
from tensorflow import keras

import numpy as np
import matplotlib.pyplot as plt
# 행렬 연산을 쉽게 해주는 Numpy
# 결과를 그래프로 보여줄 시각화 matplotlib

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images,
                               test_labels) = fashion_mnist.load_data()

# 학습에 노이즈가 생기지 않게 0.0 ~ 1.0 범위로 데이터 전처리
# weight와 범위를 같게 해줘야 빠르게 정확하게 학습할 수 있습니다.
train_images = train_images / 255.
test_images = test_images / 255.

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# 네트워크 구성
model = keras.Sequential([
  	# chage input node shape 256 node
    keras.layers.Flatten(input_shape=(28, 28)), 
    keras.layers.Dense(128, activation="relu"), # 128 node relu activation
    keras.layers.Dense(10, activation="softmax") # 10 node softmax activation
])

# learning setting
model.compile(optimizer="adam",
              loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# learning
model.fit(train_images, train_labels, epochs=1)

# validation with test data
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)

# predictions([10]) about test data
predictions = model.predict(test_images)


def plot_image(i, predictions_arrays, true_labels, images):
    predictions_array, true_label, image = predictions_arrays[i], true_labels[i], images[i]

    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(image, cmap=plt.cm.binary)

    predicted_label = np.argmax(predictions_array)
    if predicted_label == true_label:
        color = "red"
    else:
        color = "blue"

    plt.xlabel(
        f"{class_names[predicted_label]} {predictions_array[predicted_label]:2.0f}% {class_names[true_label]}", color=color)


def plot_value_array(i, predictions_arrays, true_labels):
    predictions_array, true_label = predictions_arrays[i], true_labels[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    thisplot = plt.bar(range(10), predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = np.argmax(predictions_array)
    thisplot[predicted_label].set_color("red")
    thisplot[true_label].set_color("blue")


num_rows = 5
num_cols = 3
num_images = num_rows * num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
    plt.subplot(num_rows, 2*num_cols, 2*i + 1)
    plot_image(i, predictions, test_labels, test_images)
    plt.subplot(num_rows, 2*num_cols, 2*i + 2)
    plot_value_array(i, predictions, test_labels)

plt.show()
profile
Scratch, Under the hood, Initial version analysis

0개의 댓글