뇌 MRI이미지로 알츠하이머와 경도인지장애를 진단하는 CNN 딥러닝 모델 개발과정 정리 및 회고.
(3) 예측 및 결론
자유주제
(의료/헬스케어)데이터 직무 포지션
에서 풀고자하는 문제 정의데이터셋
선정 및 선정 이유딥러닝 파이프라인
구축 학습 및 검증
한계점
과 추후 발전 방향
(Part1) Intro & Metadata
서론
포지션설정, 기획의도
알츠하이머치매와 경도인지장애
연구의 필요성
목표 및 가설
데이터셋 및 메타데이터
데이터셋 소개
데이터 준비
메타데이터 분석
대시보드
(Part2) Modeling
모델링
개요 및 구조
데이터 로딩 및 분리
전처리 레이어
기본 CNN 모델
모델 성능 개선
전이학습
네거티브 전이
하이퍼파라미터 튜닝
최종 모델 학습 및 검증
(Part3) Prediction & Conclusion
예측 및 비교분석
테스트 데이터셋
예측 및 신뢰도
평가지표 및 혼동행렬
결론
요약
한계와 추후 발전방향
핵심과 소감
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
plt.rcParams['axes.unicode_minus'] = False
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.models import load_model
base_dir = "./imagedata/"
data_dir = base_dir + "Axial"
os.listdir(data_dir)
'''
['AD', 'CN', 'MCI']
'''
CLASSES = ['CN', 'MCI', 'AD']
IMG_SIZE = 256
BATCH_SIZE = 32
SEED = 42
dataset = image_dataset_from_directory(
data_dir,
shuffle=True,
class_names=CLASSES,
batch_size=BATCH_SIZE,
image_size=(IMG_SIZE, IMG_SIZE),
seed=SEED
)
print(dataset.class_names)
'''
Found 5154 files belonging to 3 classes.
['CN', 'MCI', 'AD']
'''
number_of_images = {}
for class_name in CLASSES:
number_of_images[class_name] = len(os.listdir(data_dir+"/"+class_name))
image_count_df = pd.DataFrame(number_of_images.values(),
index=number_of_images.keys(),
columns=["Number of Images"])
display(image_count_df)
print("\nSum of Images: {}".format(image_count_df.sum()[0]))
ax = sns.barplot(data=image_count_df.T)
for p in ax.patches:
ax.annotate("%.0f" % p.get_height(), (p.get_x() + p.get_width()/2., p.get_height()-200),
ha='center', va='center', fontsize=10, color='black',xytext=(0,10),
textcoords='offset points')
plt.title('Number of Images')
plt.show()
explode = np.repeat(0.025,3)
wedgeprops = {'width': 0.5, 'edgecolor': 'w', 'linewidth': 1.5}
textprops={'size':12}
icd = image_count_df.reset_index(drop=False)
icd.columns = ['Group', 'Counts']
plt.figure(figsize=(8,8))
plt.pie(icd.Counts,labels=icd.Group,
labeldistance=0.725,startangle=0,autopct='%.1f%%',
explode=explode,wedgeprops=wedgeprops,textprops=textprops)
plt.legend(loc='center')
plt.title('Percentage of Images')
plt.show()
for img_batch, label_batch in dataset.take(1):
print(img_batch.shape)
print(label_batch.numpy())
print(label_batch.shape)
break
'''
(32, 256, 256, 3)
[2 2 1 1 1 1 1 2 1 0 1 1 1 0 1 0 0 1 0 1 1 1 1 0 1 1 0 2 1 0 0 0]
(32,)
'''
print(np.unique(label_batch))
print(dataset.class_names)
class_dict = dict(zip(np.unique(label_batch),dataset.class_names))
class_dict
'''
[0 1 2]
['CN', 'MCI', 'AD']
{0: 'CN', 1: 'MCI', 2: 'AD'}
'''
plt.figure(figsize=(12,12))
plt.suptitle(f"Alzheimer MRI Images\n{class_dict}", fontsize=24)
for i in range(16):
ax = plt.subplot(4,4,i+1)
img_arr = img_batch[i].numpy()
plt.imshow(img_arr.astype("uint8"))
plt.axis("off")
plt.title("Class : {}".format(CLASSES[label_batch[i].numpy()]))
def dataset_split(ds, tr=0.8, val=0.1, test=0.1, shuffle=True, buf_size=10000, SEED=42):
ds_size = len(ds)
if shuffle:
ds = ds.shuffle(buf_size,seed=SEED)
train_size = int(ds_size*tr)
val_size = int(ds_size*test)
train = ds.take(train_size)
test0 = ds.skip(train_size)
val = test0.take(val_size)
test = test0.skip(val_size)
return train,val,test
train_ds, val_ds, test_ds = dataset_split(dataset)
print("Split전 Batched data 개수")
print(f"Dataset : {len(dataset)}")
print("\nSplit후 Batched data 개수")
print(f"Train : {len(train_ds)}")
print(f"Validation : {len(val_ds)}")
print(f"Test : {len(test_ds)}")
'''
Split전 Batched data 개수
Dataset : 162
Split후 Batched data 개수
Train : 129
Validation : 16
Test : 17
'''
model1_path = base_dir + "model1.hdf5"
model1 = load_model(filepath=model1_path)
model1.summary()
'''
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) (None, 256, 256, 3) 0
sequential_1 (Sequential) (None, 256, 256, 3) 0
conv2d (Conv2D) (None, 256, 256, 32) 896
max_pooling2d (MaxPooling2D (None, 128, 128, 32) 0
)
conv2d_1 (Conv2D) (None, 128, 128, 64) 18496
max_pooling2d_1 (MaxPooling (None, 64, 64, 64) 0
2D)
flatten (Flatten) (None, 262144) 0
dense (Dense) (None, 64) 16777280
dense_1 (Dense) (None, 3) 195
=================================================================
Total params: 16,796,867
Trainable params: 16,796,867
Non-trainable params: 0
_________________________________________________________________
'''
model1.evaluate(test_ds, verbose=1)
'''
17/17 [==============================] - 7s 114ms/step - loss: 0.2975 - acc: 0.8934
[0.2974990904331207, 0.8933823704719543]
'''
final_model_path = base_dir + "best_model.hdf5"
final_model = load_model(filepath=final_model_path)
final_model.summary()
'''
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential (Sequential) (None, 256, 256, 3) 0
sequential_1 (Sequential) (None, 256, 256, 3) 0
conv2d_4 (Conv2D) (32, 256, 256, 16) 448
max_pooling2d_4 (MaxPooling (32, 128, 128, 16) 0
2D)
conv2d_5 (Conv2D) (32, 128, 128, 16) 2320
max_pooling2d_5 (MaxPooling (32, 64, 64, 16) 0
2D)
flatten_2 (Flatten) (32, 65536) 0
dense_4 (Dense) (32, 320) 20971840
dense_5 (Dense) (32, 3) 963
=================================================================
Total params: 20,975,571
Trainable params: 20,975,571
Non-trainable params: 0
_________________________________________________________________
'''
final_model.evaluate(test_ds, verbose=1)
'''
17/17 [==============================] - 5s 62ms/step - loss: 0.0443 - acc: 0.9853
[0.04426185041666031, 0.9852941036224365]
'''
# Test Data Label, 각 모델별 predicted Label array 생성
real_label = np.array([])
pred_model1 = np.array([])
pred_final_model = np.array([])
for imgs, labels in test_ds:
real_label = np.concatenate([real_label, labels.numpy()])
pred_model1 = np.concatenate([pred_model1, np.argmax(model1(imgs), axis=-1)])
pred_final_model = np.concatenate([pred_final_model, np.argmax(final_model(imgs), axis=-1)])
# DataFrame으로 만들기
label_df = pd.DataFrame(zip(real_label, pred_model1, pred_final_model),
columns = ['True_Label', 'Pred_Vanilla', 'Pred_Tuned'])
label_df2 = label_df.replace(class_dict)
label_df2.head(10)
count_label = label_df2['True_Label'].value_counts()
label_dict = {}
for class_name in CLASSES:
label_dict[class_name] = count_label.loc[class_name]
print(f"Label Counts : {label_dict}")
image_count_test = pd.DataFrame(label_dict.values(),
index = label_dict.keys(),
columns=["Test_Dataset"])
display(image_count_test)
print("Sum of Images : {}".format(image_count_test.sum()[0]))
# Test data count barplot
ax2 = sns.barplot(data=image_count_test.T)
for p2 in ax2.patches:
ax2.annotate("%.0f" % p2.get_height(), (p2.get_x() + p2.get_width()/2., p2.get_height()-20),
ha='center', va='center', fontsize=10, color='black',xytext=(0,10),
textcoords='offset points')
plt.title('Number of Images : Test Dataset')
plt.show()
# Test data pie plot
explode = np.repeat(0.025,3)
wedgeprops = {'width': 0.5, 'edgecolor': 'w', 'linewidth': 1.5}
textprops={'size':12}
ict = image_count_test.reset_index(drop=False)
ict.columns = ['Group', 'Counts']
plt.figure(figsize=(8,8))
plt.pie(ict.Counts,labels=ict.Group,
labeldistance=0.725,startangle=0,autopct='%.1f%%',
explode=explode,wedgeprops=wedgeprops,textprops=textprops)
plt.legend(loc='center')
plt.title('Percentage of Images : Test Dataset')
plt.show()
for image_test, label_test in test_ds.take(1):
print(image_test.shape)
print(label_test.numpy())
print(label_test.shape)
break
'''
(32, 256, 256, 3)
[1 1 1 0 1 0 0 1 1 1 1 2 1 2 2 2 1 1 2 0 1 2 0 1 1 1 1 1 1 0 0 0]
(32,)
'''
# 예측값, Confidence값
def prediction(model, img):
img_array = tf.keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, axis=0)
predictions = model.predict(img_array, verbose=0)
predicted_class = CLASSES[np.argmax(predictions[0])]
confidence = round(100*(np.max(predictions[0])), ndigits=2)
return predicted_class, confidence
# 시각화 함수
def pred_plot(model):
plt.figure(figsize=(12,12))
plt.suptitle(f"Alzheimer MRI Images {class_dict}", fontsize=24)
for i in range(12):
ax = plt.subplot(3,4,i+1)
plt.imshow(image_test[i].numpy().astype("uint8"))
predicted_class, confidence = prediction(model, image_test[i].numpy())
label_true = CLASSES[label_test[i].numpy()]
plt.title("Actual : {}\nPredicted : {}\nConfidence : {}".format(label_true, predicted_class, confidence))
plt.axis("off")
pred_plot(model1)
pred_plot(final_model)
def report_matrix(y_real, pred_label, subtitle):
print(classification_report(y_real, pred_label, target_names=CLASSES))
cm = confusion_matrix(y_real, pred_label)
plt.figure(figsize=(8,6), facecolor='w', edgecolor='k')
ax = sns.heatmap(data=cm, cmap='Greens', annot=True, fmt='d',
xticklabels=CLASSES, yticklabels=CLASSES)
plt.title(f"Confusion Matrix of Alzheimer-MCI Prediction\nModel : {subtitle}")
plt.xlabel('Prediction')
plt.ylabel('True(Real)')
plt.show(ax)
report_matrix(real_label, pred_model1, subtitle='Vanilla CNN model')
'''
precision recall f1-score support
CN 0.96 0.88 0.92 172
MCI 0.89 0.95 0.92 268
AD 0.88 0.83 0.85 104
accuracy 0.91 544
macro avg 0.91 0.89 0.90 544
weighted avg 0.91 0.91 0.91 544
'''
report_matrix(real_label, pred_final_model, subtitle='Hyperparameter Tuned CNN model')
'''
precision recall f1-score support
CN 0.97 0.98 0.98 172
MCI 0.99 0.97 0.98 268
AD 0.97 0.99 0.98 104
accuracy 0.98 544
macro avg 0.98 0.98 0.98 544
weighted avg 0.98 0.98 0.98 544
'''