SHAP - 코드 수정

2한나·2026년 4월 19일

수정할 내용

  • seed = 12(ExtraTrees 기준)로 고정하기
  • 백그라운드 데이터는 train data 전체를 stratified sampling해서 1000개 넣기
  • 샘플 데이터는 테스트 데이터만 전체 넣기
  • bar plot
    • 전체 데이터/클래스별 데이터에 대해 생성 ⇒ 모두 mean(|SHAP value|)로 수정
    • plot 종류: global bar plot, cohort bar plot, 클러스터링
  • beeswarm plot
    • 전체 클래스 평균 그리는 코드는 삭제, 클래스별로만 plot 그리기
  • decision plot 코드 추가
  • scatter plot 코드 추가
    • metallicity랑 다른 피처 shap 상관관계만 그리기
  • predicted class필터링하고 → confidence 필터링하기
    • confidence: 0, 20, 40, 60, 80으로 필터링해서 plot 그리기

수정 코드


import os
import numpy as np
import pandas as pd
import shap
import matplotlib.pyplot as plt
import joblib
from sklearn.model_selection import train_test_split

# ============================================================
# USER SETTINGS 
# ============================================================
DATA_PATH    = "/proj/home/ibs/spaceai_2025/mohan0226/SYNERGI/ImageExclusiveModel/data/preprocessed/IllustrisTNG_preprocessed.csv"
MODEL_PATH   = "/proj/home/ibs/spaceai_2025/mohan0226/SYNERGI/ImageExclusiveModel/output/model/classical_machine_learning/ExtraTrees_multiseed/seed_12/ExtraTrees_seed12.pkl"
SCALER_PATH  = "/proj/home/ibs/spaceai_2025/mohan0226/SYNERGI/ImageExclusiveModel/output/model/preprocess/StandardScaler.pkl"
SAVE_DIR     = "/proj/home/ibs/spaceai_2025/mohan0226/SYNERGI/ImageExclusiveModel/output/analysis/shap/ExtraTrees_hannah"

MODEL_NAME       = "ExtraTrees_multiseed_12"
SEED             = 12
TEST_SIZE        = 0.1
N_BACKGROUND     = 1000
TARGET_COL       = "Phase"
DROP_COLS        = ["SubHaloID", "Snapshot"]
CLASS_ID_MAP     = {"non": 0, "pre": 1, "post": 2}
METALLICITY_COL  = "Metallicity"    # 실제 컬럼명으로 변경 필요
CONF_THRESHOLDS  = [0.0, 0.2, 0.4, 0.6, 0.8]
DECISION_MAX     = 20
DPI              = 300

# ============================================================
# OUTPUT FOLDER STRUCTURE (자동 생성)
# SAVE_DIR/{cls_name}/conf{pct}/
#   {tag}_bar_global.png / _bar_cls.png / _bar_cohort.png / _bar_clustering.png
#   {tag}_beeswarm.png / _decision.png / _scatter_metallicity_{feature}.png
# ============================================================
for cls_name in list(CLASS_ID_MAP.keys()) + ["all"]:
    for t in CONF_THRESHOLDS:
        os.makedirs(os.path.join(SAVE_DIR, cls_name, f"conf{int(t*100)}"), exist_ok=True)

# ============================================================
# DATA / MODEL LOAD
# ============================================================
np.random.seed(SEED)

df    = pd.read_csv(DATA_PATH)
X_all = df.drop(columns=[TARGET_COL] + DROP_COLS)
y_all = df[TARGET_COL].values

X_train, X_test, y_train, y_test = train_test_split(
    X_all, y_all, test_size=TEST_SIZE, stratify=y_all, random_state=SEED
)
X_bg, _, y_bg, _ = train_test_split(
    X_train, y_train, train_size=N_BACKGROUND, stratify=y_train, random_state=SEED
)
print(f"train={len(X_train)}, test={len(X_test)}, background={len(X_bg)}")

model = joblib.load(MODEL_PATH)

try:
    scaler     = joblib.load(SCALER_PATH)
    X_test_raw = pd.DataFrame(scaler.inverse_transform(X_test), columns=X_test.columns, index=X_test.index)
    print("[완료] scaler inverse_transform 적용")
except Exception:
    X_test_raw = X_test.copy()
    print("[오류] scaler 로드 실패. scaled 값 그대로 사용")

# ============================================================
# SHAP COMPUTE
# ============================================================
print("[진행] SHAP values 계산 중...")
explainer    = shap.TreeExplainer(model, data=X_bg)
shap_raw    = explainer.shap_values(X_test)
# sklearn 계열 TreeExplainer는 (N, F, C) ndarray로 반환하는 경우가 있음
# list[C] of (N, F) 형태면 stack, 이미 3D ndarray면 그대로 사용
if isinstance(shap_raw, list):
    shap_values = np.stack(shap_raw, axis=2)    # (N, F, C)
else:
    shap_values = shap_raw                      # (N, F, C)
proba        = model.predict_proba(X_test)
pred_class   = np.argmax(proba, axis=1)
max_conf     = np.max(proba, axis=1)
feat_cluster = shap.utils.hclust(X_test, y_test)
print(f"[완료] SHAP shape: {shap_values.shape}")

# ============================================================
# PLOT HELPERS
# ============================================================
def savefig(path):
    plt.tight_layout()
    plt.savefig(path, dpi=DPI)
    plt.close()
    print(f"[저장] {path}")

def plot_bar_global(sv, X, path):
    plt.figure()
    shap.summary_plot(sv, X, plot_type="bar", show=False)
    savefig(path)

def plot_bar_cls(sv_cls, X, path):
    plt.figure()
    shap.summary_plot(np.abs(sv_cls), X, plot_type="bar", show=False)
    savefig(path)

def plot_bar_cohort(sv, X, path):
    cohorts = {
        name: shap.Explanation(
            values=sv[:, :, cid],
            base_values=np.full(len(sv), explainer.expected_value[cid]),
            data=X.values, feature_names=list(X.columns),
        )
        for name, cid in CLASS_ID_MAP.items()
    }
    plt.figure()
    shap.plots.bar(cohorts, show=False)
    savefig(path)

def plot_bar_clustering(sv, X, path):
    exp = shap.Explanation(
        values=np.mean(np.abs(sv), axis=2),
        base_values=np.zeros(len(sv)),
        data=X.values, feature_names=list(X.columns),
    )
    plt.figure()
    shap.plots.bar(exp, clustering=feat_cluster, show=False)
    savefig(path)

def _add_secondary_xticks(ax, scaler, col, scaled_ticks=None):
    """x축에 scaled tick 아래 original 값을 두 번째 줄로 추가"""
    col_idx = list(scaler.feature_names_in_).index(col) if hasattr(scaler, "feature_names_in_") else None
    if col_idx is None:
        return
    if scaled_ticks is None:
        scaled_ticks = ax.get_xticks()
    scaled_ticks = np.array(scaled_ticks)
    # inverse transform: col 하나만 복원
    dummy = np.zeros((len(scaled_ticks), scaler.mean_.shape[0]))
    dummy[:, col_idx] = scaled_ticks
    orig_vals = scaler.inverse_transform(dummy)[:, col_idx]
    # 기존 tick label 아래에 original 값 추가
    labels = [f"{s:.2f}\n({o:.2f})" for s, o in zip(scaled_ticks, orig_vals)]
    ax.set_xticks(scaled_ticks)
    ax.set_xticklabels(labels, fontsize=8)
    ax.set_xlabel(f"{ax.get_xlabel()}\n(위: scaled, 아래: original)", fontsize=9)

def _add_secondary_yticks(ax, scaler, col, scaled_ticks=None):
    """y축에 scaled tick 옆에 original 값을 두 번째 줄로 추가 (colorbar용 아닌 경우)"""
    col_idx = list(scaler.feature_names_in_).index(col) if hasattr(scaler, "feature_names_in_") else None
    if col_idx is None:
        return
    if scaled_ticks is None:
        scaled_ticks = ax.get_yticks()
    scaled_ticks = np.array(scaled_ticks)
    dummy = np.zeros((len(scaled_ticks), scaler.mean_.shape[0]))
    dummy[:, col_idx] = scaled_ticks
    orig_vals = scaler.inverse_transform(dummy)[:, col_idx]
    labels = [f"{s:.2f} ({o:.2f})" for s, o in zip(scaled_ticks, orig_vals)]
    ax.set_yticks(scaled_ticks)
    ax.set_yticklabels(labels, fontsize=8)

def _add_dual_colorbar(fig, scaler, col):
    """beeswarm colorbar에 scaled / original 두 스케일 표시"""
    col_idx = list(scaler.feature_names_in_).index(col) if hasattr(scaler, "feature_names_in_") else None
    # colorbar axes 찾기
    cb_ax = None
    for ax in fig.axes:
        if ax != fig.axes[0]:
            cb_ax = ax
            break
    if cb_ax is None or col_idx is None:
        return
    scaled_ticks = np.array(cb_ax.get_yticks())
    dummy = np.zeros((len(scaled_ticks), scaler.mean_.shape[0]))
    dummy[:, col_idx] = scaled_ticks
    orig_vals = scaler.inverse_transform(dummy)[:, col_idx]
    labels = [f"{s:.2f}\n({o:.2f})" for s, o in zip(scaled_ticks, orig_vals)]
    cb_ax.set_yticks(scaled_ticks)
    cb_ax.set_yticklabels(labels, fontsize=7)
    cb_ax.set_ylabel("Feature value\n(위: scaled, 아래: original)", fontsize=8)

def plot_beeswarm(sv_cls, X_scaled, path_prefix):
    # beeswarm colorbar는 feature별 상대적 크기(high/low)를 나타내므로
    # dual scale 적용 불가 → scaled X로만 생성
    plt.figure()
    shap.summary_plot(sv_cls, X_scaled, plot_type="dot", show=False)
    savefig(f"{path_prefix}.png")

def plot_decision(sv_cls, X_scaled, X_raw, expected_val, path_prefix):
    # x축은 SHAP 누적값, y축은 feature 이름만 → feature 값이 축에 없으므로 scaled 하나만 생성
    plt.figure()
    shap.decision_plot(
        expected_val, sv_cls, X_scaled,
        show=False, feature_display_range=slice(-1, -DECISION_MAX - 1, -1),
    )
    savefig(f"{path_prefix}.png")

def plot_scatter_metallicity(sv_cls, X_scaled, X_raw, prefix):
    if METALLICITY_COL not in X_scaled.columns:
        print(f"[오류] '{METALLICITY_COL}' 컬럼 없음. scatter 건너뜀")
        return
    met_idx  = list(X_scaled.columns).index(METALLICITY_COL)
    other_cols = [c for c in X_scaled.columns if c != METALLICITY_COL]
    for col in other_cols:
        other_idx = list(X_scaled.columns).index(col)
        plt.figure()
        shap.dependence_plot(
            met_idx, sv_cls, X_scaled.values,
            feature_names=list(X_scaled.columns),
            interaction_index=other_idx, show=False,
        )
        # x축(metallicity): scaled -> original 두 줄 tick
        ax = plt.gca()
        try:
            _add_secondary_xticks(ax, scaler, METALLICITY_COL)
        except Exception:
            pass
        savefig(f"{prefix}_{col}.png")

# ============================================================
# ALL: 전체 샘플 기준 bar plot (클래스 필터 없음, confidence별)
# ============================================================
print(f"\n{'='*50}")
print("[전체] all samples bar plots")
for thresh in CONF_THRESHOLDS:
    pct = int(thresh * 100)
    idx_all = np.where(max_conf >= thresh)[0]
    if len(idx_all) == 0:
        print(f"  [건너뜀] conf>={pct}% 샘플 없음")
        continue
    print(f"  [필터] conf>={pct}%: {len(idx_all)} samples")

    sv_all  = shap_values[idx_all]
    X_all_sub = X_test.iloc[idx_all].reset_index(drop=True)

    all_dir = os.path.join(SAVE_DIR, "all", f"conf{pct}")
    tag_all = f"{MODEL_NAME}_all_conf{pct}"

    plot_bar_global     (np.mean(np.abs(sv_all), axis=2), X_all_sub, f"{all_dir}/{tag_all}_bar_global.png")
    plot_bar_cohort     (sv_all, X_all_sub,                           f"{all_dir}/{tag_all}_bar_cohort.png")
    plot_bar_clustering (sv_all, X_all_sub,                           f"{all_dir}/{tag_all}_bar_clustering.png")

# ============================================================
# MAIN LOOP: predicted class -> confidence
# ============================================================
for cls_name, cls_id in CLASS_ID_MAP.items():
    print(f"\n{'='*50}")
    print(f"[클래스] {cls_name} (id={cls_id})")

    idx_cls = np.where(pred_class == cls_id)[0]
    if len(idx_cls) == 0:
        print(f"  [건너뜀] predicted={cls_name} 샘플 없음")
        continue

    for thresh in CONF_THRESHOLDS:
        pct = int(thresh * 100)
        idx = idx_cls[max_conf[idx_cls] >= thresh]
        if len(idx) == 0:
            print(f"  [건너뜀] conf>={pct}% 샘플 없음")
            continue
        print(f"  [필터] conf>={pct}%: {len(idx)} samples")

        sv     = shap_values[idx]               # (n, F, C)
        sv_cls = sv[:, :, cls_id]               # (n, F)
        X_sub  = X_test.iloc[idx].reset_index(drop=True)
        X_raw  = X_test_raw.iloc[idx].reset_index(drop=True)

        d   = os.path.join(SAVE_DIR, cls_name, f"conf{pct}")
        tag = f"{MODEL_NAME}_{cls_name}_conf{pct}"

        plot_bar_global     (np.mean(np.abs(sv), axis=2), X_sub, f"{d}/{tag}_bar_global.png")
        plot_bar_cls        (sv_cls, X_sub,               f"{d}/{tag}_bar_cls.png")
        plot_bar_cohort     (sv,     X_sub,               f"{d}/{tag}_bar_cohort.png")
        plot_bar_clustering (sv,     X_sub,               f"{d}/{tag}_bar_clustering.png")
        plot_beeswarm       (sv_cls, X_sub,                f"{d}/{tag}_beeswarm")
        plot_decision       (sv_cls, X_sub, X_raw, explainer.expected_value[cls_id], f"{d}/{tag}_decision")
        plot_scatter_metallicity(sv_cls, X_sub, X_raw, f"{d}/{tag}_scatter_metallicity")

print("\n[완료] 모든 SHAP plot 저장 완료")

0개의 댓글