MNIST Classification in Tensorflow

Jiyeon Jeong·2021년 4월 26일
0

Introduction


MNIST Dataset을 이용하여 Classification 성능을 확인한다. 딥러닝 프레임워크는 Tensorflow 2.X를 사용하였다.

MNIST?



MNIST는 고등학생과 인구 조사국 직원 등의 손글씨 숫자 데이터이다. 미국 국립 표준 기술원에서 직접 제작했으며, 70000개의 숫자 영상으로 구성되어있다. 학습 데이터셋은 60000장, 테스트 데이터셋은 10000장이 된다. 각 영상의 가로와 세로 길이는 28 X 28로 이루어져 있으며 입력은 약 784픽셀이다.

모델 구조


MNIST Classification을 진행할 모델의 구조는 아래의 그림과 같다

LeNet-5를 기반으로 모델을 구성하였다. Input은 28 X 28이며, Convolution Layer가 3개, Dense Layer가 2개로 이루어져 있다. Activation 함수는 reLU를 사용하였으며, 마지막에는 Softmax를 사용해주었다. 2개의 Convolution Layer에서는 Max Pooling을 통해 영상의 크기를 줄여주었다.

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), input_shape=(28, 28, 1),
 activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.5))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.25))
model.add(BatchNormalization())
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Activation('relu'))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='adam', metrics=['accuracy'])

모델 구조에 관한 코드는 위와 같다.

구조를 정의한 후 모델을 저장할 폴더를 설정하고 Checkpoint를 지정해준다. 그리고 loss가 떨어지지 않으면 이른 중단을 통해 학습을 종료한다. 총 30번 반복을 하였으며, 다 돌고 나면 마지막에 정확도를 출력하고 그래프를 출력한다.

MODEL_DIR = './model/'
if not os.path.exists(MODEL_DIR):
    os.mkdir(MODEL_DIR)
modelpath="./model/{epoch:02d}-{val_loss:.4f}.hdf5"
checkpointer = 
 ModelCheckpoint(filepath=modelpath, monitor='val_loss', verbose=1,
 save_best_only=True)
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=10)

history = model.fit(X_train, Y_train, validation_data=(X_test, Y_test), epochs=30,
batch_size=200, verbose=0, callbacks=[early_stopping_callback,checkpointer])

print("\n Test Accuracy: %.4f" % (model.evaluate(X_test, Y_test)[1]))

결과



정확도는 99.48%로 나오며, Loss는 0.015정도로 확인할 수 있다. 학습을 돌리는데 Google Colab에서는 5분 정도 걸리는 듯 했다.

자세한 코드는 Github MNIST Classification을 참고

profile
기록용입니다.

0개의 댓글