Spaces:
Sleeping
Sleeping
File size: 6,457 Bytes
434fa88 bbdcf26 70b5416 28b88f4 434fa88 c762552 434fa88 8861adb 28b88f4 58dd3c9 19e9235 58dd3c9 28b88f4 58dd3c9 28b88f4 58dd3c9 28b88f4 58dd3c9 28b88f4 434fa88 28b88f4 434fa88 5314289 58dd3c9 5314289 28b88f4 58dd3c9 28b88f4 58dd3c9 28b88f4 5314289 d5022a9 5314289 da05610 434fa88 5314289 d5022a9 434fa88 da05610 28b88f4 da05610 28b88f4 da05610 28b88f4 da05610 58dd3c9 da05610 5314289 f73dc70 5314289 f73dc70 5314289 5e384e2 19e9235 5d8e98a bc48994 5d8e98a 02b7ee6 5d8e98a 02b7ee6 bc48994 02b7ee6 5314289 bcc34fa 5314289 5e384e2 5314289 58dd3c9 5e384e2 ee77edf 5e384e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import gradio as gr
import pickle
import numpy as np
import pandas as pd
import shap
import matplotlib.pyplot as plt
# Load model
with open(r'best_model.pkl', 'rb') as f:
model = pickle.load(f)
# 定义变量的最大最小值
var_ranges = {
'Age': {'min': 21.00, 'max': 97.00},
'Weight'
: {'min': 20.50, 'max': 345.00},
'WBC': {'min': 0.41, 'max': 196.57},
'PlateletCount': {'min': 17.82, 'max': 1071.61},
'Albumin': {'min': 1.10, 'max': 4.90},
'Potassium': {'min': 2.87, 'max': 6.35},
'Glucose': {'min': 70.38, 'max': 327.00},
'AnionGap': {'min': 7.43, 'max': 40.71},
'pH': {'min': 7.13, 'max': 7.68},
'pO2': {'min': 27.00, 'max': 358.00},
'Lactate': {'min': -0.20, 'max': 18.93},
'PT': {'min': 10.00, 'max': 71.80},
'SpO2': {'min': 87.60, 'max': 116.58},
'Temperature': {'min': 94.46, 'max': 110.08},
'APSIII': {'min': 9.00, 'max': 148.00}
}
def min_max_scale(value, min_val, max_val):
return (value - min_val) / (max_val - min_val)
def predict_mortality(Age, Weight, WBC, PlateletCount, Albumin, Potassium,
Glucose, AnionGap, pH, pO2, Lactate, PT,
SpO2, Temperature, APSIII):
# Min-max scaling
scaled_input = [
min_max_scale(Age, var_ranges['Age']['min'], var_ranges['Age']['max']),
min_max_scale(Weight, var_ranges['Weight']['min'], var_ranges['Weight']['max']),
min_max_scale(WBC, var_ranges['WBC']['min'], var_ranges['WBC']['max']),
min_max_scale(PlateletCount, var_ranges['PlateletCount']['min'], var_ranges['PlateletCount']['max']),
min_max_scale(Albumin, var_ranges['Albumin']['min'], var_ranges['Albumin']['max']),
min_max_scale(Potassium, var_ranges['Potassium']['min'], var_ranges['Potassium']['max']),
min_max_scale(Glucose, var_ranges['Glucose']['min'], var_ranges['Glucose']['max']),
min_max_scale(AnionGap, var_ranges['AnionGap']['min'], var_ranges['AnionGap']['max']),
min_max_scale(pH, var_ranges['pH']['min'], var_ranges['pH']['max']),
min_max_scale(pO2, var_ranges['pO2']['min'], var_ranges['pO2']['max']),
min_max_scale(Lactate, var_ranges['Lactate']['min'], var_ranges['Lactate']['max']),
min_max_scale(PT, var_ranges['PT']['min'], var_ranges['PT']['max']),
min_max_scale(SpO2, var_ranges['SpO2']['min'], var_ranges['SpO2']['max']),
min_max_scale(Temperature, var_ranges['Temperature']['min'], var_ranges['Temperature']['max']),
min_max_scale(APSIII, var_ranges['APSIII']['min'], var_ranges['APSIII']['max'])
]
input_data = np.array([scaled_input])
# Get prediction
prediction = model.predict_proba(input_data)[0][1]
# Calculate confidence score
confidence_score = abs(0.5 - prediction) * 2
prediction_text = "高风险 (High Risk)" if prediction >= 0.5 else "低风险 (Low Risk)"
probability = prediction if prediction >= 0.5 else (1 - prediction)
try:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(input_data)
plt.clf()
plt.figure(figsize=(10, 3))
if isinstance(shap_values, list):
shap_values = shap_values[1]
shap.force_plot(
explainer.expected_value if not isinstance(explainer.expected_value, list)
else explainer.expected_value[1],
shap_values,
input_data[0],
feature_names=['Age', 'Weight', 'WBC', 'PlateletCount', 'Albumin', 'Potassium',
'Glucose', 'AnionGap', 'pH', 'pO2', 'Lactate', 'PT',
'SpO2', 'Temperature', 'APSIII'],
matplotlib=True,
show=False
)
except Exception as e:
print(f"SHAP visualization error: {str(e)}")
plt.clf()
plt.figure(figsize=(10, 3))
plt.text(0.5, 0.5, '无法生成 SHAP 解释图 (Unable to generate SHAP plot)', ha='center', va='center')
prediction_result = f"""
预测结果 (Prediction Result): {prediction_text}
死亡风险概率 (Mortality Risk Probability): {probability:.2%}
预测置信度 (Prediction Confidence): {confidence_score:.2%}
"""
return [prediction_result, plt.gcf()]
# Create Gradio interface with input validation and layout
with gr.Blocks() as iface:
gr.Markdown("# 冠心病合并肺炎患者静脉置管术后28天死亡风险管理\n# 28-Day Mortality Risk Management Following Intravenous Catheterization in Patients with Coronary Heart Disease and Pneumonia")
with gr.Row():
with gr.Column():
# Input fields
age = gr.Number(label="年龄 (Age)", value=0)
weight = gr.Number(label="体重 (Weight, kg)", value=0)
wbc = gr.Number(label="白细胞计数 (White Blood Cell Count)", value=0)
platelet = gr.Number(label="血小板计数 (Platelet Count)", value=0)
albumin = gr.Number(label="白蛋白 (Albumin)", value=0)
with gr.Column():
potassium = gr.Number(label="钾离子 (Potassium)", value=0)
glucose = gr.Number(label="血糖 (Glucose)", value=0)
anion_gap = gr.Number(label="阴离子间隙 (Anion Gap)", value=0)
ph = gr.Number(label="pH值 (pH)", value=0)
po2 = gr.Number(label="氧分压 (pO2)", value=0)
with gr.Column():
lactate = gr.Number(label="乳酸 (Lactate)", value=0)
pt = gr.Number(label="凝血酶原时间 (Prothrombin Time)", value=0)
spo2 = gr.Number(label="血氧饱和度 (SpO2)", value=0)
temperature = gr.Number(label="体温 (Temperature)", value=0)
apsiii = gr.Number(label="APSIII评分 (APSIII Score)", value=0)
with gr.Row():
predict_btn = gr.Button("提交 (Submit)")
with gr.Row():
# Output sections
prediction_text = gr.Textbox(label="预测结果 (Prediction Result)")
with gr.Row():
plot_output = gr.Plot(label="特征贡献度图 (Feature Contribution Plot)")
predict_btn.click(
fn=predict_mortality,
inputs=[age, weight, wbc, platelet, albumin, potassium,
glucose, anion_gap, ph, po2, lactate, pt,
spo2, temperature, apsiii],
outputs=[prediction_text, plot_output]
)
iface.launch()
|