ashkanpakzad commited on
Commit
f6764e8
2 Parent(s): ed6d7a6 25e3660

Merge branch 'main' of https://huggingface.co/spaces/ashkanpakzad/cvdriskdemo

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +68 -27
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .streamlit
app.py CHANGED
@@ -1,23 +1,53 @@
1
  import pandas as pd
2
- from pathlib import Path
3
- from sklearn.linear_model import LogisticRegression
4
  import pickle
5
  import streamlit as st
6
  import numpy as np
 
 
 
7
 
8
- # load model
9
- model = pickle.load(open(Path("model.pkl"), "rb"))
10
 
11
- def predict(input):
12
 
13
- inputdf = pd.Series(input)
 
 
 
 
14
 
15
- st.subheader('Input data')
16
- st.table(pd.DataFrame(inputdf))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- prob = model.predict_proba([inputdf])
19
 
20
- st.subheader(f'Risk of CVD: {prob[0][1]*100:.2f}%')
21
 
22
  st.title('Cardiovascular Disease Risk Prediction DEMO')
23
  st.markdown('''
@@ -25,27 +55,36 @@ st.markdown('''
25
 
26
  Output from a simple logistic regression model based on 1000 individuals in India.
27
  Data source: [Cardiovascular Disease Dataset](https://www.kaggle.com/datasets/jocelyndumlao/cardiovascular-disease-dataset/)
 
 
28
  ''')
29
 
 
 
 
 
 
 
 
30
 
31
- age = st.number_input('Age', 0, 100)
 
32
 
33
- sex_options = ['Female', 'Male']
34
- sex = st.radio('Sex', sex_options)
35
 
36
- chestpain_options = ['typical angina', 'atypical angina', 'non-anginal pain', 'none']
37
- chestpain= st.radio('Chest pain type', chestpain_options)
38
 
39
- restingBP = st.number_input('Resting systolic blood pressure mmHG', 0, 200)
40
- serumcholestrol = st.number_input('Serum Cholesterol in mg/dl', 0, 300)
 
 
 
41
 
42
- fastingbloodsugar_options = ['< than 120 mg/dl', '>= than 120 mg/dl']
43
- fastingbloodsugar = st.radio('Fasting blood sugar', fastingbloodsugar_options)
44
 
45
- maxheartrate = st.number_input('Maximum heart rate achieved', 0, 300)
 
46
 
47
- exerciseangia_options = ['no', 'yes']
48
- exerciseangia = st.radio('Exercise induced angina', exerciseangia_options)
49
 
50
  input={
51
  'age': age,
@@ -58,18 +97,20 @@ input={
58
  'exerciseangia': exerciseangia_options.index(exerciseangia)
59
  }
60
 
61
- col1, col2 = st.columns(2)
62
 
63
- with col1:
64
  but1 = st.empty()
65
 
66
- with col2:
67
  but2 = st.empty()
68
 
69
- if but1.button('Predict Input'):
 
 
70
  predict(input)
71
 
72
- if but2.button('Predict Random'):
73
  predict({
74
  'age': np.random.randint(35, 90),
75
  'gender': np.random.randint(0, 2),
 
1
  import pandas as pd
 
 
2
  import pickle
3
  import streamlit as st
4
  import numpy as np
5
+ from huggingface_hub import hf_hub_download
6
+ import shap
7
+ import matplotlib.pyplot as plt
8
 
9
+ model_path = hf_hub_download(repo_id=st.secrets["REPO_ID"], filename="model.pkl", token=st.secrets["HF_TOKEN"])
10
+ explainer_path = hf_hub_download(repo_id=st.secrets["REPO_ID"], filename="explainer.pkl", token=st.secrets["HF_TOKEN"])
11
 
 
12
 
13
+ # load model
14
+ model = pickle.load(open(model_path, "rb"))
15
+ def model_proba(x):
16
+ return model.predict_proba(x)[:, 1]
17
+ explainer = pickle.load(open(explainer_path, "rb"))
18
 
19
+ def predict(input):
20
+ col1c, col2c= st.columns([0.3, 0.7])
21
+
22
+ inputss = pd.Series(input)
23
+ inputdf = pd.DataFrame(inputss)
24
+ inputdf.rename(columns={0: 'value'}, inplace=True)
25
+
26
+ with col1c:
27
+ st.subheader('Input data')
28
+ st.table(inputdf)
29
+
30
+ prob = model.predict_proba([inputss])
31
+
32
+ st.header(f'CVD Risk: {prob[0][1]*100:.2f}%')
33
+
34
+ with col2c:
35
+ st.subheader('CVD Risk Explanation')
36
+ shap_value = explainer(pd.DataFrame(inputdf).T)
37
+ shap.decision_plot(shap_value.base_values, shap_value.values, feature_names=shap_value.feature_names)
38
+ ax = plt.gca()
39
+ ax.set_xlabel('<-- Feature input decreases risk | Feature input increases risk -->')
40
+ ax.set_ylabel('Feature impact -->')
41
+ st.pyplot(plt.gcf())
42
+
43
+ st.markdown('''
44
+ * The effect of each input feature's value on the model's result shown relates to THIS instance only.
45
+ * The straight vertical line is the expected (mean) value of the model.
46
+ * The plotted line shows the effect of each feature in deviating from the expected value.
47
+ ''')
48
+
49
 
 
50
 
 
51
 
52
  st.title('Cardiovascular Disease Risk Prediction DEMO')
53
  st.markdown('''
 
55
 
56
  Output from a simple logistic regression model based on 1000 individuals in India.
57
  Data source: [Cardiovascular Disease Dataset](https://www.kaggle.com/datasets/jocelyndumlao/cardiovascular-disease-dataset/)
58
+
59
+ The CVD risk model prediction is explained using [SHAP](https://shap.readthedocs.io/en/stable/) values.
60
  ''')
61
 
62
+ col1, col2= st.columns(2)
63
+
64
+ with col1:
65
+ age = st.number_input('Age (years)', 0, 100)
66
+
67
+ sex_options = ['Female', 'Male']
68
+ sex = st.radio('Sex', sex_options)
69
 
70
+ chestpain_options = ['none', 'non-anginal pain', 'typical angina', 'atypical angina']
71
+ chestpain= st.radio('Chest pain type', chestpain_options)
72
 
73
+ restingBP = st.number_input('Resting systolic blood pressure mm HG (94-200)', 0, 200)
 
74
 
 
 
75
 
76
+ with col2:
77
+ serumcholestrol = st.number_input('Serum Cholesterol in mg/dl (126-564)', 0, 300)
78
+
79
+ fastingbloodsugar_options = ['LESS than 120 mg/dl', 'GREATER than or EQUAL 120 mg/dl']
80
+ fastingbloodsugar = st.radio('Fasting blood sugar', fastingbloodsugar_options)
81
 
82
+ maxheartrate = st.number_input('Maximum heart rate achieved BPM (71-202)', 0, 300)
 
83
 
84
+ exerciseangia_options = ['no', 'yes']
85
+ exerciseangia = st.radio('Exercise induced angina', exerciseangia_options)
86
 
87
+ st.divider()
 
88
 
89
  input={
90
  'age': age,
 
97
  'exerciseangia': exerciseangia_options.index(exerciseangia)
98
  }
99
 
100
+ col1b, col2b = st.columns(2)
101
 
102
+ with col1b:
103
  but1 = st.empty()
104
 
105
+ with col2b:
106
  but2 = st.empty()
107
 
108
+ st.divider()
109
+
110
+ if but1.button('Predict Input', use_container_width=True):
111
  predict(input)
112
 
113
+ if but2.button('Predict Random', use_container_width=True):
114
  predict({
115
  'age': np.random.randint(35, 90),
116
  'gender': np.random.randint(0, 2),