import streamlit as st import streamlit.components.v1 as components import numpy as np import pickle import pandas as pd pd.set_option('display.max_columns', 500) def impute_missing_values(df_,dict_impute): return df_.fillna(dict_impute) to_rescale_ = ['diff_MS', 'MS_d', 'pr_pre_tavi', 'pr_j0', 'qrs_pre_tavi', 'qrs_j0', 'petit_diametre_anneau', 'surface_systole', 'MS_s', 'delta_msid', 'fraction_ejection_post_tavi', 'nominal', 'delta_qrs', 'age', 'ncc_calcif_n', 'lcc_calc_n', 'calc_risque_n'] to_encode_ = ['syncope_Oui', 'marque_ACURATE', 'marque_COREVALVE', 'marque_EDWARDS', 'bloc_branche_post_bbd', 'bloc_branche_post_non'] st.title("Pacemaker Implantation (IPM) risk prediction") st.subheader(f':orange[Post procedural] Tavi implantation model') cols = to_rescale_+to_encode_ with st.form(key='cols_form'): c1,c2= st.columns(2) c3,c4 = st.columns(2) c5,c6,c7 = st.columns(3) c8,c9,c10,c11 = st.columns(4) c12,c13,c14 = st.columns(3) c15,c16,c17= st.columns(3) c18,c19= st.columns(2) with c1: age = st.number_input('Age (year)',value=83) with c2: syncope_Oui = st.selectbox("Syncope",key="syncope",options=["Yes", "No"]) with c3: pr_pre_tavi = st.number_input('Pre-TAVI PR time (ms)',value=186) with c4: qrs_pre_tavi = st.number_input('Pre-TAVI QRS time (ms)',value=104) with c5: pr_j0 = st.number_input("Post-TAVI PR time (ms)",value=194) with c6: qrs_j0 = st.number_input("Post-TAVI QRS time (ms)",value=121) with c7: bloc_branche_post_bbd = st.selectbox("Post-TAVI RBBB",key="rbbb",options=["Yes", "No"]) with c8: lcc_calc_n = st.selectbox("LCC calcification grade",key='lcccalc1',options=[0,1,2,3]) with c9: ncc_calcif_n = st.selectbox("NCC calcification grade",key='nccalcn',options=[0,1,2,3]) with c10: rcc_calcif_n = st.selectbox("RCC calcification grade",key='rccalcn',options=[0,1,2,3]) with c11: isc_calcif_n = st.selectbox("IS calcification",key='isccalcn',options=["Yes", "No"]) with c12: MS_s = st.number_input('Systolic membranous septal length(mm)',value=8.0) with c13: MS_d = st.number_input('Diastolic membranous septal length (mm)',value=8.9) with c14: surface_systole = st.number_input('Systolic aortic annular surface (mm²)',value=483) with c15: petit_diametre_anneau = st.number_input('Minimal aortic annulus diameter (mm)',value=21) with c16: marque_valve = st.selectbox("TAVI valve type",key="pmbrand",options=["Corevalve", "Edwards","Acurate","Portico"]) with c17: IDepth = st.number_input('Implantation depth (mm)',value=5.3) with c18: nominal = st.number_input('Transcatheter valve nominal area (mm²)',value=490) with c19: fraction_ejection_post_tavi = st.number_input('Post TAVI LVEF (%)',value=59) submitButton = st.form_submit_button(label = 'Predict') #load model, set cache to prevent reloading @st.cache_resource() def load_model(): with open(r"models/ensemble_models_post.pkl", "rb") as input_file: models = pickle.load(input_file) return models #@st.cache_resource() def load_scaler(): with open(r"models/ensemble_scaler_post.pkl", "rb") as input_file: scalers = pickle.load(input_file) return scalers @st.cache_resource() def load_impute(): with open(r"models/ensemble_dict_impute_post.pkl", "rb") as input_file: dicts_impute = pickle.load(input_file) return dicts_impute with st.spinner("Loading Model...."): models = load_model() dicts_impute = load_impute() scalers = load_scaler() #st.write("Predicting Class...") with st.spinner("Prediction..."): # preprocessing age = int(age)*365 diff_MS = MS_d - MS_s delta_msid = MS_s - IDepth syncope_Oui = 1 if syncope_Oui == "Yes" else 0 isc_calcif_n = 1 if isc_calcif_n == "Yes" else 0 bloc_branche_post_bbd = 1 if bloc_branche_post_bbd == "Yes" else 0 calc_risque_n = ncc_calcif_n + rcc_calcif_n + isc_calcif_n delta_msid_d = MS_d - IDepth delta_qrs = qrs_j0 - qrs_pre_tavi bloc_branche_post_non =1 if delta_qrs<120 else 0 marque_COREVALVE = 1 if marque_valve == "Corevalve" else 0 marque_ACURATE = 1 if marque_valve == "Acurate" else 0 marque_EDWARDS = 1 if marque_valve == "Edwards" else 0 bool_warning = False if MS_d<1.5 or MS_d>15.30: bool_warning = True if MS_s<2.0 or MS_s> 14.61: bool_warning = True if pr_pre_tavi<90 or pr_pre_tavi>344: bool_warning = True if pr_j0<110 or pr_j0>370: bool_warning = True if qrs_pre_tavi<68 or qrs_pre_tavi>176: bool_warning = True if qrs_j0<74 or qrs_j0>195: bool_warning = True if petit_diametre_anneau<11 or petit_diametre_anneau>27: bool_warning = True if fraction_ejection_post_tavi<21 or fraction_ejection_post_tavi>84: bool_warning = True if nominal<282 or nominal>648: bool_warning = True if age<21156 or age>34939: bool_warning = True if surface_systole<300 or surface_systole>724: bool_warning = True if IDepth<0 or IDepth>16.1: bool_warning = True if bool_warning: st.write(":warning: Warning Results might not be reliable because:") if MS_d<1.5 or MS_d>15.30: st.write("Our study population had a dMS (mm) between 1.5 and 15.30 (mm)") if MS_s<2.0 or MS_s> 14.61: st.write("Our study population had a sMS (mm) between 2 and 14.61 (mm) ") if pr_pre_tavi<90 or pr_pre_tavi>344: st.write("Our study population had a pre-TAVI PR surface between 90 and 344 (ms)") if pr_j0<110 or pr_j0>370: st.write("Our study population had a Post-TAVI PR time between 110 and 370 (ms)") if qrs_pre_tavi<68 or qrs_pre_tavi>176: st.write("Our study population had a Pre-TAVI QRS time between 68 and 176 (ms)") if qrs_j0<74 or qrs_j0>195: st.write("Our study population had a Post-TAVI QRS time between 74 and 195 (ms)") if petit_diametre_anneau<11 or petit_diametre_anneau>27: st.write("Our study population had a Minimal aortic annulus diameter between 11 and 27 (mm)") if fraction_ejection_post_tavi<21 or fraction_ejection_post_tavi>84: st.write("Our study population had a Post TAVI LVEF between 21 and 84 (%)") if nominal<282 or nominal>648: st.write("Our study population had a Transcatheter valve nominal area between 282 and 648 (mm²)") if age<21156 or age>34939: st.write("Our study population had an age between 58 and 95 (year)") if surface_systole<300 or surface_systole>724: st.write("Our study population had a systolic aortic annular surface between 300 and 724 (mm²) ") if IDepth<0 or IDepth>16.1: st.write("Our study population had an implantation depth between 0 and 16.1 (mm) ") pred_arr = np.array([[diff_MS,MS_d,pr_pre_tavi,pr_j0,qrs_pre_tavi,qrs_j0,petit_diametre_anneau,surface_systole,MS_s,delta_msid,fraction_ejection_post_tavi,nominal,delta_qrs,age,ncc_calcif_n,lcc_calc_n,calc_risque_n,syncope_Oui,marque_ACURATE,marque_COREVALVE,marque_EDWARDS,bloc_branche_post_bbd,bloc_branche_post_non]]) pred_df = pd.DataFrame(pred_arr,columns=cols) print(pred_df.head()) pred_dfs = [impute_missing_values(pred_df,dict_impute) for dict_impute in dicts_impute] dfs_scaled_ = [scaler.transform(df_[to_rescale_]) for df_,scaler in zip(pred_dfs,scalers)] dfs_scaled = [pd.DataFrame(columns=to_rescale_,data=df_scaled_) for df_scaled_ in dfs_scaled_] dfs_cat = [pd.concat([df_scaled,pred_df[to_encode_]],axis=1) for df_scaled in dfs_scaled] pred= round( np.array([model.predict_proba(pred_df_)[0][1] for pred_df_,model in zip(dfs_cat,models)]).mean()*100,2) if pred>90: st.write("PMI risk probability >90 %") elif pred<10: st.write("PMI risk probability <10 %") else: st.write("PMI risk probability :",pred,' %')