Diabetes-Prediction / testcorcel.py
saad177's picture
fix layout size
4aa5f86
raw
history blame
No virus
2.26 kB
import gradio as gr
import matplotlib.pyplot as plt
import shap
import hopsworks
import pandas as pd
import joblib
from sklearn.pipeline import make_pipeline
df = pd.DataFrame(
[[20, 20, 30, 40]],
columns=["age", "bmi", "hba1c_level", "blood_glucose_level"],
)
# Assuming the hopsworks login and model retrieval code works as expected
project = hopsworks.login(
project="SonyaStern_Lab1",
api_key_value="c9StuuVQPoMUeXWe.jB2XeWcI8poKUN59W13MxAbMemzY7SChOnX151GtTFNhysBBUPMRuEp5IK7SE3i1",
)
mr = project.get_model_registry()
model = mr.get_model("diabetes_model", version=1)
model_dir = model.download()
model = joblib.load(model_dir + "/diabetes_model.pkl")
print("printing model pipeline:", model)
rf_classifier = model.named_steps["randomforestclassifier"]
transformer_pipeline = make_pipeline(
*[
step
for name, step in model.named_steps.items()
if name != "randomforestclassifier"
]
)
transformed_df = transformer_pipeline.transform(df)
# rf_model = model.steps[-1][1] # Load your model
def generate_plots():
# Create the first plot as before
fig1, ax1 = plt.subplots()
ax1.plot([1, 2, 3], [4, 5, 6])
ax1.set_title("Plot 1")
# Generate the SHAP waterfall plot for fig2
explainer = shap.TreeExplainer(rf_classifier)
shap_values = explainer.shap_values(transformed_df)
predicted_class = rf_classifier.predict(transformed_df)[0]
shap_values_for_predicted_class = shap_values[predicted_class]
# base_value = explainer.expected_value[1]
fig2 = plt.figure() # Create a new figure for SHAP plot
shap_explanation = shap.Explanation(
values=shap_values_for_predicted_class[0],
base_values=explainer.expected_value[predicted_class],
data=transformed_df[0],
feature_names=df.columns.tolist(),
)
shap.waterfall_plot(shap_explanation)
plt.title("SHAP Waterfall Plot") # Optionally set a title for the SHAP plot
return fig1, fig2
# Generate plots once and store them
fig1, fig2 = generate_plots()
with gr.Blocks() as demo:
with gr.Row():
gr.Plot(fig1) # Display first plot in the first row
with gr.Row():
gr.Plot(fig2) # Display SHAP waterfall plot in the second row
demo.launch()