File size: 7,986 Bytes
cbc5f50
8f5efcd
 
 
35e5b86
 
5d4b558
bf3a8f9
c567e40
cbc5f50
bf3a8f9
 
db83413
23a078b
cbc5f50
8f5efcd
 
350dd14
8f5efcd
89d1951
8f5efcd
 
4b8158a
23a078b
 
 
4b8158a
23a078b
 
 
 
 
 
 
35e5b86
 
 
 
 
 
 
 
 
 
 
 
 
4b8158a
 
 
221622c
217fd17
 
 
221622c
35e5b86
 
bf3a8f9
 
 
 
5d4b558
a459611
4be9ec7
 
ca9891a
8111dc3
015d55d
4be9ec7
 
c567e40
c81bbb9
c567e40
35e5b86
4b8158a
 
 
 
 
 
 
 
35e5b86
76f11b6
35e5b86
 
 
217fd17
c41db0a
217fd17
c41db0a
35e5b86
 
 
 
 
 
 
339157b
35e5b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb95dc6
35e5b86
 
bf3a8f9
35e5b86
 
 
5d4b558
bf3a8f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e086c8c
bf3a8f9
 
 
 
5ff3852
bd29ef5
9637b10
 
8e8061c
4aa5f86
6021f49
c14fc38
 
6034779
adc8fa2
5d4b558
015d55d
4a7aaea
015d55d
 
 
 
 
4a7aaea
adc8fa2
c567e40
 
 
 
 
217fd17
 
c567e40
217fd17
c567e40
c81bbb9
 
 
 
 
221622c
217fd17
350dd14
c734fd3
c41db0a
4b8158a
c41db0a
 
 
 
 
 
 
4b8158a
 
 
 
c41db0a
4b8158a
 
c41db0a
35e5b86
e681ef8
35e5b86
4b8158a
 
 
 
 
 
 
 
adc8fa2
 
 
 
 
 
 
 
8f5efcd
 
e681ef8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import gradio as gr
import hopsworks
import joblib
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import shap
from sklearn.pipeline import make_pipeline
import seaborn as sns

feature_names = ["Age", "BMI", "HbA1c", "Blood Glucose"]

project = hopsworks.login(project="SonyaStern_Lab1")
fs = project.get_feature_store()

print("trying to dl model")
mr = project.get_model_registry()
model = mr.get_model("diabetes_gan_model", version=2)
model_dir = model.download()
model = joblib.load(model_dir + "/diabetes_gan_model.pkl")
print("Model downloaded")

diabetes_fg = fs.get_feature_group(name="diabetes_gan", version=1)
query = diabetes_fg.select_all()
# feature_view = fs.get_or_create_feature_view(name="diabetes",
feature_view = fs.get_or_create_feature_view(
    name="diabetes_gan",
    version=1,
    description="Read from Diabetes dataset",
    labels=["diabetes"],
    query=query,
)
diabetes_df = pd.DataFrame(diabetes_fg.read())

with gr.Blocks() as demo:
    with gr.Row():
        gr.HTML(value="<h1 style='text-align: center;'>Diabetes prediction</h1>")
    with gr.Row():
        with gr.Column():
            age_input = gr.Number(label="age")
            bmi_input = gr.Slider(10, 100, label="bmi", info="Body Mass Index")
            hba1c_input = gr.Slider(
                3.5, 9, label="hba1c_level", info="Glycated Haemoglobin"
            )
            blood_glucose_input = gr.Slider(
                80, 300, label="blood_glucose_level", info="Blood Glucose Level"
            )
            existent_info_input = gr.Radio(
                ["yes", "no", "Don't know"],
                label="Do you already know if you have diabetes? (This will not be used for the prediction)",
            )
            consent_input = gr.Radio(
                ["accept", "decline"],
                label="I consent that my personal data will be saved and potentially be used for the model training",
            )
            btn = gr.Button("Submit")
        with gr.Column():
            with gr.Row():
                output = gr.Text(label="Model prediction")
            with gr.Row():
                mean_plot = gr.Plot()
    with gr.Row():
        with gr.Accordion("See model explanability", open=False):
            with gr.Row():
                with gr.Column():
                    waterfall_plot = gr.Plot()
                with gr.Column():
                    summary_plot = gr.Plot()
            with gr.Row():
                with gr.Column():
                    importance_plot = gr.Plot()
                with gr.Column():
                    decision_plot = gr.Plot()

    def submit_inputs(
        age_input,
        bmi_input,
        hba1c_input,
        blood_glucose_input,
        existent_info_input,
        consent_input,
    ):
        df = pd.DataFrame(
            [[age_input, float(bmi_input), hba1c_input, blood_glucose_input]],
            columns=["age", "bmi", "hba1c_level", "blood_glucose_level"],
        )
        res = model.predict(df)
        if res == [0]:
            res_str = "the model prediction is: You don't have diabetes"
        elif res == [1]:
            res_str = "the model prediction is: You have diabetes"
        mean_for_age = diabetes_df[
            (diabetes_df["diabetes"] == 0) & (diabetes_df["age"] == age_input)
        ].mean()

        print(
            "your bmi is:", bmi_input, "the mean for ur age is :", mean_for_age["bmi"]
        )
        categories = ["BMI", "HbA1c", "Blood Level"]
        fig, ax = plt.subplots()
        bar_width = 0.35
        indices = np.arange(len(categories))
        ax.bar(
            indices,
            [
                mean_for_age.bmi,
                mean_for_age.hba1c_level,
                mean_for_age.blood_glucose_level,
            ],
            bar_width,
            label="Reference",
            color="b",
            alpha=0.7,
        )
        ax.bar(
            indices + bar_width,
            [bmi_input, hba1c_input, blood_glucose_input],
            bar_width,
            label="User",
            color="r",
            alpha=0.7,
        )
        ax.legend()
        ax.set_xlabel("Variables")
        ax.set_ylabel("Values")
        ax.set_title("Comparison with average non-diabetic values for your age")
        ax.set_xticks(indices + bar_width / 2)
        ax.set_xticklabels(categories)

        ## explainability plots
        rf_classifier = model.named_steps["randomforestclassifier"]
        transformer_pipeline = make_pipeline(
            *[
                step
                for name, step in model.named_steps.items()
                if name != "randomforestclassifier"
            ]
        )
        transformed_df = transformer_pipeline.transform(df)

        # Generate the SHAP waterfall plot for fig2
        explainer = shap.TreeExplainer(rf_classifier)
        shap_values = explainer.shap_values(
            transformed_df
        )  # Compute SHAP values directly on the DataFrame
        predicted_class = rf_classifier.predict(transformed_df)[0]
        shap_values_for_predicted_class = shap_values[predicted_class]

        # Select the SHAP values for the first instance and the positive class
        shap_explanation = shap.Explanation(
            values=shap_values_for_predicted_class[0],
            base_values=explainer.expected_value[predicted_class],
            data=df.iloc[0],
            feature_names=["age", "bmi", "hba1c", "glucose"],
        )

        fig2 = plt.figure(figsize=(3, 3))  # Create a new figure for SHAP plot
        fig2.tight_layout()
        plt.gca().set_position((0, 0, 1, 1))
        plt.title("SHAP Waterfall Plot")  # Optionally set a title for the SHAP plot
        plt.tight_layout()
        plt.tick_params(axis="y", labelsize=3)
        shap.waterfall_plot(shap_explanation)

        fig3 = plt.figure(figsize=(3, 3))
        plt.title("SHAP Summary Plot")
        shap.summary_plot(
            shap_values,
            features=transformed_df,
            feature_names=["age", "bmi", "hba1c", "glucose"],
        )

        fig4 = plt.figure(figsize=(4, 3))
        feature_importances = rf_classifier.feature_importances_
        plt.title("Feature Importances")
        sns.barplot(x=feature_importances, y=["age", "bmi", "hba1c", "glucose"])

        fig5 = plt.figure(figsize=(3, 3))
        fig5.tight_layout()
        plt.gca().set_position((0, 0, 1, 1))
        plt.title("SHAP Interaction Plot")
        plt.tight_layout()
        shap.decision_plot(
            explainer.expected_value[predicted_class],
            shap_values_for_predicted_class,
            df.iloc[0],
        )

        ## save user's data in hopsworks
        if consent_input == "accept":
            print("user consented to save their data, now trying to save to hopsworks")
            user_data_fg = fs.get_or_create_feature_group(
                name="diabetes_user_data",
                version=1,
                primary_key=[
                    "age",
                    "bmi",
                    "hba1c_level",
                    "blood_glucose_level",
                    "diabetes",
                ],
                description="Submitted user data",
            )
            user_data_df = df.copy()
            user_data_df["diabetes"] = existent_info_input
            user_data_df["model_prediction"] = res[0]
            user_data_fg.insert(user_data_df)
            print("inserted new user data to hopsworks", user_data_df)
        return res_str, fig, fig2, fig3, fig4, fig5

    btn.click(
        submit_inputs,
        inputs=[
            age_input,
            bmi_input,
            hba1c_input,
            blood_glucose_input,
            existent_info_input,
            consent_input,
        ],
        outputs=[
            output,
            mean_plot,
            waterfall_plot,
            summary_plot,
            importance_plot,
            decision_plot,
        ],
    )

demo.launch()