tensorflow 실습(MNIST)

Lee Hyun Joon ·2022년 7월 10일

ML_Basic

목록 보기
5/14

28*28 데이터셋 train 및 callback custom 실습

import os
import tensorflow as tf
from tensorflow import keras

data_path = # local datapath에서 추출 
# Discard test set
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data(path=data_path)
        
# Normalize pixel values
x_train = x_train / 255.0 

class myCallback(tf.keras.callbacks.Callback): # 상속 후 새 class 생성
        def on_epoch_end(self, epoch, logs={}): # on_epoch_end는 epoch 끝날때마다 함수 호출
            if logs.get('accuracy') > 0.99:
                print("원하는 accuracy가 나왔기에 이제 그만 학습!") 
                self.model.stop_training = True 
def train_mnist(x_train, y_train):

    callbacks = myCallback()

    model = tf.keras.models.Sequential([ 
        keras.layers.Flatten(input_shape=(28,28)),
        keras.layers.Dense(512, activation=tf.nn.relu),
        keras.layers.Dense(10, activation=tf.nn.softmax)
    ]) # 모델 구조 선언 (multi-class이기에 softmax로 진행 )
	
	
    model.compile(optimizer='adam', 
                  loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy']) 

    history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

    return history        
    
model_run = train_mnist(x_train, y_train)

해당 실습 코드를 올린 이유는 callback에 대해 정리하고자 올렸습니다.
이전부터 early stopping 이나 다른 방법들을 적용해봤지만, callback은 낯설게 느껴져 사용하지 않은 방법이었습니다.

특정 조건에 해당될 경우 train을 그만하고 싶다는 판단할 때 사용하는 방법이라고 이해했습니다.
위 코드에서는 정확도가 99프로 넘으면 그만 학습을 하는 것으로 표현했습니다.
이 방식은 early stopping 보다 orthogonalization 요건을 고려해 stop 조건을 세울 수 있다고 생각이 들긴 한데, 이후 실습을 통해 한 번 비교해보겠습니다.

profile
우당탕탕 개발 지망생

0개의 댓글