SHAP

wandajeong·2023년 9월 15일
0

Machine Learning

목록 보기
15/15
post-custom-banner

Shap Summary plot & Feature Importance

  • Tree 계열 모델
import shap

model = XGBRegressor(random_state=45)
model.fit(tr_x, tr_y)

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(te_x)
shap.summary_plot(shap_values, te_x)

shap_mean =np.abs(shap_values).mean(axis=0)
importance_df = pd.DataFrame([te_x.columns.tolist(), shap_mean.tolist()]).T
importance_df.columns = ['column_name', 'shap_importance']
importance_df = importance_df.sort_values('shap_importance', ascending=False)
importance_df
  • kernel (SVM) 모델
    • 시간이 오래걸린다.
import shap

clf = svm.SVC(kernel='linear', C=0.1)
clf.fit(tr_x, tr_y)

explainer = shap.KernelExplainer(clf.predict, shap.sample(tr_x, 100))
shap_values = explainer.shap_values(te_x)
shap.summary_plot(shap_values, te_x, plot_type='bar')

shap_mean =np.abs(shap_values).mean(axis=0)
importance_df = pd.DataFrame([te_x.columns.tolist(), shap_mean.tolist()]).T
importance_df.columns = ['column_name', 'shap_importance']
importance_df = importance_df.sort_values('shap_importance', ascending=False)
importance_df

SHAP by KFold Cross Validation

kf = Kfold(n_splits =5, shuffle=True, random_state=0)

model = XGBRegressor(random_state=456)
shap_values = np.zeros_like(X)

for tr_idx, val_idx in kf.split(X, y):
	tr_x, tr_y = X.iloc[tr_idx], y.iloc[tr_idx]
    val_x, val_y = X.iloc[val_idx], y.iloc[val_idx]
    
    model.fit(tr_x, tr_y)
    explainer = shap.Explainer(model)
    shap_values_val = explainer.shap_values(val_x)
    
    shap_values[val_idx,:] = shap_values_val
    
shap.summary_plot(shap_values, X)
# shap values를 통한 변수 선택 
mean_shap_values = np.mean(np.abs(shap.values), axis=0)
shap_importance = pd.DataFrame(list(zip(X.columns, mean_shap_values)), columns=['column_name','shap_importance'])
shap_importance = shap_importance.sort_values('shap_importance', ascending=False)
profile
ML/DL swimmer
post-custom-banner

0개의 댓글