Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pickle | |
import numpy as np | |
import pandas as pd | |
import shap | |
import matplotlib.pyplot as plt | |
# Load model | |
with open(r'best_model.pkl', 'rb') as f: | |
model = pickle.load(f) | |
# 定义变量的最大最小值 | |
var_ranges = { | |
'Age': {'min': 21.00, 'max': 97.00}, | |
'Weight' | |
: {'min': 20.50, 'max': 345.00}, | |
'WBC': {'min': 0.41, 'max': 196.57}, | |
'PlateletCount': {'min': 17.82, 'max': 1071.61}, | |
'Albumin': {'min': 1.10, 'max': 4.90}, | |
'Potassium': {'min': 2.87, 'max': 6.35}, | |
'Glucose': {'min': 70.38, 'max': 327.00}, | |
'AnionGap': {'min': 7.43, 'max': 40.71}, | |
'pH': {'min': 7.13, 'max': 7.68}, | |
'pO2': {'min': 27.00, 'max': 358.00}, | |
'Lactate': {'min': -0.20, 'max': 18.93}, | |
'PT': {'min': 10.00, 'max': 71.80}, | |
'SpO2': {'min': 87.60, 'max': 116.58}, | |
'Temperature': {'min': 94.46, 'max': 110.08}, | |
'APSIII': {'min': 9.00, 'max': 148.00} | |
} | |
def min_max_scale(value, min_val, max_val): | |
return (value - min_val) / (max_val - min_val) | |
def predict_mortality(Age, Weight, WBC, PlateletCount, Albumin, Potassium, | |
Glucose, AnionGap, pH, pO2, Lactate, PT, | |
SpO2, Temperature, APSIII): | |
# Min-max scaling | |
scaled_input = [ | |
min_max_scale(Age, var_ranges['Age']['min'], var_ranges['Age']['max']), | |
min_max_scale(Weight, var_ranges['Weight']['min'], var_ranges['Weight']['max']), | |
min_max_scale(WBC, var_ranges['WBC']['min'], var_ranges['WBC']['max']), | |
min_max_scale(PlateletCount, var_ranges['PlateletCount']['min'], var_ranges['PlateletCount']['max']), | |
min_max_scale(Albumin, var_ranges['Albumin']['min'], var_ranges['Albumin']['max']), | |
min_max_scale(Potassium, var_ranges['Potassium']['min'], var_ranges['Potassium']['max']), | |
min_max_scale(Glucose, var_ranges['Glucose']['min'], var_ranges['Glucose']['max']), | |
min_max_scale(AnionGap, var_ranges['AnionGap']['min'], var_ranges['AnionGap']['max']), | |
min_max_scale(pH, var_ranges['pH']['min'], var_ranges['pH']['max']), | |
min_max_scale(pO2, var_ranges['pO2']['min'], var_ranges['pO2']['max']), | |
min_max_scale(Lactate, var_ranges['Lactate']['min'], var_ranges['Lactate']['max']), | |
min_max_scale(PT, var_ranges['PT']['min'], var_ranges['PT']['max']), | |
min_max_scale(SpO2, var_ranges['SpO2']['min'], var_ranges['SpO2']['max']), | |
min_max_scale(Temperature, var_ranges['Temperature']['min'], var_ranges['Temperature']['max']), | |
min_max_scale(APSIII, var_ranges['APSIII']['min'], var_ranges['APSIII']['max']) | |
] | |
input_data = np.array([scaled_input]) | |
# Get prediction | |
prediction = model.predict_proba(input_data)[0][1] | |
# Calculate confidence score | |
confidence_score = abs(0.5 - prediction) * 2 | |
prediction_text = "高风险 (High Risk)" if prediction >= 0.5 else "低风险 (Low Risk)" | |
probability = prediction if prediction >= 0.5 else (1 - prediction) | |
try: | |
explainer = shap.TreeExplainer(model) | |
shap_values = explainer.shap_values(input_data) | |
plt.clf() | |
plt.figure(figsize=(10, 3)) | |
if isinstance(shap_values, list): | |
shap_values = shap_values[1] | |
shap.force_plot( | |
explainer.expected_value if not isinstance(explainer.expected_value, list) | |
else explainer.expected_value[1], | |
shap_values, | |
input_data[0], | |
feature_names=['Age', 'Weight', 'WBC', 'PlateletCount', 'Albumin', 'Potassium', | |
'Glucose', 'AnionGap', 'pH', 'pO2', 'Lactate', 'PT', | |
'SpO2', 'Temperature', 'APSIII'], | |
matplotlib=True, | |
show=False | |
) | |
except Exception as e: | |
print(f"SHAP visualization error: {str(e)}") | |
plt.clf() | |
plt.figure(figsize=(10, 3)) | |
plt.text(0.5, 0.5, '无法生成 SHAP 解释图 (Unable to generate SHAP plot)', ha='center', va='center') | |
prediction_result = f""" | |
预测结果 (Prediction Result): {prediction_text} | |
死亡风险概率 (Mortality Risk Probability): {probability:.2%} | |
预测置信度 (Prediction Confidence): {confidence_score:.2%} | |
""" | |
return [prediction_result, plt.gcf()] | |
# Create Gradio interface with input validation and layout | |
with gr.Blocks() as iface: | |
gr.Markdown("# 冠心病合并肺炎患者静脉置管术后28天死亡风险管理\n# 28-Day Mortality Risk Management Following Intravenous Catheterization in Patients with Coronary Heart Disease and Pneumonia") | |
with gr.Row(): | |
with gr.Column(): | |
# Input fields | |
age = gr.Number(label="年龄 (Age)", value=0) | |
weight = gr.Number(label="体重 (Weight, kg)", value=0) | |
wbc = gr.Number(label="白细胞计数 (White Blood Cell Count)", value=0) | |
platelet = gr.Number(label="血小板计数 (Platelet Count)", value=0) | |
albumin = gr.Number(label="白蛋白 (Albumin)", value=0) | |
with gr.Column(): | |
potassium = gr.Number(label="钾离子 (Potassium)", value=0) | |
glucose = gr.Number(label="血糖 (Glucose)", value=0) | |
anion_gap = gr.Number(label="阴离子间隙 (Anion Gap)", value=0) | |
ph = gr.Number(label="pH值 (pH)", value=0) | |
po2 = gr.Number(label="氧分压 (pO2)", value=0) | |
with gr.Column(): | |
lactate = gr.Number(label="乳酸 (Lactate)", value=0) | |
pt = gr.Number(label="凝血酶原时间 (Prothrombin Time)", value=0) | |
spo2 = gr.Number(label="血氧饱和度 (SpO2)", value=0) | |
temperature = gr.Number(label="体温 (Temperature)", value=0) | |
apsiii = gr.Number(label="APSIII评分 (APSIII Score)", value=0) | |
with gr.Row(): | |
predict_btn = gr.Button("提交 (Submit)") | |
with gr.Row(): | |
# Output sections | |
prediction_text = gr.Textbox(label="预测结果 (Prediction Result)") | |
with gr.Row(): | |
plot_output = gr.Plot(label="特征贡献度图 (Feature Contribution Plot)") | |
predict_btn.click( | |
fn=predict_mortality, | |
inputs=[age, weight, wbc, platelet, albumin, potassium, | |
glucose, anion_gap, ph, po2, lactate, pt, | |
spo2, temperature, apsiii], | |
outputs=[prediction_text, plot_output] | |
) | |
iface.launch() | |