import os
import tensorflow as tf
from tensorflow import keras
data_path = # local datapath에서 추출
# Discard test set
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data(path=data_path)
# Normalize pixel values
x_train = x_train / 255.0
class myCallback(tf.keras.callbacks.Callback): # 상속 후 새 class 생성
def on_epoch_end(self, epoch, logs={}): # on_epoch_end는 epoch 끝날때마다 함수 호출
if logs.get('accuracy') > 0.99:
print("원하는 accuracy가 나왔기에 이제 그만 학습!")
self.model.stop_training = True
def train_mnist(x_train, y_train):
callbacks = myCallback()
model = tf.keras.models.Sequential([
keras.layers.Flatten(input_shape=(28,28)),
keras.layers.Dense(512, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)
]) # 모델 구조 선언 (multi-class이기에 softmax로 진행 )
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])
return history
model_run = train_mnist(x_train, y_train)
해당 실습 코드를 올린 이유는 callback에 대해 정리하고자 올렸습니다.
이전부터 early stopping 이나 다른 방법들을 적용해봤지만, callback은 낯설게 느껴져 사용하지 않은 방법이었습니다.
특정 조건에 해당될 경우 train을 그만하고 싶다는 판단할 때 사용하는 방법이라고 이해했습니다.
위 코드에서는 정확도가 99프로 넘으면 그만 학습을 하는 것으로 표현했습니다.
이 방식은 early stopping 보다 orthogonalization 요건을 고려해 stop 조건을 세울 수 있다고 생각이 들긴 한데, 이후 실습을 통해 한 번 비교해보겠습니다.