Curvature's picture
Update app.py
70b5416 verified
import gradio as gr
import pickle
import numpy as np
import pandas as pd
import shap
import matplotlib.pyplot as plt
# Load model
with open(r'best_model.pkl', 'rb') as f:
model = pickle.load(f)
# 定义变量的最大最小值
var_ranges = {
'Age': {'min': 21.00, 'max': 97.00},
'Weight'
: {'min': 20.50, 'max': 345.00},
'WBC': {'min': 0.41, 'max': 196.57},
'PlateletCount': {'min': 17.82, 'max': 1071.61},
'Albumin': {'min': 1.10, 'max': 4.90},
'Potassium': {'min': 2.87, 'max': 6.35},
'Glucose': {'min': 70.38, 'max': 327.00},
'AnionGap': {'min': 7.43, 'max': 40.71},
'pH': {'min': 7.13, 'max': 7.68},
'pO2': {'min': 27.00, 'max': 358.00},
'Lactate': {'min': -0.20, 'max': 18.93},
'PT': {'min': 10.00, 'max': 71.80},
'SpO2': {'min': 87.60, 'max': 116.58},
'Temperature': {'min': 94.46, 'max': 110.08},
'APSIII': {'min': 9.00, 'max': 148.00}
}
def min_max_scale(value, min_val, max_val):
return (value - min_val) / (max_val - min_val)
def predict_mortality(Age, Weight, WBC, PlateletCount, Albumin, Potassium,
Glucose, AnionGap, pH, pO2, Lactate, PT,
SpO2, Temperature, APSIII):
# Min-max scaling
scaled_input = [
min_max_scale(Age, var_ranges['Age']['min'], var_ranges['Age']['max']),
min_max_scale(Weight, var_ranges['Weight']['min'], var_ranges['Weight']['max']),
min_max_scale(WBC, var_ranges['WBC']['min'], var_ranges['WBC']['max']),
min_max_scale(PlateletCount, var_ranges['PlateletCount']['min'], var_ranges['PlateletCount']['max']),
min_max_scale(Albumin, var_ranges['Albumin']['min'], var_ranges['Albumin']['max']),
min_max_scale(Potassium, var_ranges['Potassium']['min'], var_ranges['Potassium']['max']),
min_max_scale(Glucose, var_ranges['Glucose']['min'], var_ranges['Glucose']['max']),
min_max_scale(AnionGap, var_ranges['AnionGap']['min'], var_ranges['AnionGap']['max']),
min_max_scale(pH, var_ranges['pH']['min'], var_ranges['pH']['max']),
min_max_scale(pO2, var_ranges['pO2']['min'], var_ranges['pO2']['max']),
min_max_scale(Lactate, var_ranges['Lactate']['min'], var_ranges['Lactate']['max']),
min_max_scale(PT, var_ranges['PT']['min'], var_ranges['PT']['max']),
min_max_scale(SpO2, var_ranges['SpO2']['min'], var_ranges['SpO2']['max']),
min_max_scale(Temperature, var_ranges['Temperature']['min'], var_ranges['Temperature']['max']),
min_max_scale(APSIII, var_ranges['APSIII']['min'], var_ranges['APSIII']['max'])
]
input_data = np.array([scaled_input])
# Get prediction
prediction = model.predict_proba(input_data)[0][1]
# Calculate confidence score
confidence_score = abs(0.5 - prediction) * 2
prediction_text = "高风险 (High Risk)" if prediction >= 0.5 else "低风险 (Low Risk)"
probability = prediction if prediction >= 0.5 else (1 - prediction)
try:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(input_data)
plt.clf()
plt.figure(figsize=(10, 3))
if isinstance(shap_values, list):
shap_values = shap_values[1]
shap.force_plot(
explainer.expected_value if not isinstance(explainer.expected_value, list)
else explainer.expected_value[1],
shap_values,
input_data[0],
feature_names=['Age', 'Weight', 'WBC', 'PlateletCount', 'Albumin', 'Potassium',
'Glucose', 'AnionGap', 'pH', 'pO2', 'Lactate', 'PT',
'SpO2', 'Temperature', 'APSIII'],
matplotlib=True,
show=False
)
except Exception as e:
print(f"SHAP visualization error: {str(e)}")
plt.clf()
plt.figure(figsize=(10, 3))
plt.text(0.5, 0.5, '无法生成 SHAP 解释图 (Unable to generate SHAP plot)', ha='center', va='center')
prediction_result = f"""
预测结果 (Prediction Result): {prediction_text}
死亡风险概率 (Mortality Risk Probability): {probability:.2%}
预测置信度 (Prediction Confidence): {confidence_score:.2%}
"""
return [prediction_result, plt.gcf()]
# Create Gradio interface with input validation and layout
with gr.Blocks() as iface:
gr.Markdown("# 冠心病合并肺炎患者静脉置管术后28天死亡风险管理\n# 28-Day Mortality Risk Management Following Intravenous Catheterization in Patients with Coronary Heart Disease and Pneumonia")
with gr.Row():
with gr.Column():
# Input fields
age = gr.Number(label="年龄 (Age)", value=0)
weight = gr.Number(label="体重 (Weight, kg)", value=0)
wbc = gr.Number(label="白细胞计数 (White Blood Cell Count)", value=0)
platelet = gr.Number(label="血小板计数 (Platelet Count)", value=0)
albumin = gr.Number(label="白蛋白 (Albumin)", value=0)
with gr.Column():
potassium = gr.Number(label="钾离子 (Potassium)", value=0)
glucose = gr.Number(label="血糖 (Glucose)", value=0)
anion_gap = gr.Number(label="阴离子间隙 (Anion Gap)", value=0)
ph = gr.Number(label="pH值 (pH)", value=0)
po2 = gr.Number(label="氧分压 (pO2)", value=0)
with gr.Column():
lactate = gr.Number(label="乳酸 (Lactate)", value=0)
pt = gr.Number(label="凝血酶原时间 (Prothrombin Time)", value=0)
spo2 = gr.Number(label="血氧饱和度 (SpO2)", value=0)
temperature = gr.Number(label="体温 (Temperature)", value=0)
apsiii = gr.Number(label="APSIII评分 (APSIII Score)", value=0)
with gr.Row():
predict_btn = gr.Button("提交 (Submit)")
with gr.Row():
# Output sections
prediction_text = gr.Textbox(label="预测结果 (Prediction Result)")
with gr.Row():
plot_output = gr.Plot(label="特征贡献度图 (Feature Contribution Plot)")
predict_btn.click(
fn=predict_mortality,
inputs=[age, weight, wbc, platelet, albumin, potassium,
glucose, anion_gap, ph, po2, lactate, pt,
spo2, temperature, apsiii],
outputs=[prediction_text, plot_output]
)
iface.launch()