데이터로드는 생략







import io
from sklearn.metrics import confusion_matrix
def plot_to_image(figure):
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close(figure)
buf.seek(0)
image = tf.image.decode_png(buf.getvalue(), channels=4)
image = tf.expand_dims(image, 0)
return image
def plot_confusion_matrix(cm, class_names):
figure = plt.figure(figsize=(8, 8))
plt.imshow(cm)
plt.title("Confusion matrix")
plt.colorbar()
tick_marks = np.arange(len(class_names))
threshold = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
color = "white" if cm[i, j] > threshold else "black"
plt.text(j, i, cm[i, j], horizontalalignment="center", color=color)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
return figure
logdir = "logs/fit/cm/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer_cm = tf.summary.create_file_writer(logdir)
test_images = test_x[:100]
test_labels = np.argmax(test_y[:100], axis=1)
def log_confusion_matrix(epoch, logs):
test_pred_raw = model.predict(test_images)
test_pred = np.argmax(test_pred_raw, axis=1)
classes = np.arange(10)
cm = confusion_matrix(test_labels, test_pred, labels=classes)
figure = plot_confusion_matrix(cm, class_names=classes)
cm_image = plot_to_image(figure)
with file_writer_cm.as_default():
tf.summary.image("Confusion Matrix", cm_image, step=epoch)



