Spaces:
Running
Running
# app.py (7-feature aligned, server-safe) | |
import gradio as gr | |
import pandas as pd | |
import joblib | |
import shap | |
import numpy as np | |
import matplotlib | |
matplotlib.use("Agg") # 非交互后端,服务器端更稳 | |
import matplotlib.pyplot as plt | |
import warnings | |
warnings.filterwarnings("ignore") | |
# ====== 模型与背景数据 ====== | |
MODEL_PATH = "models/svm_pipeline.joblib" | |
BG_PATH = "data/bg.csv" | |
# 模型最终需要的 7 个特征(顺序必须与训练一致) | |
FEATURES = ["ALB", "TP", "TBA", "AST_ALT", "CREA", "PNI", "AAPR"] | |
# 加载模型与背景 | |
pipeline = joblib.load(MODEL_PATH) | |
bg_df = pd.read_csv(BG_PATH) | |
missing_bg = [c for c in FEATURES if c not in bg_df.columns] | |
if missing_bg: | |
raise ValueError(f"背景集缺少列: {missing_bg}") | |
bg_array = bg_df[FEATURES].to_numpy(dtype=np.float64) | |
# 预测函数(供 KernelExplainer 调用)——返回正类概率/分数 | |
def _predict_proba_nd(x_nd: np.ndarray) -> np.ndarray: | |
df = pd.DataFrame(x_nd, columns=FEATURES) | |
# 若模型有 predict_proba:取正类概率;否则退回 decision_function / predict | |
if hasattr(pipeline, "predict_proba"): | |
proba = pipeline.predict_proba(df) | |
# 确定正类索引(假定正类标签为 1;若不是,请在此处修改) | |
classes_ = getattr(pipeline, "classes_", None) | |
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1 | |
return proba[:, pos_idx] | |
elif hasattr(pipeline, "decision_function"): | |
score = pipeline.decision_function(df) | |
return score if isinstance(score, np.ndarray) else np.asarray(score) | |
else: | |
pred = pipeline.predict(df) | |
return pred if isinstance(pred, np.ndarray) else np.asarray(pred) | |
# 只初始化一次 explainer(性能更稳) | |
explainer = shap.KernelExplainer(_predict_proba_nd, bg_array) | |
def _render_force_plot(base_val: float, shap_1d: np.ndarray, feat_1d: np.ndarray, fnames): | |
"""返回 matplotlib Figure(旧接口,服务器端稳定)""" | |
plt.close('all') | |
# ——将特征值按需四舍五入,减少标签长度(可根据需要微调每列小数位) | |
feat = np.asarray(feat_1d, dtype=float).copy() | |
# 这里给出一套通用规则:ALB/TP/CREA/PNI 一位小数;TBA 两位;AST_ALT 与 AAPR 两位 | |
round_map = { | |
"ALB": 2, "TP": 2, "TBA": 2, "AST_ALT": 2, "CREA": 1, "PNI": 1, "AAPR": 3 | |
} | |
feat_rounded = [ | |
np.round(val, round_map.get(name, 2)) for val, name in zip(feat, fnames) | |
] | |
# 如需进一步减字,可把特征名缩写或去掉空格,这里保持原名 | |
shap.force_plot( | |
base_val, | |
np.asarray(shap_1d).reshape(-1), | |
np.asarray(feat_rounded).reshape(-1), | |
feature_names=list(fnames), | |
matplotlib=True, | |
show=False | |
) | |
fig = plt.gcf() | |
fig.set_size_inches(14, 3.6) # 加宽、适当降低高度 | |
fig.set_dpi(180) # 提高分辨率 | |
plt.tight_layout(pad=0.6) # 更紧凑 | |
return fig | |
def _coerce_float(x): | |
return float(x) if x is not None and x != "" else np.nan | |
def predict_and_explain(ALB, TP, TBA, AST_ALT, CREA, LYM, ALP, nsamples=200): | |
status = [] | |
try: | |
# ---- 1) 取数并校验 ---- | |
ALB = _coerce_float(ALB) | |
TP = _coerce_float(TP) | |
TBA = _coerce_float(TBA) | |
AST_ALT = _coerce_float(AST_ALT) | |
CREA = _coerce_float(CREA) | |
LYM = _coerce_float(LYM) | |
ALP = _coerce_float(ALP) | |
vals = [ALB, TP, TBA, AST_ALT, CREA, LYM, ALP] | |
if any(np.isnan(v) for v in vals): | |
return None, None, "Error: 所有输入必须为数值且不可缺失。" | |
if ALP <= 0: | |
return None, None, "Error: ALP 必须 > 0(用于计算 AAPR=ALB/ALP)。" | |
# ---- 2) 衍生指标 ---- | |
PNI = ALB + 5.0 * LYM | |
AAPR = ALB / ALP | |
status.append(f"Derived: PNI={PNI:.1f}, AAPR={AAPR:.2f}") | |
# ---- 3) 组装最终 7 特征并预测 ---- | |
x_row = np.array([[ALB, TP, TBA, AST_ALT, CREA, PNI, AAPR]], dtype=np.float64) | |
if hasattr(pipeline, "predict_proba"): | |
classes_ = getattr(pipeline, "classes_", None) | |
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1 | |
prob = float(pipeline.predict_proba(pd.DataFrame(x_row, columns=FEATURES))[0, pos_idx]) | |
status.append(f"Pred prob: {prob:.3f}") | |
else: | |
# 若无概率,给出分数 | |
score = float( | |
pipeline.decision_function(pd.DataFrame(x_row, columns=FEATURES))[0] | |
) if hasattr(pipeline, "decision_function") else float( | |
pipeline.predict(pd.DataFrame(x_row, columns=FEATURES))[0] | |
) | |
prob = score | |
status.append(f"Pred score: {score:.3f}") | |
# ---- 4) SHAP 计算 ---- | |
ns = int(nsamples) if nsamples is not None else 200 | |
shap_out = explainer.shap_values(x_row, nsamples=ns) | |
# 统一提取“一维贡献向量” | |
if isinstance(shap_out, list): | |
# 二分类:list 长度=2,取正类 | |
classes_ = getattr(pipeline, "classes_", None) | |
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1 | |
sv = np.asarray(shap_out[pos_idx], dtype=np.float64) | |
if sv.ndim == 2: # (1, n_features) | |
sv = sv[0, :] | |
else: | |
sv = np.asarray(shap_out, dtype=np.float64) | |
if sv.ndim == 3: # (1, n_features, n_classes) | |
sv = sv[0, :, 1] | |
elif sv.ndim == 2: # (1, n_features) | |
sv = sv[0, :] | |
else: | |
sv = sv.reshape(-1) | |
status.append(f"SHAP vector shape: {sv.shape}") | |
# base value | |
ev = explainer.expected_value | |
if isinstance(ev, (list, np.ndarray)): | |
ev = np.asarray(ev).reshape(-1) | |
classes_ = getattr(pipeline, "classes_", None) | |
pos_idx = int(np.where(classes_ == 1)[0][0]) if classes_ is not None else 1 | |
base_val = float(ev[pos_idx if len(ev) > pos_idx else 0]) | |
else: | |
base_val = float(ev) | |
# ---- 5) 绘图(优先 force,失败退条形图)---- | |
try: | |
fig = _render_force_plot(base_val, sv, x_row[0, :], FEATURES) | |
status.append("Rendered force plot (matplotlib).") | |
return round(float(prob), 3), fig, "\n".join(status) | |
except Exception as e_force: | |
status.append(f"Force-plot failed: {repr(e_force)}; fallback=bar") | |
order = np.argsort(np.abs(sv))[::-1] | |
topk = order[:min(7, sv.shape[0])] | |
plt.close('all') | |
fig = plt.figure(figsize=(8, 5), dpi=160) | |
plt.barh(np.array(FEATURES)[topk], sv[topk]) | |
plt.xlabel("SHAP value") | |
plt.title("Top features (single-sample contribution)") | |
plt.gca().invert_yaxis() | |
plt.tight_layout() | |
status.append("Rendered bar fallback.") | |
return round(float(prob), 3), fig, "\n".join(status) | |
except Exception as e: | |
return None, None, f"Fatal error: {repr(e)}" | |
# ====== 示例输入(仅 7 项 + nsamples)====== | |
example_values = [41.7, 64.9, 0.870, 0.890, 55.9, 1.95, 51.96, 200] | |
# ====== Gradio 界面 ====== | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"### Meige Risk Prediction (SVM) with SHAP Explanation\n" | |
"Please enter ALB, TP, TBA, AST/ALT, CREA, LYM, and ALP. The values for PNI and AAPR will be calculated automatically.\n\n" | |
"**Units**: ALB(g/L), TP(g/L), TBA(μmol/L), AST/ALT(ratio), CREA(μmol/L), " | |
"LYM(×10⁹/L), ALP(U/L)." | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
inputs = [ | |
gr.Number(label="ALB (g/L)"), | |
gr.Number(label="TP (g/L)"), | |
gr.Number(label="TBA (μmol/L)"), | |
gr.Number(label="AST/ALT"), | |
gr.Number(label="CREA (μmol/L)"), | |
gr.Number(label="LYM (×10⁹/L)"), | |
gr.Number(label="ALP (U/L)"), | |
] | |
ns_slider = gr.Slider(100, 500, value=200, step=50, label="SHAP nsamples") | |
btn_fill = gr.Button("Fill Example") | |
btn_predict = gr.Button("Predict") | |
with gr.Column(scale=1): | |
out_prob = gr.Number(label="Predicted Probability / Score") | |
out_plot = gr.Plot(label="SHAP Force Plot (fallback: bar)") | |
out_log = gr.Textbox(label="Status", lines=8) | |
def _fill_example(): | |
return tuple(example_values) | |
btn_fill.click(fn=_fill_example, outputs=[*inputs, ns_slider]) | |
btn_predict.click( | |
fn=predict_and_explain, | |
inputs=[*inputs, ns_slider], | |
outputs=[out_prob, out_plot, out_log] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |