Spaces:
Runtime error
Runtime error
import gradio as gr | |
import matplotlib.pyplot as plt | |
import shap | |
import hopsworks | |
import pandas as pd | |
import joblib | |
project = hopsworks.login( | |
project="SonyaStern_Lab1", | |
api_key_value="c9StuuVQPoMUeXWe.jB2XeWcI8poKUN59W13MxAbMemzY7SChOnX151GtTFNhysBBUPMRuEp5IK7SE3i1", | |
) | |
mr = project.get_model_registry() | |
model = mr.get_model("diabetes_model", version=1) | |
model_dir = model.download() | |
model = joblib.load(model_dir + "/diabetes_model.pkl") | |
rf_model = model.steps[-1][1] # Load your model | |
df = pd.DataFrame( | |
[[20, 20, 30, 40]], | |
columns=["age", "bmi", "hba1c_level", "blood_glucose_level"], | |
) | |
def generate_plots(): | |
# Create the first plot as before | |
fig1, ax1 = plt.subplots() | |
ax1.plot([1, 2, 3], [4, 5, 6]) | |
ax1.set_title("Plot 1") | |
# Generate the SHAP waterfall plot for fig2 | |
explainer = shap.Explainer(rf_model) | |
shap_values = explainer.shap_values(df)[1] # Select SHAP values for class 1 | |
shap_values_exp = shap.Explanation( | |
values=shap_values[0], base_values=explainer.expected_value[1] | |
) | |
ax2 = shap.plots.waterfall( | |
shap_values_exp, show=False | |
) # Get the axis for the waterfall plot | |
return fig1, ax2 | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Plot(generate_plots()[0]) # Display first plot in the first row | |
with gr.Row(): | |
_, ax2 = generate_plots() | |
gr.Plot(ax2) # Display SHAP waterfall plot in the second row | |
demo.launch() | |