cvdriskdemo / app.py
ashkanpakzad's picture
Merge branch 'main' of https://huggingface.co/spaces/ashkanpakzad/cvdriskdemo
f6764e8
import pandas as pd
import pickle
import streamlit as st
import numpy as np
from huggingface_hub import hf_hub_download
import shap
import matplotlib.pyplot as plt
model_path = hf_hub_download(repo_id=st.secrets["REPO_ID"], filename="model.pkl", token=st.secrets["HF_TOKEN"])
explainer_path = hf_hub_download(repo_id=st.secrets["REPO_ID"], filename="explainer.pkl", token=st.secrets["HF_TOKEN"])
# load model
model = pickle.load(open(model_path, "rb"))
def model_proba(x):
return model.predict_proba(x)[:, 1]
explainer = pickle.load(open(explainer_path, "rb"))
def predict(input):
col1c, col2c= st.columns([0.3, 0.7])
inputss = pd.Series(input)
inputdf = pd.DataFrame(inputss)
inputdf.rename(columns={0: 'value'}, inplace=True)
with col1c:
st.subheader('Input data')
st.table(inputdf)
prob = model.predict_proba([inputss])
st.header(f'CVD Risk: {prob[0][1]*100:.2f}%')
with col2c:
st.subheader('CVD Risk Explanation')
shap_value = explainer(pd.DataFrame(inputdf).T)
shap.decision_plot(shap_value.base_values, shap_value.values, feature_names=shap_value.feature_names)
ax = plt.gca()
ax.set_xlabel('<-- Feature input decreases risk | Feature input increases risk -->')
ax.set_ylabel('Feature impact -->')
st.pyplot(plt.gcf())
st.markdown('''
* The effect of each input feature's value on the model's result shown relates to THIS instance only.
* The straight vertical line is the expected (mean) value of the model.
* The plotted line shows the effect of each feature in deviating from the expected value.
''')
st.title('Cardiovascular Disease Risk Prediction DEMO')
st.markdown('''
This is a CVD risk prediction app for demonstration only, **therefore not for clinical use**.
Output from a simple logistic regression model based on 1000 individuals in India.
Data source: [Cardiovascular Disease Dataset](https://www.kaggle.com/datasets/jocelyndumlao/cardiovascular-disease-dataset/)
The CVD risk model prediction is explained using [SHAP](https://shap.readthedocs.io/en/stable/) values.
''')
col1, col2= st.columns(2)
with col1:
age = st.number_input('Age (years)', 0, 100)
sex_options = ['Female', 'Male']
sex = st.radio('Sex', sex_options)
chestpain_options = ['none', 'non-anginal pain', 'typical angina', 'atypical angina']
chestpain= st.radio('Chest pain type', chestpain_options)
restingBP = st.number_input('Resting systolic blood pressure mm HG (94-200)', 0, 200)
with col2:
serumcholestrol = st.number_input('Serum Cholesterol in mg/dl (126-564)', 0, 300)
fastingbloodsugar_options = ['LESS than 120 mg/dl', 'GREATER than or EQUAL 120 mg/dl']
fastingbloodsugar = st.radio('Fasting blood sugar', fastingbloodsugar_options)
maxheartrate = st.number_input('Maximum heart rate achieved BPM (71-202)', 0, 300)
exerciseangia_options = ['no', 'yes']
exerciseangia = st.radio('Exercise induced angina', exerciseangia_options)
st.divider()
input={
'age': age,
'gender': sex_options.index(sex),
'chestpain': chestpain_options.index(chestpain),
'restingBP': restingBP,
'serumcholestrol': serumcholestrol,
'fastingbloodsugar': fastingbloodsugar_options.index(fastingbloodsugar),
'maxheartrate': maxheartrate,
'exerciseangia': exerciseangia_options.index(exerciseangia)
}
col1b, col2b = st.columns(2)
with col1b:
but1 = st.empty()
with col2b:
but2 = st.empty()
st.divider()
if but1.button('Predict Input', use_container_width=True):
predict(input)
if but2.button('Predict Random', use_container_width=True):
predict({
'age': np.random.randint(35, 90),
'gender': np.random.randint(0, 2),
'chestpain': np.random.randint(0, 4),
'restingBP': np.random.randint(80, 200),
'serumcholestrol': np.random.randint(100, 600),
'fastingbloodsugar': np.random.randint(0, 2),
'maxheartrate': np.random.randint(70, 220),
'exerciseangia': np.random.randint(0, 2)
})