Spaces:
Runtime error
Runtime error
fix shap plot
Browse files
app.py
CHANGED
@@ -5,8 +5,14 @@ import pandas as pd
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
import shap
|
|
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
fs = project.get_feature_store()
|
11 |
|
12 |
print("trying to dl model")
|
@@ -51,11 +57,14 @@ with gr.Blocks() as demo:
|
|
51 |
)
|
52 |
btn = gr.Button("Submit")
|
53 |
with gr.Column():
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
with gr.Row():
|
57 |
with gr.Accordion("See model explanability", open=False):
|
58 |
-
|
|
|
59 |
|
60 |
def submit_inputs(
|
61 |
age_input,
|
@@ -105,23 +114,42 @@ with gr.Blocks() as demo:
|
|
105 |
ax.legend()
|
106 |
ax.set_xlabel("Variables")
|
107 |
ax.set_ylabel("Values")
|
108 |
-
ax.set_title("Comparison with
|
109 |
ax.set_xticks(indices + bar_width / 2)
|
110 |
ax.set_xticklabels(categories)
|
111 |
|
112 |
## explainability plots
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
)
|
122 |
|
123 |
-
#
|
124 |
-
|
|
|
|
|
|
|
125 |
|
126 |
## save user's data in hopsworks
|
127 |
if consent_input == True:
|
@@ -135,7 +163,7 @@ with gr.Blocks() as demo:
|
|
135 |
user_data_df["diabetes"] = existent_info_input
|
136 |
user_data_fg.insert(user_data_df)
|
137 |
print("inserted new user data to hopsworks", user_data_df)
|
138 |
-
return res, fig,
|
139 |
|
140 |
btn.click(
|
141 |
submit_inputs,
|
@@ -147,7 +175,7 @@ with gr.Blocks() as demo:
|
|
147 |
existent_info_input,
|
148 |
consent_input,
|
149 |
],
|
150 |
-
outputs=[output,
|
151 |
)
|
152 |
|
153 |
demo.launch()
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import numpy as np
|
7 |
import shap
|
8 |
+
from sklearn.pipeline import make_pipeline
|
9 |
|
10 |
+
feature_names = ["Age", "BMI", "HbA1c", "Blood Glucose"]
|
11 |
+
|
12 |
+
project = hopsworks.login(
|
13 |
+
project="SonyaStern_Lab1",
|
14 |
+
api_key_value="c9StuuVQPoMUeXWe.jB2XeWcI8poKUN59W13MxAbMemzY7SChOnX151GtTFNhysBBUPMRuEp5IK7SE3i1",
|
15 |
+
)
|
16 |
fs = project.get_feature_store()
|
17 |
|
18 |
print("trying to dl model")
|
|
|
57 |
)
|
58 |
btn = gr.Button("Submit")
|
59 |
with gr.Column():
|
60 |
+
with gr.Row():
|
61 |
+
output = gr.Text(label="Model prediction")
|
62 |
+
with gr.Row():
|
63 |
+
mean_plot = gr.Plot()
|
64 |
with gr.Row():
|
65 |
with gr.Accordion("See model explanability", open=False):
|
66 |
+
with gr.Column():
|
67 |
+
waterfall_plot = gr.Plot()
|
68 |
|
69 |
def submit_inputs(
|
70 |
age_input,
|
|
|
114 |
ax.legend()
|
115 |
ax.set_xlabel("Variables")
|
116 |
ax.set_ylabel("Values")
|
117 |
+
ax.set_title("Comparison with average non-diabetic values for your age")
|
118 |
ax.set_xticks(indices + bar_width / 2)
|
119 |
ax.set_xticklabels(categories)
|
120 |
|
121 |
## explainability plots
|
122 |
+
rf_classifier = model.named_steps["randomforestclassifier"]
|
123 |
+
transformer_pipeline = make_pipeline(
|
124 |
+
*[
|
125 |
+
step
|
126 |
+
for name, step in model.named_steps.items()
|
127 |
+
if name != "randomforestclassifier"
|
128 |
+
]
|
129 |
+
)
|
130 |
+
transformed_df = transformer_pipeline.transform(df)
|
131 |
+
|
132 |
+
# Generate the SHAP waterfall plot for fig2
|
133 |
+
explainer = shap.TreeExplainer(rf_classifier)
|
134 |
+
shap_values = explainer.shap_values(
|
135 |
+
transformed_df
|
136 |
+
) # Compute SHAP values directly on the DataFrame
|
137 |
+
predicted_class = rf_classifier.predict(transformed_df)[0]
|
138 |
+
shap_values_for_predicted_class = shap_values[predicted_class]
|
139 |
|
140 |
+
# Select the SHAP values for the first instance and the positive class
|
141 |
+
shap_explanation = shap.Explanation(
|
142 |
+
values=shap_values_for_predicted_class[0],
|
143 |
+
base_values=explainer.expected_value[predicted_class],
|
144 |
+
data=transformed_df[0],
|
145 |
+
feature_names=df.columns.tolist(),
|
146 |
)
|
147 |
|
148 |
+
fig2 = plt.figure(figsize=(12, 6)) # Create a new figure for SHAP plot
|
149 |
+
shap.waterfall_plot(
|
150 |
+
shap_explanation
|
151 |
+
) # Set show=False to prevent immediate display
|
152 |
+
plt.title("SHAP Waterfall Plot") # Optionally set a title for the SHAP plot
|
153 |
|
154 |
## save user's data in hopsworks
|
155 |
if consent_input == True:
|
|
|
163 |
user_data_df["diabetes"] = existent_info_input
|
164 |
user_data_fg.insert(user_data_df)
|
165 |
print("inserted new user data to hopsworks", user_data_df)
|
166 |
+
return res, fig, fig2
|
167 |
|
168 |
btn.click(
|
169 |
submit_inputs,
|
|
|
175 |
existent_info_input,
|
176 |
consent_input,
|
177 |
],
|
178 |
+
outputs=[output, mean_plot, waterfall_plot],
|
179 |
)
|
180 |
|
181 |
demo.launch()
|