수정할 내용
- seed = 12(ExtraTrees 기준)로 고정하기
- 백그라운드 데이터는 train data 전체를 stratified sampling해서 1000개 넣기
- 샘플 데이터는 테스트 데이터만 전체 넣기
- bar plot
- 전체 데이터/클래스별 데이터에 대해 생성 ⇒ 모두 mean(|SHAP value|)로 수정
- plot 종류: global bar plot, cohort bar plot, 클러스터링
- beeswarm plot
- 전체 클래스 평균 그리는 코드는 삭제, 클래스별로만 plot 그리기
- decision plot 코드 추가
- scatter plot 코드 추가
- metallicity랑 다른 피처 shap 상관관계만 그리기
- predicted class필터링하고 → confidence 필터링하기
- confidence: 0, 20, 40, 60, 80으로 필터링해서 plot 그리기
import os
import numpy as np
import pandas as pd
import shap
import matplotlib.pyplot as plt
import joblib
from sklearn.model_selection import train_test_split
# ============================================================
# USER SETTINGS
# ============================================================
DATA_PATH = "/proj/home/ibs/spaceai_2025/mohan0226/SYNERGI/ImageExclusiveModel/data/preprocessed/IllustrisTNG_preprocessed.csv"
MODEL_PATH = "/proj/home/ibs/spaceai_2025/mohan0226/SYNERGI/ImageExclusiveModel/output/model/classical_machine_learning/ExtraTrees_multiseed/seed_12/ExtraTrees_seed12.pkl"
SCALER_PATH = "/proj/home/ibs/spaceai_2025/mohan0226/SYNERGI/ImageExclusiveModel/output/model/preprocess/StandardScaler.pkl"
SAVE_DIR = "/proj/home/ibs/spaceai_2025/mohan0226/SYNERGI/ImageExclusiveModel/output/analysis/shap/ExtraTrees_hannah"
MODEL_NAME = "ExtraTrees_multiseed_12"
SEED = 12
TEST_SIZE = 0.1
N_BACKGROUND = 1000
TARGET_COL = "Phase"
DROP_COLS = ["SubHaloID", "Snapshot"]
CLASS_ID_MAP = {"non": 0, "pre": 1, "post": 2}
METALLICITY_COL = "Metallicity" # 실제 컬럼명으로 변경 필요
CONF_THRESHOLDS = [0.0, 0.2, 0.4, 0.6, 0.8]
DECISION_MAX = 20
DPI = 300
# ============================================================
# OUTPUT FOLDER STRUCTURE (자동 생성)
# SAVE_DIR/{cls_name}/conf{pct}/
# {tag}_bar_global.png / _bar_cls.png / _bar_cohort.png / _bar_clustering.png
# {tag}_beeswarm.png / _decision.png / _scatter_metallicity_{feature}.png
# ============================================================
for cls_name in list(CLASS_ID_MAP.keys()) + ["all"]:
for t in CONF_THRESHOLDS:
os.makedirs(os.path.join(SAVE_DIR, cls_name, f"conf{int(t*100)}"), exist_ok=True)
# ============================================================
# DATA / MODEL LOAD
# ============================================================
np.random.seed(SEED)
df = pd.read_csv(DATA_PATH)
X_all = df.drop(columns=[TARGET_COL] + DROP_COLS)
y_all = df[TARGET_COL].values
X_train, X_test, y_train, y_test = train_test_split(
X_all, y_all, test_size=TEST_SIZE, stratify=y_all, random_state=SEED
)
X_bg, _, y_bg, _ = train_test_split(
X_train, y_train, train_size=N_BACKGROUND, stratify=y_train, random_state=SEED
)
print(f"train={len(X_train)}, test={len(X_test)}, background={len(X_bg)}")
model = joblib.load(MODEL_PATH)
try:
scaler = joblib.load(SCALER_PATH)
X_test_raw = pd.DataFrame(scaler.inverse_transform(X_test), columns=X_test.columns, index=X_test.index)
print("[완료] scaler inverse_transform 적용")
except Exception:
X_test_raw = X_test.copy()
print("[오류] scaler 로드 실패. scaled 값 그대로 사용")
# ============================================================
# SHAP COMPUTE
# ============================================================
print("[진행] SHAP values 계산 중...")
explainer = shap.TreeExplainer(model, data=X_bg)
shap_raw = explainer.shap_values(X_test)
# sklearn 계열 TreeExplainer는 (N, F, C) ndarray로 반환하는 경우가 있음
# list[C] of (N, F) 형태면 stack, 이미 3D ndarray면 그대로 사용
if isinstance(shap_raw, list):
shap_values = np.stack(shap_raw, axis=2) # (N, F, C)
else:
shap_values = shap_raw # (N, F, C)
proba = model.predict_proba(X_test)
pred_class = np.argmax(proba, axis=1)
max_conf = np.max(proba, axis=1)
feat_cluster = shap.utils.hclust(X_test, y_test)
print(f"[완료] SHAP shape: {shap_values.shape}")
# ============================================================
# PLOT HELPERS
# ============================================================
def savefig(path):
plt.tight_layout()
plt.savefig(path, dpi=DPI)
plt.close()
print(f"[저장] {path}")
def plot_bar_global(sv, X, path):
plt.figure()
shap.summary_plot(sv, X, plot_type="bar", show=False)
savefig(path)
def plot_bar_cls(sv_cls, X, path):
plt.figure()
shap.summary_plot(np.abs(sv_cls), X, plot_type="bar", show=False)
savefig(path)
def plot_bar_cohort(sv, X, path):
cohorts = {
name: shap.Explanation(
values=sv[:, :, cid],
base_values=np.full(len(sv), explainer.expected_value[cid]),
data=X.values, feature_names=list(X.columns),
)
for name, cid in CLASS_ID_MAP.items()
}
plt.figure()
shap.plots.bar(cohorts, show=False)
savefig(path)
def plot_bar_clustering(sv, X, path):
exp = shap.Explanation(
values=np.mean(np.abs(sv), axis=2),
base_values=np.zeros(len(sv)),
data=X.values, feature_names=list(X.columns),
)
plt.figure()
shap.plots.bar(exp, clustering=feat_cluster, show=False)
savefig(path)
def _add_secondary_xticks(ax, scaler, col, scaled_ticks=None):
"""x축에 scaled tick 아래 original 값을 두 번째 줄로 추가"""
col_idx = list(scaler.feature_names_in_).index(col) if hasattr(scaler, "feature_names_in_") else None
if col_idx is None:
return
if scaled_ticks is None:
scaled_ticks = ax.get_xticks()
scaled_ticks = np.array(scaled_ticks)
# inverse transform: col 하나만 복원
dummy = np.zeros((len(scaled_ticks), scaler.mean_.shape[0]))
dummy[:, col_idx] = scaled_ticks
orig_vals = scaler.inverse_transform(dummy)[:, col_idx]
# 기존 tick label 아래에 original 값 추가
labels = [f"{s:.2f}\n({o:.2f})" for s, o in zip(scaled_ticks, orig_vals)]
ax.set_xticks(scaled_ticks)
ax.set_xticklabels(labels, fontsize=8)
ax.set_xlabel(f"{ax.get_xlabel()}\n(위: scaled, 아래: original)", fontsize=9)
def _add_secondary_yticks(ax, scaler, col, scaled_ticks=None):
"""y축에 scaled tick 옆에 original 값을 두 번째 줄로 추가 (colorbar용 아닌 경우)"""
col_idx = list(scaler.feature_names_in_).index(col) if hasattr(scaler, "feature_names_in_") else None
if col_idx is None:
return
if scaled_ticks is None:
scaled_ticks = ax.get_yticks()
scaled_ticks = np.array(scaled_ticks)
dummy = np.zeros((len(scaled_ticks), scaler.mean_.shape[0]))
dummy[:, col_idx] = scaled_ticks
orig_vals = scaler.inverse_transform(dummy)[:, col_idx]
labels = [f"{s:.2f} ({o:.2f})" for s, o in zip(scaled_ticks, orig_vals)]
ax.set_yticks(scaled_ticks)
ax.set_yticklabels(labels, fontsize=8)
def _add_dual_colorbar(fig, scaler, col):
"""beeswarm colorbar에 scaled / original 두 스케일 표시"""
col_idx = list(scaler.feature_names_in_).index(col) if hasattr(scaler, "feature_names_in_") else None
# colorbar axes 찾기
cb_ax = None
for ax in fig.axes:
if ax != fig.axes[0]:
cb_ax = ax
break
if cb_ax is None or col_idx is None:
return
scaled_ticks = np.array(cb_ax.get_yticks())
dummy = np.zeros((len(scaled_ticks), scaler.mean_.shape[0]))
dummy[:, col_idx] = scaled_ticks
orig_vals = scaler.inverse_transform(dummy)[:, col_idx]
labels = [f"{s:.2f}\n({o:.2f})" for s, o in zip(scaled_ticks, orig_vals)]
cb_ax.set_yticks(scaled_ticks)
cb_ax.set_yticklabels(labels, fontsize=7)
cb_ax.set_ylabel("Feature value\n(위: scaled, 아래: original)", fontsize=8)
def plot_beeswarm(sv_cls, X_scaled, path_prefix):
# beeswarm colorbar는 feature별 상대적 크기(high/low)를 나타내므로
# dual scale 적용 불가 → scaled X로만 생성
plt.figure()
shap.summary_plot(sv_cls, X_scaled, plot_type="dot", show=False)
savefig(f"{path_prefix}.png")
def plot_decision(sv_cls, X_scaled, X_raw, expected_val, path_prefix):
# x축은 SHAP 누적값, y축은 feature 이름만 → feature 값이 축에 없으므로 scaled 하나만 생성
plt.figure()
shap.decision_plot(
expected_val, sv_cls, X_scaled,
show=False, feature_display_range=slice(-1, -DECISION_MAX - 1, -1),
)
savefig(f"{path_prefix}.png")
def plot_scatter_metallicity(sv_cls, X_scaled, X_raw, prefix):
if METALLICITY_COL not in X_scaled.columns:
print(f"[오류] '{METALLICITY_COL}' 컬럼 없음. scatter 건너뜀")
return
met_idx = list(X_scaled.columns).index(METALLICITY_COL)
other_cols = [c for c in X_scaled.columns if c != METALLICITY_COL]
for col in other_cols:
other_idx = list(X_scaled.columns).index(col)
plt.figure()
shap.dependence_plot(
met_idx, sv_cls, X_scaled.values,
feature_names=list(X_scaled.columns),
interaction_index=other_idx, show=False,
)
# x축(metallicity): scaled -> original 두 줄 tick
ax = plt.gca()
try:
_add_secondary_xticks(ax, scaler, METALLICITY_COL)
except Exception:
pass
savefig(f"{prefix}_{col}.png")
# ============================================================
# ALL: 전체 샘플 기준 bar plot (클래스 필터 없음, confidence별)
# ============================================================
print(f"\n{'='*50}")
print("[전체] all samples bar plots")
for thresh in CONF_THRESHOLDS:
pct = int(thresh * 100)
idx_all = np.where(max_conf >= thresh)[0]
if len(idx_all) == 0:
print(f" [건너뜀] conf>={pct}% 샘플 없음")
continue
print(f" [필터] conf>={pct}%: {len(idx_all)} samples")
sv_all = shap_values[idx_all]
X_all_sub = X_test.iloc[idx_all].reset_index(drop=True)
all_dir = os.path.join(SAVE_DIR, "all", f"conf{pct}")
tag_all = f"{MODEL_NAME}_all_conf{pct}"
plot_bar_global (np.mean(np.abs(sv_all), axis=2), X_all_sub, f"{all_dir}/{tag_all}_bar_global.png")
plot_bar_cohort (sv_all, X_all_sub, f"{all_dir}/{tag_all}_bar_cohort.png")
plot_bar_clustering (sv_all, X_all_sub, f"{all_dir}/{tag_all}_bar_clustering.png")
# ============================================================
# MAIN LOOP: predicted class -> confidence
# ============================================================
for cls_name, cls_id in CLASS_ID_MAP.items():
print(f"\n{'='*50}")
print(f"[클래스] {cls_name} (id={cls_id})")
idx_cls = np.where(pred_class == cls_id)[0]
if len(idx_cls) == 0:
print(f" [건너뜀] predicted={cls_name} 샘플 없음")
continue
for thresh in CONF_THRESHOLDS:
pct = int(thresh * 100)
idx = idx_cls[max_conf[idx_cls] >= thresh]
if len(idx) == 0:
print(f" [건너뜀] conf>={pct}% 샘플 없음")
continue
print(f" [필터] conf>={pct}%: {len(idx)} samples")
sv = shap_values[idx] # (n, F, C)
sv_cls = sv[:, :, cls_id] # (n, F)
X_sub = X_test.iloc[idx].reset_index(drop=True)
X_raw = X_test_raw.iloc[idx].reset_index(drop=True)
d = os.path.join(SAVE_DIR, cls_name, f"conf{pct}")
tag = f"{MODEL_NAME}_{cls_name}_conf{pct}"
plot_bar_global (np.mean(np.abs(sv), axis=2), X_sub, f"{d}/{tag}_bar_global.png")
plot_bar_cls (sv_cls, X_sub, f"{d}/{tag}_bar_cls.png")
plot_bar_cohort (sv, X_sub, f"{d}/{tag}_bar_cohort.png")
plot_bar_clustering (sv, X_sub, f"{d}/{tag}_bar_clustering.png")
plot_beeswarm (sv_cls, X_sub, f"{d}/{tag}_beeswarm")
plot_decision (sv_cls, X_sub, X_raw, explainer.expected_value[cls_id], f"{d}/{tag}_decision")
plot_scatter_metallicity(sv_cls, X_sub, X_raw, f"{d}/{tag}_scatter_metallicity")
print("\n[완료] 모든 SHAP plot 저장 완료")