File size: 2,504 Bytes
14a6b31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e487b0
14a6b31
 
5e487b0
 
14a6b31
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import pandas as pd
import xgboost as xgb
import shap
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib

matplotlib.use("Agg")  # 防止服务器无图形界面时报错

# 指定输入特征顺序
feature_names = ["CT value(HU)", "Tumor size(cm)", "ctDNA", "CEA", "Location", "CYFRA21-1", "CA125", "LDH"]

# 加载模型(确保是用 sklearn API 训练并保存的)
model = xgb.XGBClassifier()
model.load_model("./xgb_model.json")
model.get_booster().feature_names = feature_names

# 初始化 SHAP 解释器
explainer = shap.Explainer(model)


# 预测函数
def predict_probability(CT_value, Tumor_size, ctDNA, CEA, Location, CYFRA21_1, CA125, LDH):
    input_data = pd.DataFrame(
        [[CT_value, Tumor_size, ctDNA, CEA, Location, CYFRA21_1, CA125, LDH]], columns=feature_names
    )

    # 将 Location 和 ctDNA 转换为数值型
    input_data["Location"] = input_data["Location"].map({"Central": 1, "Peripheral": 0})
    input_data["ctDNA"] = input_data["ctDNA"].map({"Positive": 1, "Negative": 0})
    # 预测
    try:
        prob = model.predict_proba(input_data)[0][1]
    except Exception as e:
        return f"预测出错: {e}", None

    # 计算 SHAP 值
    try:
        shap_values = explainer(input_data)

        # 绘图
        shap.plots.waterfall(shap_values[0], show=False)
        plt.title("SHAP Waterfall Plot")
        plt.savefig("shap_plot.png", bbox_inches="tight", dpi=300)
        plt.close()
    except Exception as e:
        return f"SHAP 图生成失败: {e}", None

    return f"阳性概率: {prob:.2%}", "shap_plot.png"


demo = gr.Interface(
    fn=predict_probability,
    inputs=[
        gr.Number(label="CT value(HU)"),
        gr.Number(label="Tumor size(cm)"),
        gr.Dropdown(choices=["Positive", "Negative"], label="ctDNA"),
        gr.Number(label="CEA (ng/mL) Normal range: 0-5"),
        gr.Dropdown(choices=["Central", "Peripheral"], label="Location"),  # 修改为 Dropdown 类型
        gr.Number(label="CYFRA21-1 (ng/mL) Normal range: 0-5"),
        gr.Number(label="CA125 (U/mL) Normal range: 0-35"),
        gr.Number(label="LDH (U/L) Normal range: 120-250"),
    ],
    outputs=[
        gr.Textbox(label="Predicted results"),
        gr.Image(type="filepath", label="SHAP Waterfall Plot"),
    ],
    title="Prediction of lymph node metastasis",
    description="Variables were entered to obtain the predicted positive probability and SHAP interpretation map",
)

demo.launch(share=True)