ML_predict_LNM / app.py
Ashleygxr's picture
Update app.py
1cee686 verified
import numpy as np
import pandas as pd
import xgboost as xgb
import shap
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg") # 防止服务器无图形界面时报错
# 指定输入特征顺序
feature_names = [
"CT value(HU)",
"Tumor size(mm)",
"ctDNA",
"CEA",
"Location",
"CYFRA21-1",
"AAPR",
"CA125",
"LDH",
"ANC",
"ALT",
"GGT",
"CREA",
"UREA",
"Pleural indentation",
]
# 加载模型(确保是用 sklearn API 训练并保存的)
model = xgb.XGBClassifier()
model.load_model("./xgb_model.json")
model.get_booster().feature_names = feature_names
# 初始化 SHAP 解释器
explainer = shap.Explainer(model)
# 预测函数
def predict_probability(
CT_value,
Tumor_size,
ctDNA,
CEA,
Location,
CYFRA21_1,
AAPR,
CA125,
LDH,
ANC,
ALT,
GGT,
CREA,
UREA,
Pleural_indentation,
):
input_data = pd.DataFrame(
[
[
CT_value,
Tumor_size,
ctDNA,
CEA,
Location,
CYFRA21_1,
AAPR,
CA125,
LDH,
ANC,
ALT,
GGT,
CREA,
UREA,
Pleural_indentation,
]
],
columns=feature_names,
)
# 将 Location 和 ctDNA 转换为数值型
input_data["Location"] = input_data["Location"].map({"Central": 1, "Peripheral": 0})
input_data["ctDNA"] = input_data["ctDNA"].map({"Positive": 1, "Negative": 0})
input_data["Pleural indentation"] = input_data["Pleural indentation"].map({"Positive": 1, "Negative": 0})
# 预测
try:
prob = model.predict_proba(input_data)[0][1]
except Exception as e:
return f"预测出错: {e}", None
# 计算 SHAP 值
try:
shap_values = explainer(input_data)
# 绘图
shap.plots.waterfall(shap_values[0], show=False)
plt.title("SHAP Waterfall Plot")
plt.savefig("shap_plot.png", bbox_inches="tight", dpi=300)
plt.close()
except Exception as e:
return f"SHAP 图生成失败: {e}", None
return f"阳性概率: {prob:.2%}", "shap_plot.png"
demo = gr.Interface(
fn=predict_probability,
inputs=[
gr.Number(label="CT value(HU)"),
gr.Number(label="Tumor size(mm)"),
gr.Dropdown(choices=["Positive", "Negative"], label="ctDNA"),
gr.Number(label="CEA (ng/mL) Normal range: 0-5"),
gr.Dropdown(choices=["Central", "Peripheral"], label="Location"), # 修改为 Dropdown 类型
gr.Number(label="CYFRA21-1 (ng/mL) Normal range: 0-5"),
gr.Number(label="AAPR (ng/mL) Normal range: 0-5"),
gr.Number(label="CA125 (U/mL) Normal range: 0-35"),
gr.Number(label="LDH (U/L) Normal range: 120-250"),
gr.Number(label="ANC (10^9/L) Normal range: 1.8-6.3"),
gr.Number(label="ALT (U/L) Normal range: 7-40"),
gr.Number(label="GGT (U/L) Normal range: 7-45"),
gr.Number(label="CREA (μmol/L) Normal range: 41-81"),
gr.Number(label="UREA (mmol/L) Normal range: 2.6-8.8"),
gr.Dropdown(choices=["Positive", "Negative"], label="Pleural indentation"),
],
outputs=[
gr.Textbox(label="Results of prediction"),
gr.Image(type="filepath", label="SHAP Waterfall Plot"),
],
title="Prediction of Lymph Node Metastasis",
description="Variables were entered to obtain the predicted positive probability and SHAP interpretation map",
)
demo.launch(share=True)