Machine Learning Evaluation Tools

Jewook·2022년 9월 12일


for skewed datasets, We need more evaluation techniques for learning model than accuracy.

Confusion Matrix

from sklearn.metrics import confusion_matrix

# y, pred를 넣으면 2*2 ndarray return
# TN  FP
# FN  TP
confusion_matrix( y_test, pred )

Precision and Recall

precision = TP / (FP + TP)
recall = TP / (FN + TP)

P로 예측한 것 중에 진짜 P의 비율
실제 P인 것중에 P로 예측한 비율

from sklearn.metrics import precision_score, recall_score

precision = precision_score(y_test, pred)
recall = precision_score(y_test, pred)


classification의 경우, 결과값이 아닌 확률을 리턴


from sklearn.preprocessing import Binarizer

def get_eval_by_threshold(y_test, pred_proba_c1, thresholds):
	for custom_threshold in thresholds:
    	binarizer = Binarizer(threshold = custom_threshold)
        custom_predict = binarizer.transform(pred_proba_c1)
        print('Threshold : custom_threshold')
        clf_eval(y_test, custom_predict, pred_proba_c1)
def clf_eval(y_test, pred=None, pred_proba = None):   
	confusion = confusion_matrix(y_test, pred)
    accuracy = accuracy_score(y_test, pred)
    precision = precision_score(y_test, pred)
    recall = recall_score(y_test, pred)
    f1 = f1_score(y_test, pred)
    roc_auc = roc_auc_score(y_test, pred_proba)
    print(f'Accuracy : {accuracy:.4f}, Precision : {precision:.4f}, Recall : {recall:.4f}, F1 : {f1:.4f}, AUC : {roc_auc:.4f}')

Precision Recall Curve

from sklearn.metrics import precision_recall_curve
# precision_recall_curve 시각화

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
%matplotlib inline

def precision_recall_curve_plot(y_test, pred_proba_c1):
    precisions, recalls, thresholds = precision_recall_curve(y_test, pred_proba_c1)

    plt.figure(figsize = (8,6))
    threshold_boundary = thresholds.shape[0]
    plt.plot(thresholds, precisions[0:threshold_boundary], linestyle = '--', label = 'precision')
    plt.plot(thresholds, recalls[0:threshold_boundary], label = 'recall')

    # x축 0.1 단위로 표시
    start, end = plt.xlim()
    plt.xticks(np.round(np.arrange(start, end, 0.1), 2))

    plt.xlabel('Threshold value'); plt.ylabel('Precision and Recall value')
    plt.legend(); plt.grid()

ROC Curve

from sklearn.metrics import roc_curve, roc_auc_score

def roc_curve_plot(y_test, pred_proba_c1):

    fprs, tprs, thresholds = roc_curve(y_test, pred_proba_c1)
    plt.plot(fprs, tprs, label = 'ROC')
    plt.plot([0,1], [0,1], 'k--', label = 'Random')

    start, end = plt.xlim()
    plt.title( f'ROC_Curve (AUC : {roc_auc_score(y_test, pred_proba_c1):.4f}')
    plt.xticks(np.round( np.arange(start, end, 0.1), 2))
    plt.xlim(0,1); plt.ylim(0,1)
    plt.xlabel('FPR(1 - Sensitivity )'); plt.ylabel('TPR( Recall )')


