File size: 6,579 Bytes
66495d3
 
 
 
 
 
 
1efdb42
d9fb06f
0777f12
 
61ee496
d9fb06f
66495d3
264aa55
66495d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12ef8a7
66495d3
 
 
a51075f
 
3c5b3ea
a51075f
66495d3
1d3f222
66495d3
 
0338011
 
66495d3
c39e389
 
66495d3
 
436e90c
7c82d4a
a4d0f08
1d3f222
294a2e8
66495d3
d56bd08
896cbd8
d56bd08
1d3f222
d56bd08
 
0338011
66495d3
 
 
 
 
70e6274
66495d3
 
 
4f026cd
66495d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c39e389
70e6274
66495d3
 
 
 
 
 
 
 
 
 
 
 
 
 
c39e389
896cbd8
 
c39e389
66495d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt


theme = gr.themes.Default(primary_hue="blue").set(
    background_fill_primary="#D3D3D3",
    block_background_fill="#D3D3D3",
)


# load the model from disk
loaded_model = pickle.load(open("heart_xgbV2.pkl", 'rb'))

# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.

gender_dict = {"Male":0,"Female":1}
cp_dict = {"Typical Angina":0, "Atypical Angina":1, "Non-Anginal":2, "Asymptomatic":3}
fbs_dict = {"Yes":1,"No":0}
exng_dict = {"Yes":1,"No":0}
restecg_dict = {"Normal":0, "Having ST-T abnormality":1, "Showing probable or definite left ventricular hypertrophy by Estes' Criteria":2}
thall_dict = {"Fixed Defect":1, "Normal Blood Flow":2, "Reversible Defect":3}
slp_dict = {"Upsloping":1, "Flat":2, "Downsloping":3}

# Create the main function for server
def main_func(age, sex, cp, trtbps, chol, fbs, restecg,thalachh,exng,oldpeak,slp,caa,thall):
    new_row = pd.DataFrame.from_dict({'age':age,'sex':gender_dict[sex],
              'cp':cp_dict[cp],'trtbps':trtbps,'chol':chol,
              'fbs':fbs_dict[fbs], 'restecg':restecg_dict[restecg], 'thalachh':thalachh, 'exng':exng_dict[exng],
                                     'oldpeak':oldpeak,'slp':slp_dict[slp],'caa':caa,'thall':thall_dict[thall]}, 
                                     orient = 'index').transpose()
    
    prob = loaded_model.predict_proba(new_row)
    
    shap_values = explainer(new_row)
    # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
    # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
    plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)

    plt.tight_layout()
    local_plot = plt.gcf()
    plt.close()
    
    return {"Lower Chance of a Heart Attack": float(prob[0][0]), "Higher Chance of a Heart Attack": 1-float(prob[0][0])}, local_plot

# Create the UI
title = "**Heart Attack Predictor & Interpreter** πŸͺ"
description1 = "This app takes info from subjects and predicts their heart attack likelihood."

description_notmedical="**Do not use for medical diagnosis.**"


description2 = "**Fill all the options** or no result will be generated!!!**"


description3 = "To use the app, please fill all the options, and click on Analyze. 🀞"


descriptionExamples = "If you would like to see how the model works, please scroll down and try one of the examples!"


##Pinak
with gr.Blocks(title=title, theme=theme) as demo:

    gr.Markdown("<span style='color: #FF0000;font-size: 20px'> **Heart Attack Predictor & Interpreter** πŸͺ</span>")
    gr.Markdown("""---""")
    gr.Markdown("<span style='font-size: 20px;'> **Do not use for medical diagnosis.**")
    gr.Markdown("""---""")
    gr.Markdown("<span style='font-size: 16px;'> If you would like to see how the model works, please scroll down and try one of the examples!")
    gr.Markdown("""---""")
    gr.Markdown("<span style='font-size: 16px;'> This app takes info from subjects and predicts their heart attack likelihood.")
    gr.Markdown("""---""")
    gr.Markdown("<span style='font-size: 16px;'> To use the app, please fill in all the options, and click on Analyze. 🀞")
    gr.Markdown("<span style='font-size: 16px;'> **Fill all the options or no result will be generated!!!**")
    gr.Markdown("""---""")
    
    with gr.Row():
        with gr.Column():
            age = gr.Number(label="What is your age?", value=40)
        with gr.Column():
            slp = gr.Dropdown(["Upsloping", "Flat", "Downsloping"], label="What was the slope of the peak exercise ST segment?")
            
    with gr.Row(): 
        with gr.Column():
            sex = gr.Radio(["Female", "Male"], label = "What is your sex?")
            cp = gr.Radio(["Typical Angina", "Atypical Angina", "Non-Anginal", "Asymptomatic"], label = "What kind of chest pain is it?")
        with gr.Column():
            restecg = gr.Radio(["Normal", "Having ST-T abnormality", "Showing probable or definite left ventricular hypertrophy by Estes' Criteria"], 
                                  label = "What is your resting ECG result?")
            
    with gr.Row():
        with gr.Column():
            fbs = gr.Radio(["Yes", "No"], label = "Is your fasting Blood Sugar >120 mg/dl?")
        with gr.Column():
            exng = gr.Radio(["Yes", "No"], label = "Do you have Exercise Induced Angina?")
    with gr.Row():
        with gr.Column():
            caa = gr.Radio([1, 2, 3], label="How many vessels were colored by the fluoroscopy?")
            
        with gr.Column():
            thall = gr.Radio(["Fixed Defect", "Normal Blood Flow", "Reversible Defect"],  label="What is your Thalassemia condition?")
            
    with gr.Row():
        with gr.Column():
            trtbps = gr.Slider(label = "What is your resting blood Pressure (in mm Hg)?", minimum = 10, maximum = 250, value = 100, step = 1)
            
        with gr.Column():
            chol = gr.Slider(label = "What is your cholesterol in mg/dl (via BMI sensor)?", minimum = 30, maximum = 300, value = 180, step = 1)
    with gr.Row():      
        with gr.Column():
            oldpeak = gr.Slider(label = "What was the ST depression induced by exercise relative to rest?", minimum = 0, maximum = 6.2, step = 0.1)
        with gr.Column():
            thalachh = gr.Slider(label="What is your maximum heart rate?", minimum = 60, maximum = 250, value=100, step = 1)
            

    with gr.Row():
            submit_btn = gr.Button("Analyze")

##Do not need to touch    
    with gr.Column(visible=True) as output_col:
            label = gr.Label(label = "Predicted Label")
            local_plot = gr.Plot(label = 'Shap:')
    
    submit_btn.click(
            main_func,
            [age, sex, cp, trtbps, chol, fbs, restecg,thalachh,exng,oldpeak,slp,caa,thall],
            [label,local_plot], api_name="Heart_Predictor"
        )

    gr.Examples([[24, "Male", "Typical Angina", 130, 150, "Yes", "Having ST-T abnormality",170, "Yes", 5.1, "Flat", 2, "Normal Blood Flow"],
                 [59, "Female", "Non-Anginal", 150, 170, "No", "Showing probable or definite left ventricular hypertrophy by Estes' Criteria",190, "No", 6, "Upsloping", 3, "Reversible Defect"]], [age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall], [label,local_plot], main_func, cache_examples=True)

demo.launch()