tensorflow v2 (5)

km-ji·2024년 2월 2일

tensorflow

목록 보기
7/9

import 해주기

import tensorflow as tf
from tensorflow.keras import datasets, utils
from tensorflow.keras import models, layers, activations, initializers, losses, optimizers, metrics

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

데이터 준비

MNIST 불러오기

train & test split

(train_data, train_label), (test_data, test_label) = datasets.mnist.load_data()

normalization

train_data = train_data.reshape(60000, 784) / 255.0
test_data = test_data.reshape(10000, 784) / 255.0

펼쳐주기

one-hot encoding

train_label = utils.to_categorical(train_label)
test_label = utils.to_categorical(test_label)

모델 만들기

model = models.Sequential() 

model.add(layers.Dense(input_dim=28*28, units=256, activation=None, kernel_initializer=initializers.he_uniform())) 
model.add(layers.BatchNormalization())
model.add(layers.Activation('relu')) 
model.add(layers.Dropout(rate=0.2))

model.add(layers.Dense(units=256, activation=None, kernel_initializer=initializers.he_uniform())) 
model.add(layers.BatchNormalization())
model.add(layers.Activation('relu')) 
model.add(layers.Dropout(rate=0.2))

model.add(layers.Dense(units=10, activation='softmax')) 

컴파일

model.compile(optimizer=optimizers.Adam(), 
              loss=losses.categorical_crossentropy, 
              metrics=[metrics.categorical_accuracy])

요약 확인 해보기

model.summary()

모델 Train

history = model.fit(train_data, train_label, batch_size=100, epochs=15, validation_split=0.2) 

모델 Test

result = model.evaluate(test_data, test_label, batch_size=100)

print('loss (cross-entropy) :', result[0])
print('test accuracy :', result[1])

시각화

val_acc = history.history['val_categorical_accuracy']
acc = history.history['categorical_accuracy']

import numpy as np
import matplotlib.pyplot as plt

x_len = np.arange(len(acc))
plt.plot(x_len, acc, marker='.', c='blue', label="Train-set Acc.")
plt.plot(x_len, val_acc, marker='.', c='red', label="Validation-set Acc.")

plt.legend(loc='lower right')
plt.grid()
plt.xlabel('epoch')
plt.ylabel('Accuracy')
plt.show()
profile
I'm mz. Do you want to try mzing?

0개의 댓글