saad177's picture
fix layout size
c14fc38
raw
history blame
No virus
6.26 kB
import gradio as gr
import hopsworks
import joblib
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import shap
from sklearn.pipeline import make_pipeline
feature_names = ["Age", "BMI", "HbA1c", "Blood Glucose"]
project = hopsworks.login(
project="SonyaStern_Lab1",
api_key_value="c9StuuVQPoMUeXWe.jB2XeWcI8poKUN59W13MxAbMemzY7SChOnX151GtTFNhysBBUPMRuEp5IK7SE3i1",
)
fs = project.get_feature_store()
print("trying to dl model")
mr = project.get_model_registry()
model = mr.get_model("diabetes_model", version=1)
model_dir = model.download()
model = joblib.load(model_dir + "/diabetes_model.pkl")
print("Model downloaded")
diabetes_fg = fs.get_feature_group(name="diabetes_gan", version=1)
query = diabetes_fg.select_all()
# feature_view = fs.get_or_create_feature_view(name="diabetes",
feature_view = fs.get_or_create_feature_view(
name="diabetes_gan",
version=1,
description="Read from Diabetes dataset",
labels=["diabetes"],
query=query,
)
diabetes_df = pd.DataFrame(diabetes_fg.read())
with gr.Blocks() as demo:
with gr.Row():
gr.HTML(value="<h1 style='text-align: center;'>Diabetes prediction</h1>")
with gr.Row():
with gr.Column():
age_input = gr.Number(label="age")
bmi_input = gr.Slider(10, 100, label="bmi", info="Body Mass Index")
hba1c_input = gr.Slider(
3.5, 9, label="hba1c_level", info="Glycated Haemoglobin"
)
blood_glucose_input = gr.Slider(
80, 300, label="blood_glucose_level", info="Blood Glucose Level"
)
existent_info_input = gr.Radio(
["yes", "no", "Don't know"],
label="Do you already know if you have diabetes? (This will not be used for the prediction)",
)
consent_input = gr.Checkbox(
info="I consent that my personal data will be saved and potentially be used for the model training",
label="accept",
)
btn = gr.Button("Submit")
with gr.Column():
with gr.Row():
output = gr.Text(label="Model prediction")
with gr.Row():
mean_plot = gr.Plot()
with gr.Row():
with gr.Accordion("See model explanability", open=False):
with gr.Column():
waterfall_plot = gr.Plot()
def submit_inputs(
age_input,
bmi_input,
hba1c_input,
blood_glucose_input,
existent_info_input,
consent_input,
):
df = pd.DataFrame(
[[age_input, bmi_input, hba1c_input, blood_glucose_input]],
columns=["age", "bmi", "hba1c_level", "blood_glucose_level"],
)
res = model.predict(df)
mean_for_age = diabetes_df[
(diabetes_df["diabetes"] == 0) & (diabetes_df["age"] == age_input)
].mean()
print(
"your bmi is:", bmi_input, "the mean for ur age is :", mean_for_age["bmi"]
)
categories = ["BMI", "HbA1c", "Blood Level"]
fig, ax = plt.subplots()
bar_width = 0.35
indices = np.arange(len(categories))
ax.bar(
indices,
[
mean_for_age.bmi,
mean_for_age.hba1c_level,
mean_for_age.blood_glucose_level,
],
bar_width,
label="Reference",
color="b",
alpha=0.7,
)
ax.bar(
indices + bar_width,
[bmi_input, hba1c_input, blood_glucose_input],
bar_width,
label="User",
color="r",
alpha=0.7,
)
ax.legend()
ax.set_xlabel("Variables")
ax.set_ylabel("Values")
ax.set_title("Comparison with average non-diabetic values for your age")
ax.set_xticks(indices + bar_width / 2)
ax.set_xticklabels(categories)
## explainability plots
rf_classifier = model.named_steps["randomforestclassifier"]
transformer_pipeline = make_pipeline(
*[
step
for name, step in model.named_steps.items()
if name != "randomforestclassifier"
]
)
transformed_df = transformer_pipeline.transform(df)
# Generate the SHAP waterfall plot for fig2
explainer = shap.TreeExplainer(rf_classifier)
shap_values = explainer.shap_values(
transformed_df
) # Compute SHAP values directly on the DataFrame
predicted_class = rf_classifier.predict(transformed_df)[0]
shap_values_for_predicted_class = shap_values[predicted_class]
# Select the SHAP values for the first instance and the positive class
shap_explanation = shap.Explanation(
values=shap_values_for_predicted_class[0],
base_values=explainer.expected_value[predicted_class],
data=transformed_df[0],
feature_names=df.columns.tolist(),
)
fig2 = plt.figure(figsize=(12, 6)) # Create a new figure for SHAP plot
plt.title("SHAP Waterfall Plot") # Optionally set a title for the SHAP plot
plt.tight_layout()
shap.waterfall_plot(
shap_explanation
) # Set show=False to prevent immediate display
## save user's data in hopsworks
if consent_input == True:
user_data_fg = fs.get_or_create_feature_group(
name="user_diabetes_data",
version=1,
primary_key=["age", "bmi", "hba1c_level", "blood_glucose_level"],
description="Submitted user data",
)
user_data_df = df.copy()
user_data_df["diabetes"] = existent_info_input
user_data_fg.insert(user_data_df)
print("inserted new user data to hopsworks", user_data_df)
return res, fig, fig2
btn.click(
submit_inputs,
inputs=[
age_input,
bmi_input,
hba1c_input,
blood_glucose_input,
existent_info_input,
consent_input,
],
outputs=[output, mean_plot, waterfall_plot],
)
demo.launch()