ML_모델평가

이병찬·2024년 3월 15일

ML

목록 보기
8/14

회귀모델들은 실제 값과의 에러치를 가지고 계산

분류 모델은 평가 항목이 많다

이진 분류 모델의 평가

  • 참/거짓을 찾는 과정에서 답은 2가지지만 예측 결과는 4가지로 나온다

Accuracy

Precision, 정밀도

  • 모델이 True라고 분류한 것 중에서 실제 True인 것의 비율

Recall, 재현율

  • 실제 True인 것 중에서 모델이 True라고 예측한 것의 비율
  • 놓쳐서는 안되는 참 값인 데이터를 신경써야 될 때 봐야되는 지표(ex. 심장병 판단, 불량품 분류)

Fall-Out

threshold, 임계값

  • 분류 모델은 예측한 결과에 대한 확률을 반환하며 이때, 설정된 threshold/임계값에 의해 예측된 확률이 결과로 반환하는 기준 값으로 사용
  • 이진 분류에서는 주로 두 가지 클래스 중 하나를 선택하고 예측된 확률 값을 제공하며, 이 확률 값은 0과 1 사이의 범위에 존재
  • 이러한 예측된 확률을 기반으로 클래스를 할당할 때, threshold(임계값)을 사용하여 결정
  • 보통 threshold는 0.5로 설정되어 있어서, 예측된 확률 값이 0.5보다 크면 클래스 1로 분류하고, 그렇지 않으면 클래스 0으로 분류
  • 이 threshold, 임계값을 다르게 설정하면 각 평가지표의 값들이 다르게 나오게되고 이를 이용하여 모델 성능을 향상
  • 클래스 불균형이 있는 데이터셋에서는 threshold를 조정하여 재현율(recall)이나 정밀도(precision)를 최적화 : 임계값을 낮추면 더 많은 샘플이 positive로 분류되어 재현율은 증가하지만 정밀도는 감소하며 반대로 임계값을 높이면 재현율은 감소하고 정밀도는 증가

F1-Score

ROC 곡선


AUC 면적

데이터로 ROC 커브 그려보기

from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve)

from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, roc_auc_score, roc_curve)

print('accuracy :', accuracy_score(y_test, y_pred_test))
print('recall :', recall_score(y_test, y_pred_test))
print('precision :', precision_score(y_test, y_pred_test))
print('auc score :', roc_auc_score(y_test, y_pred_test))
print('f1 score :', f1_score(y_test, y_pred_test))

fpr = fall-out, tpr = recall, thresholds 확인

wine_tree.predict_proba(X_test)
# predict_proba() : 예측된 값이 속성(여기선 taste의 0 or 1)들 중 얼마나 비중을 차지하는지 보여주는 함수
# ['0'일 확률, '1'일 확률]

import matplotlib.pyplot as plt
%matplotlib inline

pred_proba = wine_tree.predict_proba(X_test)[:, 1]
                    #.predict_proba(X_test)[:, 1] : 여기선 taste의 '1'일 확률들만의 값으로 가져온다는 설정
fpr, tpr, thresholds = roc_curve(y_test, pred_proba)
# fpr = fall-out, tpr = recall

ROC 커브 그리기

plt.figure(figsize=(5, 4))
plt.plot([0, 1], [0, 1])
# plt.plot([0, 1], [0, 1]) : X축과 Y축 값이 모두 0부터 1까지인 대각선
plt.plot(fpr, tpr)
plt.grid()
plt.show()

profile
비전공 데이터 분석가 도전

0개의 댓글