saad177 commited on
Commit
bf3a8f9
1 Parent(s): 9637b10

fix shap plot

Browse files
Files changed (1) hide show
  1. app.py +44 -16
app.py CHANGED
@@ -5,8 +5,14 @@ import pandas as pd
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  import shap
 
8
 
9
- project = hopsworks.login(project="SonyaStern_Lab1")
 
 
 
 
 
10
  fs = project.get_feature_store()
11
 
12
  print("trying to dl model")
@@ -51,11 +57,14 @@ with gr.Blocks() as demo:
51
  )
52
  btn = gr.Button("Submit")
53
  with gr.Column():
54
- output = gr.Text(label="Model prediction")
55
- plot = gr.Plot()
 
 
56
  with gr.Row():
57
  with gr.Accordion("See model explanability", open=False):
58
- waterfall_plot = gr.Plot()
 
59
 
60
  def submit_inputs(
61
  age_input,
@@ -105,23 +114,42 @@ with gr.Blocks() as demo:
105
  ax.legend()
106
  ax.set_xlabel("Variables")
107
  ax.set_ylabel("Values")
108
- ax.set_title("Comparison with Mean values for your age")
109
  ax.set_xticks(indices + bar_width / 2)
110
  ax.set_xticklabels(categories)
111
 
112
  ## explainability plots
113
- rf_model = model.steps[-1][1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- explainer = shap.Explainer(rf_model)
116
- shap_values = explainer.shap_values(df)[1] # Select SHAP values for class 1
117
- print(shap_values.shape) # should show (1, 4)
118
- # Convert shap_values to Explanation object
119
- shap_values_exp = shap.Explanation(
120
- values=shap_values[0], base_values=explainer.expected_value[1]
121
  )
122
 
123
- # Plot the waterfall plot
124
- shap_waterfall_plot = shap.plots.waterfall(shap_values_exp)
 
 
 
125
 
126
  ## save user's data in hopsworks
127
  if consent_input == True:
@@ -135,7 +163,7 @@ with gr.Blocks() as demo:
135
  user_data_df["diabetes"] = existent_info_input
136
  user_data_fg.insert(user_data_df)
137
  print("inserted new user data to hopsworks", user_data_df)
138
- return res, fig, shap_waterfall_plot
139
 
140
  btn.click(
141
  submit_inputs,
@@ -147,7 +175,7 @@ with gr.Blocks() as demo:
147
  existent_info_input,
148
  consent_input,
149
  ],
150
- outputs=[output, plot, waterfall_plot],
151
  )
152
 
153
  demo.launch()
 
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  import shap
8
+ from sklearn.pipeline import make_pipeline
9
 
10
+ feature_names = ["Age", "BMI", "HbA1c", "Blood Glucose"]
11
+
12
+ project = hopsworks.login(
13
+ project="SonyaStern_Lab1",
14
+ api_key_value="c9StuuVQPoMUeXWe.jB2XeWcI8poKUN59W13MxAbMemzY7SChOnX151GtTFNhysBBUPMRuEp5IK7SE3i1",
15
+ )
16
  fs = project.get_feature_store()
17
 
18
  print("trying to dl model")
 
57
  )
58
  btn = gr.Button("Submit")
59
  with gr.Column():
60
+ with gr.Row():
61
+ output = gr.Text(label="Model prediction")
62
+ with gr.Row():
63
+ mean_plot = gr.Plot()
64
  with gr.Row():
65
  with gr.Accordion("See model explanability", open=False):
66
+ with gr.Column():
67
+ waterfall_plot = gr.Plot()
68
 
69
  def submit_inputs(
70
  age_input,
 
114
  ax.legend()
115
  ax.set_xlabel("Variables")
116
  ax.set_ylabel("Values")
117
+ ax.set_title("Comparison with average non-diabetic values for your age")
118
  ax.set_xticks(indices + bar_width / 2)
119
  ax.set_xticklabels(categories)
120
 
121
  ## explainability plots
122
+ rf_classifier = model.named_steps["randomforestclassifier"]
123
+ transformer_pipeline = make_pipeline(
124
+ *[
125
+ step
126
+ for name, step in model.named_steps.items()
127
+ if name != "randomforestclassifier"
128
+ ]
129
+ )
130
+ transformed_df = transformer_pipeline.transform(df)
131
+
132
+ # Generate the SHAP waterfall plot for fig2
133
+ explainer = shap.TreeExplainer(rf_classifier)
134
+ shap_values = explainer.shap_values(
135
+ transformed_df
136
+ ) # Compute SHAP values directly on the DataFrame
137
+ predicted_class = rf_classifier.predict(transformed_df)[0]
138
+ shap_values_for_predicted_class = shap_values[predicted_class]
139
 
140
+ # Select the SHAP values for the first instance and the positive class
141
+ shap_explanation = shap.Explanation(
142
+ values=shap_values_for_predicted_class[0],
143
+ base_values=explainer.expected_value[predicted_class],
144
+ data=transformed_df[0],
145
+ feature_names=df.columns.tolist(),
146
  )
147
 
148
+ fig2 = plt.figure(figsize=(12, 6)) # Create a new figure for SHAP plot
149
+ shap.waterfall_plot(
150
+ shap_explanation
151
+ ) # Set show=False to prevent immediate display
152
+ plt.title("SHAP Waterfall Plot") # Optionally set a title for the SHAP plot
153
 
154
  ## save user's data in hopsworks
155
  if consent_input == True:
 
163
  user_data_df["diabetes"] = existent_info_input
164
  user_data_fg.insert(user_data_df)
165
  print("inserted new user data to hopsworks", user_data_df)
166
+ return res, fig, fig2
167
 
168
  btn.click(
169
  submit_inputs,
 
175
  existent_info_input,
176
  consent_input,
177
  ],
178
+ outputs=[output, mean_plot, waterfall_plot],
179
  )
180
 
181
  demo.launch()