File size: 7,438 Bytes
cbc5f50
8f5efcd
 
 
35e5b86
 
5d4b558
bf3a8f9
c567e40
cbc5f50
bf3a8f9
 
db83413
23a078b
cbc5f50
8f5efcd
 
0433697
8f5efcd
 
 
 
4b8158a
23a078b
 
 
4b8158a
23a078b
 
 
 
 
 
 
35e5b86
 
 
 
 
 
 
 
 
 
 
 
 
4b8158a
 
 
221622c
4b8158a
9ab4c69
fb95dc6
221622c
35e5b86
 
bf3a8f9
 
 
 
5d4b558
a459611
4be9ec7
 
ca9891a
8111dc3
015d55d
4be9ec7
 
c567e40
c81bbb9
c567e40
35e5b86
4b8158a
 
 
 
 
 
 
 
35e5b86
 
 
 
 
 
 
 
 
 
 
 
 
339157b
35e5b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb95dc6
35e5b86
 
bf3a8f9
35e5b86
 
 
5d4b558
bf3a8f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e086c8c
bf3a8f9
 
 
 
5ff3852
bd29ef5
9637b10
 
8e8061c
4aa5f86
6021f49
c14fc38
 
e80d1dd
8e8061c
6034779
bf3a8f9
 
 
5d4b558
015d55d
4a7aaea
015d55d
 
 
 
 
4a7aaea
c81bbb9
c567e40
 
 
 
 
 
 
c81bbb9
 
 
 
 
221622c
4b8158a
c734fd3
4b8158a
 
 
 
 
 
 
 
 
c567e40
35e5b86
e681ef8
35e5b86
4b8158a
 
 
 
 
 
 
 
c567e40
8f5efcd
 
e681ef8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import gradio as gr
import hopsworks
import joblib
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import shap
from sklearn.pipeline import make_pipeline
import seaborn as sns

feature_names = ["Age", "BMI", "HbA1c", "Blood Glucose"]

project = hopsworks.login(project="SonyaStern_Lab1")
fs = project.get_feature_store()

print("trying to dl model")
mr = project.get_model_registry()
model = mr.get_model("diabetes_model", version=1)
model_dir = model.download()
model = joblib.load(model_dir + "/diabetes_model.pkl")
print("Model downloaded")

diabetes_fg = fs.get_feature_group(name="diabetes_gan", version=1)
query = diabetes_fg.select_all()
# feature_view = fs.get_or_create_feature_view(name="diabetes",
feature_view = fs.get_or_create_feature_view(
    name="diabetes_gan",
    version=1,
    description="Read from Diabetes dataset",
    labels=["diabetes"],
    query=query,
)
diabetes_df = pd.DataFrame(diabetes_fg.read())

with gr.Blocks() as demo:
    with gr.Row():
        gr.HTML(value="<h1 style='text-align: center;'>Diabetes prediction</h1>")
    with gr.Row():
        with gr.Column():
            age_input = gr.Number(label="age")
            bmi_input = gr.Slider(10, 100, label="bmi", info="Body Mass Index")
            hba1c_input = gr.Slider(
                3.5, 9, label="hba1c_level", info="Glycated Haemoglobin"
            )
            blood_glucose_input = gr.Slider(
                80, 300, label="blood_glucose_level", info="Blood Glucose Level"
            )
            existent_info_input = gr.Radio(
                ["yes", "no", "Don't know"],
                label="Do you already know if you have diabetes? (This will not be used for the prediction)",
            )
            consent_input = gr.Checkbox(
                info="I consent that my personal data will be saved and potentially be used for the model training",
                label="accept",
            )
            btn = gr.Button("Submit")
        with gr.Column():
            with gr.Row():
                output = gr.Text(label="Model prediction")
            with gr.Row():
                mean_plot = gr.Plot()
    with gr.Row():
        with gr.Accordion("See model explanability", open=False):
            with gr.Row():
                with gr.Column():
                    waterfall_plot = gr.Plot()
                with gr.Column():
                    summary_plot = gr.Plot()
            with gr.Row():
                with gr.Column():
                    importance_plot = gr.Plot()
                with gr.Column():
                    decision_plot = gr.Plot()

    def submit_inputs(
        age_input,
        bmi_input,
        hba1c_input,
        blood_glucose_input,
        existent_info_input,
        consent_input,
    ):
        df = pd.DataFrame(
            [[age_input, bmi_input, hba1c_input, blood_glucose_input]],
            columns=["age", "bmi", "hba1c_level", "blood_glucose_level"],
        )
        res = model.predict(df)

        mean_for_age = diabetes_df[
            (diabetes_df["diabetes"] == 0) & (diabetes_df["age"] == age_input)
        ].mean()

        print(
            "your bmi is:", bmi_input, "the mean for ur age is :", mean_for_age["bmi"]
        )
        categories = ["BMI", "HbA1c", "Blood Level"]
        fig, ax = plt.subplots()
        bar_width = 0.35
        indices = np.arange(len(categories))
        ax.bar(
            indices,
            [
                mean_for_age.bmi,
                mean_for_age.hba1c_level,
                mean_for_age.blood_glucose_level,
            ],
            bar_width,
            label="Reference",
            color="b",
            alpha=0.7,
        )
        ax.bar(
            indices + bar_width,
            [bmi_input, hba1c_input, blood_glucose_input],
            bar_width,
            label="User",
            color="r",
            alpha=0.7,
        )
        ax.legend()
        ax.set_xlabel("Variables")
        ax.set_ylabel("Values")
        ax.set_title("Comparison with average non-diabetic values for your age")
        ax.set_xticks(indices + bar_width / 2)
        ax.set_xticklabels(categories)

        ## explainability plots
        rf_classifier = model.named_steps["randomforestclassifier"]
        transformer_pipeline = make_pipeline(
            *[
                step
                for name, step in model.named_steps.items()
                if name != "randomforestclassifier"
            ]
        )
        transformed_df = transformer_pipeline.transform(df)

        # Generate the SHAP waterfall plot for fig2
        explainer = shap.TreeExplainer(rf_classifier)
        shap_values = explainer.shap_values(
            transformed_df
        )  # Compute SHAP values directly on the DataFrame
        predicted_class = rf_classifier.predict(transformed_df)[0]
        shap_values_for_predicted_class = shap_values[predicted_class]

        # Select the SHAP values for the first instance and the positive class
        shap_explanation = shap.Explanation(
            values=shap_values_for_predicted_class[0],
            base_values=explainer.expected_value[predicted_class],
            data=df.iloc[0],
            feature_names=["age", "bmi", "hba1c", "glucose"],
        )

        fig2 = plt.figure(figsize=(3, 3))  # Create a new figure for SHAP plot
        fig2.tight_layout()
        plt.gca().set_position((0, 0, 1, 1))
        plt.title("SHAP Waterfall Plot")  # Optionally set a title for the SHAP plot
        plt.tight_layout()
        # plt.xticks(rotation=90)
        # plt.yticks(rotation=45)
        plt.tick_params(axis="y", labelsize=3)
        shap.waterfall_plot(
            shap_explanation
        )  # Set show=False to prevent immediate display

        fig3 = plt.figure(figsize=(3, 3))
        plt.title("SHAP Summary Plot")
        shap.summary_plot(
            shap_values,
            features=transformed_df,
            feature_names=["age", "bmi", "hba1c", "glucose"],
        )

        fig4 = plt.figure(figsize=(3, 3))
        feature_importances = rf_classifier.feature_importances_
        plt.title("Feature Importances")
        sns.barplot(x=feature_importances, y=["age", "bmi", "hba1c", "glucose"])

        fig5 = plt.figure(figsize=(3, 3))
        plt.title("SHAP Interaction Plot")
        shap.decision_plot(
            explainer.expected_value[predicted_class],
            shap_values_for_predicted_class,
            df.iloc[0],
        )

        ## save user's data in hopsworks
        if consent_input == True:
            user_data_fg = fs.get_or_create_feature_group(
                name="user_diabetes_data",
                version=1,
                primary_key=["age", "bmi", "hba1c_level", "blood_glucose_level"],
                description="Submitted user data",
            )
            user_data_df = df.copy()
            user_data_df["diabetes"] = existent_info_input
            user_data_fg.insert(user_data_df)
            print("inserted new user data to hopsworks", user_data_df)
        return res, fig, fig2, fig3, fig4, fig5

    btn.click(
        submit_inputs,
        inputs=[
            age_input,
            bmi_input,
            hba1c_input,
            blood_glucose_input,
            existent_info_input,
            consent_input,
        ],
        outputs=[output, mean_plot, waterfall_plot, importance_plot, decision_plot],
    )

demo.launch()