Spaces:
Sleeping
Sleeping
import pandas as pd | |
import pickle | |
import streamlit as st | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
import shap | |
import matplotlib.pyplot as plt | |
model_path = hf_hub_download(repo_id=st.secrets["REPO_ID"], filename="model.pkl", token=st.secrets["HF_TOKEN"]) | |
explainer_path = hf_hub_download(repo_id=st.secrets["REPO_ID"], filename="explainer.pkl", token=st.secrets["HF_TOKEN"]) | |
# load model | |
model = pickle.load(open(model_path, "rb")) | |
def model_proba(x): | |
return model.predict_proba(x)[:, 1] | |
explainer = pickle.load(open(explainer_path, "rb")) | |
def predict(input): | |
col1c, col2c= st.columns([0.3, 0.7]) | |
inputss = pd.Series(input) | |
inputdf = pd.DataFrame(inputss) | |
inputdf.rename(columns={0: 'value'}, inplace=True) | |
with col1c: | |
st.subheader('Input data') | |
st.table(inputdf) | |
prob = model.predict_proba([inputss]) | |
st.header(f'CVD Risk: {prob[0][1]*100:.2f}%') | |
with col2c: | |
st.subheader('CVD Risk Explanation') | |
shap_value = explainer(pd.DataFrame(inputdf).T) | |
shap.decision_plot(shap_value.base_values, shap_value.values, feature_names=shap_value.feature_names) | |
ax = plt.gca() | |
ax.set_xlabel('<-- Feature input decreases risk | Feature input increases risk -->') | |
ax.set_ylabel('Feature impact -->') | |
st.pyplot(plt.gcf()) | |
st.markdown(''' | |
* The effect of each input feature's value on the model's result shown relates to THIS instance only. | |
* The straight vertical line is the expected (mean) value of the model. | |
* The plotted line shows the effect of each feature in deviating from the expected value. | |
''') | |
st.title('Cardiovascular Disease Risk Prediction DEMO') | |
st.markdown(''' | |
This is a CVD risk prediction app for demonstration only, **therefore not for clinical use**. | |
Output from a simple logistic regression model based on 1000 individuals in India. | |
Data source: [Cardiovascular Disease Dataset](https://www.kaggle.com/datasets/jocelyndumlao/cardiovascular-disease-dataset/) | |
The CVD risk model prediction is explained using [SHAP](https://shap.readthedocs.io/en/stable/) values. | |
''') | |
col1, col2= st.columns(2) | |
with col1: | |
age = st.number_input('Age (years)', 0, 100) | |
sex_options = ['Female', 'Male'] | |
sex = st.radio('Sex', sex_options) | |
chestpain_options = ['none', 'non-anginal pain', 'typical angina', 'atypical angina'] | |
chestpain= st.radio('Chest pain type', chestpain_options) | |
restingBP = st.number_input('Resting systolic blood pressure mm HG (94-200)', 0, 200) | |
with col2: | |
serumcholestrol = st.number_input('Serum Cholesterol in mg/dl (126-564)', 0, 300) | |
fastingbloodsugar_options = ['LESS than 120 mg/dl', 'GREATER than or EQUAL 120 mg/dl'] | |
fastingbloodsugar = st.radio('Fasting blood sugar', fastingbloodsugar_options) | |
maxheartrate = st.number_input('Maximum heart rate achieved BPM (71-202)', 0, 300) | |
exerciseangia_options = ['no', 'yes'] | |
exerciseangia = st.radio('Exercise induced angina', exerciseangia_options) | |
st.divider() | |
input={ | |
'age': age, | |
'gender': sex_options.index(sex), | |
'chestpain': chestpain_options.index(chestpain), | |
'restingBP': restingBP, | |
'serumcholestrol': serumcholestrol, | |
'fastingbloodsugar': fastingbloodsugar_options.index(fastingbloodsugar), | |
'maxheartrate': maxheartrate, | |
'exerciseangia': exerciseangia_options.index(exerciseangia) | |
} | |
col1b, col2b = st.columns(2) | |
with col1b: | |
but1 = st.empty() | |
with col2b: | |
but2 = st.empty() | |
st.divider() | |
if but1.button('Predict Input', use_container_width=True): | |
predict(input) | |
if but2.button('Predict Random', use_container_width=True): | |
predict({ | |
'age': np.random.randint(35, 90), | |
'gender': np.random.randint(0, 2), | |
'chestpain': np.random.randint(0, 4), | |
'restingBP': np.random.randint(80, 200), | |
'serumcholestrol': np.random.randint(100, 600), | |
'fastingbloodsugar': np.random.randint(0, 2), | |
'maxheartrate': np.random.randint(70, 220), | |
'exerciseangia': np.random.randint(0, 2) | |
}) | |