Yassine EL OUAHIDI
first commit
fba85dc
import streamlit as st
import streamlit.components.v1 as components
import numpy as np
import pickle
import pandas as pd
pd.set_option('display.max_columns', 500)
def impute_missing_values(df_,dict_impute):
return df_.fillna(dict_impute)
to_rescale_ = ['MS_d',
'delta_msid_d',
'pr_pre_tavi',
'surface_systole',
'MS_s',
'calc_risque_n']
to_encode_ = ['syncope_Oui',
'marque_ACURATE',
'marque_COREVALVE',
'marque_EDWARDS',
'bloc_branche_pre_bbd']
st.title("Pacemaker Implantation (IPM) risk prediction")
st.subheader(f':green[Per procedural] Tavi implantation model')
cols = to_rescale_+to_encode_
with st.form(key='cols_form'):
c1,cuseless= st.columns(2)
c2,c3 = st.columns(2)
c4,c5,c6 = st.columns(3)
c7,c8,c9 = st.columns(3)
c10,c11 = st.columns(2)
with c1:
syncope_Oui = st.selectbox("Syncope",key="syncope",options=["Yes", "No"])
with c2:
pr_pre_tavi = st.number_input('Pre-TAVI PR time (ms)',value=186)
with c3:
bloc_branche_pre_bbd = st.selectbox("Right bundle branch block",key="prebbd",options=["Yes", "No"])
with c4:
ncc_calcif_n = st.selectbox("NCC calcification grade",key='nccalcn',options=[0,1,2,3])
with c5:
rcc_calcif_n = st.selectbox("RCC calcification grade",key='rccalcn',options=[0,1,2,3])
with c6:
isc_calcif_n = st.selectbox("IS calcification",key='isccalcn',options=["Yes", "No"])
with c7:
surface_systole = st.number_input('Systolic aortic annular surface (mm²)',value=483)
with c8:
MS_s = st.number_input('Systolic membranous septal length(mm)',value=8)
with c9:
MS_d = st.number_input('Diastolic membranous septal length (mm)',value=8.9)
with c10:
marque_valve = st.selectbox("TAVI valve type",key="pmbrand",options=["Corevalve", "Edwards","Acurate","Portico"])
with c11:
IDepth = st.number_input('Implantation depth (mm)',value=5.3)
submitButton = st.form_submit_button(label = 'Predict')
#load model, set cache to prevent reloading
@st.cache_resource()
def load_model():
with open(r"models/ensemble_models_per.pkl", "rb") as input_file:
models = pickle.load(input_file)
return models
#@st.cache_resource()
def load_scaler():
with open(r"models/ensemble_scaler_per.pkl", "rb") as input_file:
scalers = pickle.load(input_file)
return scalers
@st.cache_resource()
def load_impute():
with open(r"models/ensemble_dict_impute_per.pkl", "rb") as input_file:
dicts_impute = pickle.load(input_file)
return dicts_impute
with st.spinner("Loading Model...."):
models = load_model()
dicts_impute = load_impute()
scalers = load_scaler()
#st.write("Predicting Class...")
with st.spinner("Prediction..."):
# preprocessing
diff_MS = MS_d - MS_s
syncope_Oui = 1 if syncope_Oui == "Yes" else 0
isc_calcif_n = 1 if isc_calcif_n == "Yes" else 0
bloc_branche_pre_bbd = 1 if bloc_branche_pre_bbd == "Yes" else 0
calc_risque_n = ncc_calcif_n + rcc_calcif_n + isc_calcif_n
delta_msid_d = MS_d - IDepth
marque_COREVALVE = 1 if marque_valve == "Corevalve" else 0
marque_ACURATE = 1 if marque_valve == "Acurate" else 0
marque_EDWARDS = 1 if marque_valve == "Edwards" else 0
bool_warning = False
if MS_d<1.5 or MS_d>15.30:
bool_warning = True
if MS_s<2.0 or MS_s> 14.61:
bool_warning = True
if pr_pre_tavi<90 or pr_pre_tavi>344:
bool_warning = True
if surface_systole<300 or surface_systole>724:
bool_warning = True
if IDepth<0 or IDepth>16.1:
bool_warning = True
if bool_warning:
st.write(":warning: Warning Results might not be reliable because:")
if MS_d<1.5 or MS_d>15.30:
st.write("Our study population had a dMS (mm) between 1.5 and 15.30 (mm)")
if MS_s<2.0 or MS_s> 14.61:
st.write("Our study population had a sMS (mm) between 2 and 14.61 (mm) ")
if pr_pre_tavi<90 or pr_pre_tavi>344:
st.write("Our study population had a pre-TAVI PR surface between 90 and 344 (ms)")
if surface_systole<300 or surface_systole>724:
st.write("Our study population had a systolic aortic annular surface between 300 and 724 (mm²) ")
if IDepth<0 or IDepth>16.1:
st.write("Our study population had an implantation depth between 0 and 16.1 (mm) ")
pred_arr = np.array([[MS_d,delta_msid_d,pr_pre_tavi,surface_systole,MS_s,calc_risque_n,
syncope_Oui,marque_ACURATE,marque_COREVALVE,marque_EDWARDS,bloc_branche_pre_bbd]])
pred_df = pd.DataFrame(pred_arr,columns=cols)
print(pred_df.head())
pred_dfs = [impute_missing_values(pred_df,dict_impute) for dict_impute in dicts_impute]
dfs_scaled_ = [scaler.transform(df_[to_rescale_]) for df_,scaler in zip(pred_dfs,scalers)]
dfs_scaled = [pd.DataFrame(columns=to_rescale_,data=df_scaled_) for df_scaled_ in dfs_scaled_]
dfs_cat = [pd.concat([df_scaled,pred_df[to_encode_]],axis=1) for df_scaled in dfs_scaled]
pred= round( np.array([model.predict_proba(pred_df_)[0][1] for pred_df_,model in zip(dfs_cat,models)]).mean()*100,2)
if pred>90:
st.write("PMI risk probability >90 %")
elif pred<10:
st.write("PMI risk probability <10 %")
else:
st.write("PMI risk probability :",pred,' %')