여기서는 콜백 API를 사용하여 지정된 측정항목이 충족되면 학습을 중지한다.
이는 유용한 기능이므로 이 임계값에 도달하면 모든 에포크를 완료할 필요가 없다.
예를 들어, 1000 Epoch를 설정하고 원하는 정확도가 이미 Epoch 200에 도달한 경우 훈련이 자동으로 중지된다.
Fashion MNIST 데이터세트를 로드하고, 학습을 위해서 픽셀 값을 정규화한다.
import tensorflow as tf
fmnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fmnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
'''
Halts the training when the loss falls below 0.4
Args:
epoch (integer) - index of epoch (required but unused in the function definition below)
logs (dict) - metric results from the training epoch
'''
if(logs.get('loss') < 0.4):
print("\nLoss is lower than 0.4 so cancelling training!")
self.model.stop_training = True
callbacks = myCallback()
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer=tf.optimizers.Adam(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])