File size: 6,457 Bytes
434fa88
 
 
 
bbdcf26
70b5416
28b88f4
434fa88
c762552
434fa88
8861adb
28b88f4
 
58dd3c9
19e9235
 
58dd3c9
 
 
28b88f4
58dd3c9
28b88f4
58dd3c9
 
28b88f4
58dd3c9
 
 
 
28b88f4
434fa88
28b88f4
 
434fa88
5314289
58dd3c9
 
5314289
28b88f4
 
 
 
 
 
 
 
 
58dd3c9
28b88f4
 
58dd3c9
 
 
 
28b88f4
 
 
5314289
 
d5022a9
5314289
 
da05610
434fa88
5314289
d5022a9
434fa88
da05610
28b88f4
da05610
 
28b88f4
da05610
 
28b88f4
da05610
 
 
 
 
 
 
58dd3c9
 
da05610
 
 
 
 
 
 
 
5314289
f73dc70
 
5314289
 
 
f73dc70
 
 
 
5314289
5e384e2
19e9235
5d8e98a
 
 
 
 
 
 
 
bc48994
5d8e98a
 
 
 
 
 
02b7ee6
5d8e98a
 
 
 
 
 
02b7ee6
bc48994
02b7ee6
 
5314289
 
bcc34fa
5314289
5e384e2
 
 
5314289
58dd3c9
 
5e384e2
 
ee77edf
5e384e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import gradio as gr
import pickle
import numpy as np
import pandas as pd
import shap
import matplotlib.pyplot as plt

# Load model
with open(r'best_model.pkl', 'rb') as f:
    model = pickle.load(f)

# 定义变量的最大最小值
var_ranges = {
    'Age': {'min': 21.00, 'max': 97.00},
    'Weight'
: {'min': 20.50, 'max': 345.00}, 
    'WBC': {'min': 0.41, 'max': 196.57},
    'PlateletCount': {'min': 17.82, 'max': 1071.61},
    'Albumin': {'min': 1.10, 'max': 4.90},
    'Potassium': {'min': 2.87, 'max': 6.35},
    'Glucose': {'min': 70.38, 'max': 327.00},
    'AnionGap': {'min': 7.43, 'max': 40.71},
    'pH': {'min': 7.13, 'max': 7.68},
    'pO2': {'min': 27.00, 'max': 358.00},
    'Lactate': {'min': -0.20, 'max': 18.93},
    'PT': {'min': 10.00, 'max': 71.80},
    'SpO2': {'min': 87.60, 'max': 116.58},
    'Temperature': {'min': 94.46, 'max': 110.08},
    'APSIII': {'min': 9.00, 'max': 148.00}
}

def min_max_scale(value, min_val, max_val):
    return (value - min_val) / (max_val - min_val)

def predict_mortality(Age, Weight, WBC, PlateletCount, Albumin, Potassium,
                     Glucose, AnionGap, pH, pO2, Lactate, PT,
                     SpO2, Temperature, APSIII):
    # Min-max scaling
    scaled_input = [
        min_max_scale(Age, var_ranges['Age']['min'], var_ranges['Age']['max']),
        min_max_scale(Weight, var_ranges['Weight']['min'], var_ranges['Weight']['max']),
        min_max_scale(WBC, var_ranges['WBC']['min'], var_ranges['WBC']['max']),
        min_max_scale(PlateletCount, var_ranges['PlateletCount']['min'], var_ranges['PlateletCount']['max']),
        min_max_scale(Albumin, var_ranges['Albumin']['min'], var_ranges['Albumin']['max']),
        min_max_scale(Potassium, var_ranges['Potassium']['min'], var_ranges['Potassium']['max']),
        min_max_scale(Glucose, var_ranges['Glucose']['min'], var_ranges['Glucose']['max']),
        min_max_scale(AnionGap, var_ranges['AnionGap']['min'], var_ranges['AnionGap']['max']),
        min_max_scale(pH, var_ranges['pH']['min'], var_ranges['pH']['max']),
        min_max_scale(pO2, var_ranges['pO2']['min'], var_ranges['pO2']['max']),
        min_max_scale(Lactate, var_ranges['Lactate']['min'], var_ranges['Lactate']['max']),
        min_max_scale(PT, var_ranges['PT']['min'], var_ranges['PT']['max']),
        min_max_scale(SpO2, var_ranges['SpO2']['min'], var_ranges['SpO2']['max']),
        min_max_scale(Temperature, var_ranges['Temperature']['min'], var_ranges['Temperature']['max']),
        min_max_scale(APSIII, var_ranges['APSIII']['min'], var_ranges['APSIII']['max'])
    ]
    
    input_data = np.array([scaled_input])
    
    # Get prediction
    prediction = model.predict_proba(input_data)[0][1]
    
    # Calculate confidence score
    confidence_score = abs(0.5 - prediction) * 2
    
    prediction_text = "高风险 (High Risk)" if prediction >= 0.5 else "低风险 (Low Risk)"
    probability = prediction if prediction >= 0.5 else (1 - prediction)
    
    try:
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(input_data)
        
        plt.clf()
        plt.figure(figsize=(10, 3))
        if isinstance(shap_values, list):
            shap_values = shap_values[1]
        
        shap.force_plot(
            explainer.expected_value if not isinstance(explainer.expected_value, list) 
            else explainer.expected_value[1],
            shap_values,
            input_data[0],
            feature_names=['Age', 'Weight', 'WBC', 'PlateletCount', 'Albumin', 'Potassium',
                          'Glucose', 'AnionGap', 'pH', 'pO2', 'Lactate', 'PT',
                          'SpO2', 'Temperature', 'APSIII'],
            matplotlib=True,
            show=False
        )
        
    except Exception as e:
        print(f"SHAP visualization error: {str(e)}")
        plt.clf()
        plt.figure(figsize=(10, 3))
        plt.text(0.5, 0.5, '无法生成 SHAP 解释图 (Unable to generate SHAP plot)', ha='center', va='center')

    prediction_result = f"""
    预测结果 (Prediction Result): {prediction_text}
    死亡风险概率 (Mortality Risk Probability): {probability:.2%}
    预测置信度 (Prediction Confidence): {confidence_score:.2%}
    """
    
    return [prediction_result, plt.gcf()]

# Create Gradio interface with input validation and layout
with gr.Blocks() as iface:
    gr.Markdown("# 冠心病合并肺炎患者静脉置管术后28天死亡风险管理\n# 28-Day Mortality Risk Management Following Intravenous Catheterization in Patients with Coronary Heart Disease and Pneumonia")
    with gr.Row():
            with gr.Column():
                # Input fields
                age = gr.Number(label="年龄 (Age)", value=0)
                weight = gr.Number(label="体重 (Weight, kg)", value=0)
                wbc = gr.Number(label="白细胞计数 (White Blood Cell Count)", value=0)
                platelet = gr.Number(label="血小板计数 (Platelet Count)", value=0)
                albumin = gr.Number(label="白蛋白 (Albumin)", value=0)

            with gr.Column():
                potassium = gr.Number(label="钾离子 (Potassium)", value=0)
                glucose = gr.Number(label="血糖 (Glucose)", value=0)
                anion_gap = gr.Number(label="阴离子间隙 (Anion Gap)", value=0)
                ph = gr.Number(label="pH值 (pH)", value=0)
                po2 = gr.Number(label="氧分压 (pO2)", value=0)

            with gr.Column():
                lactate = gr.Number(label="乳酸 (Lactate)", value=0)
                pt = gr.Number(label="凝血酶原时间 (Prothrombin Time)", value=0)
                spo2 = gr.Number(label="血氧饱和度 (SpO2)", value=0)
                temperature = gr.Number(label="体温 (Temperature)", value=0)
                apsiii = gr.Number(label="APSIII评分 (APSIII Score)", value=0)
    with gr.Row():
        predict_btn = gr.Button("提交 (Submit)")

    with gr.Row():
        # Output sections
        prediction_text = gr.Textbox(label="预测结果 (Prediction Result)")
    with gr.Row():
        plot_output = gr.Plot(label="特征贡献度图 (Feature Contribution Plot)")

    predict_btn.click(
        fn=predict_mortality,
        inputs=[age, weight, wbc, platelet, albumin, potassium,
                glucose, anion_gap, ph, po2, lactate, pt,
                spo2, temperature, apsiii],
        outputs=[prediction_text, plot_output]
    )

iface.launch()