Spaces:
Sleeping
Sleeping
ashkanpakzad
commited on
Merge branch 'main' of https://huggingface.co/spaces/ashkanpakzad/cvdriskdemo
Browse files- .gitignore +1 -0
- app.py +68 -27
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.streamlit
|
app.py
CHANGED
@@ -1,23 +1,53 @@
|
|
1 |
import pandas as pd
|
2 |
-
from pathlib import Path
|
3 |
-
from sklearn.linear_model import LogisticRegression
|
4 |
import pickle
|
5 |
import streamlit as st
|
6 |
import numpy as np
|
|
|
|
|
|
|
7 |
|
8 |
-
|
9 |
-
|
10 |
|
11 |
-
def predict(input):
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
prob = model.predict_proba([inputdf])
|
19 |
|
20 |
-
st.subheader(f'Risk of CVD: {prob[0][1]*100:.2f}%')
|
21 |
|
22 |
st.title('Cardiovascular Disease Risk Prediction DEMO')
|
23 |
st.markdown('''
|
@@ -25,27 +55,36 @@ st.markdown('''
|
|
25 |
|
26 |
Output from a simple logistic regression model based on 1000 individuals in India.
|
27 |
Data source: [Cardiovascular Disease Dataset](https://www.kaggle.com/datasets/jocelyndumlao/cardiovascular-disease-dataset/)
|
|
|
|
|
28 |
''')
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
|
|
32 |
|
33 |
-
|
34 |
-
sex = st.radio('Sex', sex_options)
|
35 |
|
36 |
-
chestpain_options = ['typical angina', 'atypical angina', 'non-anginal pain', 'none']
|
37 |
-
chestpain= st.radio('Chest pain type', chestpain_options)
|
38 |
|
39 |
-
|
40 |
-
serumcholestrol = st.number_input('Serum Cholesterol in mg/dl', 0, 300)
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
fastingbloodsugar = st.radio('Fasting blood sugar', fastingbloodsugar_options)
|
44 |
|
45 |
-
|
|
|
46 |
|
47 |
-
|
48 |
-
exerciseangia = st.radio('Exercise induced angina', exerciseangia_options)
|
49 |
|
50 |
input={
|
51 |
'age': age,
|
@@ -58,18 +97,20 @@ input={
|
|
58 |
'exerciseangia': exerciseangia_options.index(exerciseangia)
|
59 |
}
|
60 |
|
61 |
-
|
62 |
|
63 |
-
with
|
64 |
but1 = st.empty()
|
65 |
|
66 |
-
with
|
67 |
but2 = st.empty()
|
68 |
|
69 |
-
|
|
|
|
|
70 |
predict(input)
|
71 |
|
72 |
-
if but2.button('Predict Random'):
|
73 |
predict({
|
74 |
'age': np.random.randint(35, 90),
|
75 |
'gender': np.random.randint(0, 2),
|
|
|
1 |
import pandas as pd
|
|
|
|
|
2 |
import pickle
|
3 |
import streamlit as st
|
4 |
import numpy as np
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
import shap
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
|
9 |
+
model_path = hf_hub_download(repo_id=st.secrets["REPO_ID"], filename="model.pkl", token=st.secrets["HF_TOKEN"])
|
10 |
+
explainer_path = hf_hub_download(repo_id=st.secrets["REPO_ID"], filename="explainer.pkl", token=st.secrets["HF_TOKEN"])
|
11 |
|
|
|
12 |
|
13 |
+
# load model
|
14 |
+
model = pickle.load(open(model_path, "rb"))
|
15 |
+
def model_proba(x):
|
16 |
+
return model.predict_proba(x)[:, 1]
|
17 |
+
explainer = pickle.load(open(explainer_path, "rb"))
|
18 |
|
19 |
+
def predict(input):
|
20 |
+
col1c, col2c= st.columns([0.3, 0.7])
|
21 |
+
|
22 |
+
inputss = pd.Series(input)
|
23 |
+
inputdf = pd.DataFrame(inputss)
|
24 |
+
inputdf.rename(columns={0: 'value'}, inplace=True)
|
25 |
+
|
26 |
+
with col1c:
|
27 |
+
st.subheader('Input data')
|
28 |
+
st.table(inputdf)
|
29 |
+
|
30 |
+
prob = model.predict_proba([inputss])
|
31 |
+
|
32 |
+
st.header(f'CVD Risk: {prob[0][1]*100:.2f}%')
|
33 |
+
|
34 |
+
with col2c:
|
35 |
+
st.subheader('CVD Risk Explanation')
|
36 |
+
shap_value = explainer(pd.DataFrame(inputdf).T)
|
37 |
+
shap.decision_plot(shap_value.base_values, shap_value.values, feature_names=shap_value.feature_names)
|
38 |
+
ax = plt.gca()
|
39 |
+
ax.set_xlabel('<-- Feature input decreases risk | Feature input increases risk -->')
|
40 |
+
ax.set_ylabel('Feature impact -->')
|
41 |
+
st.pyplot(plt.gcf())
|
42 |
+
|
43 |
+
st.markdown('''
|
44 |
+
* The effect of each input feature's value on the model's result shown relates to THIS instance only.
|
45 |
+
* The straight vertical line is the expected (mean) value of the model.
|
46 |
+
* The plotted line shows the effect of each feature in deviating from the expected value.
|
47 |
+
''')
|
48 |
+
|
49 |
|
|
|
50 |
|
|
|
51 |
|
52 |
st.title('Cardiovascular Disease Risk Prediction DEMO')
|
53 |
st.markdown('''
|
|
|
55 |
|
56 |
Output from a simple logistic regression model based on 1000 individuals in India.
|
57 |
Data source: [Cardiovascular Disease Dataset](https://www.kaggle.com/datasets/jocelyndumlao/cardiovascular-disease-dataset/)
|
58 |
+
|
59 |
+
The CVD risk model prediction is explained using [SHAP](https://shap.readthedocs.io/en/stable/) values.
|
60 |
''')
|
61 |
|
62 |
+
col1, col2= st.columns(2)
|
63 |
+
|
64 |
+
with col1:
|
65 |
+
age = st.number_input('Age (years)', 0, 100)
|
66 |
+
|
67 |
+
sex_options = ['Female', 'Male']
|
68 |
+
sex = st.radio('Sex', sex_options)
|
69 |
|
70 |
+
chestpain_options = ['none', 'non-anginal pain', 'typical angina', 'atypical angina']
|
71 |
+
chestpain= st.radio('Chest pain type', chestpain_options)
|
72 |
|
73 |
+
restingBP = st.number_input('Resting systolic blood pressure mm HG (94-200)', 0, 200)
|
|
|
74 |
|
|
|
|
|
75 |
|
76 |
+
with col2:
|
77 |
+
serumcholestrol = st.number_input('Serum Cholesterol in mg/dl (126-564)', 0, 300)
|
78 |
+
|
79 |
+
fastingbloodsugar_options = ['LESS than 120 mg/dl', 'GREATER than or EQUAL 120 mg/dl']
|
80 |
+
fastingbloodsugar = st.radio('Fasting blood sugar', fastingbloodsugar_options)
|
81 |
|
82 |
+
maxheartrate = st.number_input('Maximum heart rate achieved BPM (71-202)', 0, 300)
|
|
|
83 |
|
84 |
+
exerciseangia_options = ['no', 'yes']
|
85 |
+
exerciseangia = st.radio('Exercise induced angina', exerciseangia_options)
|
86 |
|
87 |
+
st.divider()
|
|
|
88 |
|
89 |
input={
|
90 |
'age': age,
|
|
|
97 |
'exerciseangia': exerciseangia_options.index(exerciseangia)
|
98 |
}
|
99 |
|
100 |
+
col1b, col2b = st.columns(2)
|
101 |
|
102 |
+
with col1b:
|
103 |
but1 = st.empty()
|
104 |
|
105 |
+
with col2b:
|
106 |
but2 = st.empty()
|
107 |
|
108 |
+
st.divider()
|
109 |
+
|
110 |
+
if but1.button('Predict Input', use_container_width=True):
|
111 |
predict(input)
|
112 |
|
113 |
+
if but2.button('Predict Random', use_container_width=True):
|
114 |
predict({
|
115 |
'age': np.random.randint(35, 90),
|
116 |
'gender': np.random.randint(0, 2),
|