import joblib import pandas as pd import streamlit as st from streamlit_shap import st_shap import shap import numpy as np import math # Load Model model1 = joblib.load('saved_model/model1.pkl') predictor = model1 st.title('Prediction of postoperative AKI') st.write('For geriatric patients (>=65 years) undergoing noncardiac surgery') asa_options = { "Ⅰ": -1.38814621350096, "Ⅱ": 0.511717257330668, "Ⅲ": -0.395963387679586, "Ⅳ": -1.4681889211745, "Ⅴ": -1.4681889211745 } hypertension_options = { "No": 3.68035036162672, "Ⅰ": -0.544235316554776, "Ⅱ": -0.844034612101507, "Ⅲ": -0.943664453050348 } dm_options = { "No": 0.122040527036218, "Non-insulin dependent": -0.638909566289975, "Insulin dependent": -1.7072632945609 } illness_options = { "Yes": 1, "No": 0 } os_options = { "": 2.70805020110221, "Dermatological surgery": 2.22692108269007, "Limb surgery": 0.752566153571913, "Arthroplasty": 0.532164318353033, "Spinal surgery": 1.6127011607198, "Head and neck surgery": 2.52541407124607, "Upper abdomen surgery": -0.649878597660549, "Lower abdomen surgery": -0.464370149938076, "Abdomen surgery": -2.20389571615324, "Thoracic surgery": 0.416954872664665, "cranial surgery": 0.0899556968720493, "Thoracoabdominal surgery": 0.0899556968720493, "Vascular surgery": 0.281010933634758 } es_options = { "No": 0.0175056050012399, "yes": -1.97901454494049 } eot_options = { "": 2.944439, "1h": 0.768487407184522, "1h-1.5h": 0.523271451886264, "1.5-2h": -0.170387892577323, "2h-3h": -0.345346138774783, ">=3h": -0.749296055496119 } up_options = { "": 3.59497984361465, "No": 0.0374632790768398, "+/-": -0.465955759499621, "1+": -1.18050684872272, "2+": -1.18050684872272, "3+": -1.4681889211745 } sccl_zh = '血清胱抑素C测定' asa_zh = 'ASA分级' hypertension_zh = '高血压_分级' dm_zh = '糖尿病_控制方式' os_zh = '手术风险评估_手术部位' es_zh = '急诊' eot_zh = '手术时长_分钟' up_zh = '尿蛋白定性' coovorbcdw_zh = 'RBC分布宽度CV' inr_zh = '国际标准化比值' sccl_en = 'Serum cystatin C level (mg/L)' asa_en = 'ASA classification' hypertension_en = 'Hypertension' dm_en = 'Diabetes mellitus' os_en = 'Operation site' es_en = 'Emergency surgery' eot_en = 'Estimated operation time' up_en = 'Urine protein' coovorbcdw_en = 'RDW-CV (%)' inr_en = 'International normalized ratio' # Form with st.form(key='form_parameters'): col1, col2 = st.columns(2) # Create two columns # First column with col1: sccl = st.number_input('Serum cystatin C level (mg/L)') asa = st.selectbox('ASA classification', list(asa_options.keys())) asa_woe = asa_options.get(asa, 0) hypertension = st.selectbox('Hypertension', list(hypertension_options.keys())) hypertension_woe = hypertension_options.get(hypertension, 0) dm = st.selectbox('Diabetes mellitus', list(dm_options.keys())) dm_woe = dm_options.get(dm, 0) os = st.selectbox('Operation site', list(os_options.keys())) os_woe = os_options.get(os, 0) sccl_woe = 0 if sccl == '': sccl_woe = 2.99885240089024 sccl = np.nan elif float(sccl) < 0.87: sccl_woe = 0.645372775418255 elif (float(sccl) >= 0.87) and (float(sccl) < 0.95): sccl_woe = 0.369025188040581 elif (float(sccl) >= 0.95) and (float(sccl) < 1.03): sccl_woe = -0.10610555179212 elif (float(sccl) >= 1.03) and (float(sccl) < 1.13): sccl_woe = 0.0898098195273885 elif float(sccl) >= 1.13: sccl_woe = -0.925496780781878 sccl = float(sccl) # Second column with col2: eot = st.selectbox('Estimated operation time', list(eot_options.keys())) eot_woe = eot_options.get(eot, 0) up = st.selectbox('Urine protein', list(up_options.keys())) up_woe = up_options.get(up, 0) es = st.selectbox('Emergency surgery', list(es_options.keys())) es_woe = es_options.get(es, 0) coovorbcdw = st.number_input('RDW-CV (%)') inr = st.number_input('International normalized ratio') coovorbcdw_woe = 0 if coovorbcdw == '': coovorbcdw_woe = 3.34041249059303 coovorbcdw = np.nan elif float(coovorbcdw) < 12.7: coovorbcdw_woe = 0.398231895412054 elif (float(coovorbcdw) >= 12.7) and (float(coovorbcdw) < 13.3): coovorbcdw_woe = 0.048813472291148 elif (float(coovorbcdw) >= 13.3) and (float(coovorbcdw) < 13.7): coovorbcdw_woe = -0.104034700025362 elif (float(coovorbcdw) >= 13.7) and (float(coovorbcdw) < 14.6): coovorbcdw_woe = -0.0861214293356946 elif float(coovorbcdw) >= 14.6: coovorbcdw_woe = -0.646032944021757 coovorbcdw = float(coovorbcdw) inr_woe = 0 if inr == '': inr_woe = 3.38627173144548 inr = np.nan elif float(inr) < 0.92: inr_woe = -0.000131790754127941 elif (float(inr) >= 0.92) and (float(inr) < 0.95): inr_woe = 0.383156538484155 elif (float(inr) >= 0.95) and (float(inr) < 0.98): inr_woe = 0.282044225138759 elif (float(inr) >= 0.98) and (float(inr) < 1.03): inr_woe = -0.104889279197154 elif float(inr) >= 1.03: inr_woe = -0.769604383885953 inr = float(inr) st.markdown('---') submitted = st.form_submit_button('Predict') with st.container(): st.write('Abbreviation: AKI, acute kidney injury;ASA, American Society of Anesthesiologists Physical Status Classification score;RDW-CV, Coefficient of variation of red blood cell distribution width.') st.write('This risk calculator predicts the risk of acute kidney injury following noncardiac surgery in geriatric patients. It was created utilizing data from 10,561 geriatric patients received noncardiac surgery at the West China Hospital of Sichuan University. The area under the receiver operating characteristic curve is 0.806, and area under the precision-recall curve is 0.505. This tool could be used prospectively to identify patients at high risk of acute kidney injury following noncardiac surgery, and facilitate perioperative medical decision-making.') # 构建预测样本(原始数据) data1 = { hypertension_zh: hypertension, up_zh: up, dm_zh: dm, os_zh: os, asa_zh: asa, eot_zh: eot, sccl_zh: sccl, coovorbcdw_zh: coovorbcdw, inr_zh: inr, es_zh: es } X_train_1 = pd.DataFrame(data1, index=[0]) data2 = { hypertension_zh: hypertension_woe, up_zh: up_woe, dm_zh: dm_woe, os_zh: os_woe, asa_zh: asa_woe, eot_zh: eot_woe, sccl_zh: sccl_woe, coovorbcdw_zh: coovorbcdw_woe, inr_zh: inr_woe, es_zh: es_woe } X_train = pd.DataFrame(data2, index=[0]) if submitted: # Predict y_prob_train = model1.predict_proba(X_train) value = y_prob_train[0][1] explainer = shap.TreeExplainer(model1) shap_values = explainer.shap_values(X_train) # 这里的X_train替换成输进来的值,要woe版的数据,不是原始版数据,有个对应文件用来映射数值与woe print('base value: {}'.format(explainer.expected_value)) X1_ = X_train_1.fillna(0) # 这里替换成输进来的值的原始数据 # 替换英文名 X1_.columns = [hypertension_en, up_en, dm_en, os_en, asa_en, eot_en, sccl_en, coovorbcdw_en, inr_en, es_en] with st.container(): st.write('### The possibility of Acute kidney injury is: \n'+ str(value)) st.markdown('---') with st.container(): st.write('Submitted parameters were: \n') st.write('ASA classification: {} / Hypertension: {} / Diabetes mellitus: {} / Operation site: {} / Emergency surgery: {} / Estimated operation time: {} / Urine protein: {} / Serum cystatin C level (mg/L): {} / Coefficient of variation of red blood cell distribution width (%): {} / International normalized ratio: {} \n'.format(X1_[asa_en][0], X1_[hypertension_en][0], X1_[dm_en][0], X1_[os_en][0], X1_[es_en][0], X1_[eot_en][0], X1_[up_en][0], X1_[sccl_en][0], X1_[coovorbcdw_en][0], X1_[inr_en][0])) with st.container(): st.write('## SHAP force plot for this patient') # Visualize SHAP values # Visualize SHAP values st_shap(shap.force_plot(explainer.expected_value, shap_values[0], X1_.iloc[0], link="logit"), height=400, width=1000)