import pickle import pandas as pd import shap from shap.plots._force_matplotlib import draw_additive_plot import gradio as gr import numpy as np import matplotlib.pyplot as plt # load the model from disk loaded_model = pickle.load(open("cdc_diabetes_health_indicators.pkl", 'rb')) # Setup SHAP explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS. age_d = {"18-24":1,"25-29":2,"30-34":3,"35-39":4,"40-44":5,"45-49":6,"50-54":7,"55-59":8,"60-64":9,"65-69":10,"70-74":11,"75-79":12,"80 and older":13} education_d = {"Never attended school or only kindergarten":1,"Grades 1 through 8 (Elementary)":2,"Grades 9 through 11 (Some high school)":3,"Grade 12 or GED (High school graduate)":4,"College 1 year to 3 years (Some college or technical school)":5,"College 4 years or more (College graduated)":6} income_d = {"Less than $10,000":1,"Less than $16,250":2,"Less than $22,500":3,"Less than $28,750":4,"Less than $35,000":5,"Less than $48,500":6,"Less than $61,500":7,"$75,000 or more":8} # Create the main function for server def main_func(HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income): new_row = pd.DataFrame.from_dict({ 'HighBP':HighBP,'HighChol':HighChol,'CholCheck':CholCheck, 'BMI':BMI, 'Smoker':Smoker,'Stroke':Stroke,'HeartDiseaseorAttack':HeartDiseaseorAttack, 'PhysActivity':PhysActivity,'Fruits':Fruits,'Veggies':Veggies,'HvyAlcoholConsump':HvyAlcoholConsump, 'AnyHealthcare':AnyHealthcare, 'NoDocbcCost':NoDocbcCost, 'GenHlth':GenHlth, 'MentHlth': MentHlth, 'PhysHlth':PhysHlth, 'DiffWalk':DiffWalk, 'Sex':Sex, 'Age':Age, 'Education':Education, 'Income':Income}, orient = 'index').transpose() prob = loaded_model.predict_proba(new_row) shap_values = explainer(new_row) # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False) # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False) plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False) plt.tight_layout() local_plot = plt.gcf() plt.close() return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot # Create the UI title = "**Diabetes Predictor & Interpreter** " description1 = """This app takes info from subjects and predicts their diabetes likelihood. Do not use for medical diagnosis.""" description2 = """ To use the app, pick the most applicable option for you, or adjust the values of the factors, and click on Analyze. """ with gr.Blocks(title=title) as demo: gr.Markdown(f"## {title}") gr.Markdown(description1) gr.Markdown("""---""") gr.Markdown(description2) gr.Markdown("""---""") HighBP = gr. Radio (["No", "Yes"], label = "Do you have high blood pressure?", type = "index") HighChol = gr.Radio(["No", "Yes"], label = "Do you have high cholesterol?", type = "index") CholCheck = gr.Radio(["No", "Yes"], label = "Did you have your cholesterol check within 5 years?", type = "index") BMI = gr.Slider(label="bmi Score", minimum=12, maximum=98, value=49, step=1) Smoke = gr.Radio(["No", "Yes"], label = "Have you smoked at least 100 cigarettes in your entire life? Note: 5 packs = 100 cigarettes", type = "index") Stroke = gr.Radio(["No", "Yes"], label = "Did you ever had a stroke?", type = "index") HeartDiseaseorAttack = gr.Radio(["No", "Yes"], label = "Do you have either a Coronary Heart Disease(CHD) or a Myocardial Infarction(heart attack)?", type = "index") PhysActivity = gr.Radio(["No", "Yes"], label = "Do you do physical activity in the past 30 days in (not including your job)?", type = "index") Fruits = gr.Radio(["No", "Yes"], label = "Do you eat fruits 1 or more times per day?", type = "index") Veggies = gr.Radio(["No", "Yes"], label = "Do you eat vegetables 1 or more times per day?", type = "index") HvyAlcoholConsump = gr.Radio(["No", "Yes"], label = "Are you a heavy drinker? Note: Adult men = more than 14 drinks per week Adult Women = more than 7 drinks per week", type = "index") AnyHealthcare = gr.Radio(["No", "Yes"], label = "Do you have any kind of healthcare coverage, including health insurance, prepaid plans such as HMO, etc.?", type = "index") NoDocbcCost = gr.Radio(["No", "Yes"], label = "Was there a time in the past 12 months when you needed to see a doctor but could not because of cost?", type = "index") GenHlth = gr.Slider(label="How would rate your general health? Note: 1 = excellent, 2 = very good, 3 = good, 4 = fair, 5 = poor", minimum=1, maximum=5, value=3, step=1) MentHlth = gr.Slider(label="How many days was your mental health not good in the past 30 days? This includes stress, depression, and problems with emotions.", minimum=0, maximum=30, value=15, step=1) PhysHlth = gr.Slider(label="How many days was your physical health not good in the past 30 days? This includes physical illness and injuries.", minimum=0, maximum=30, value=15, step=1) DiffWalk = gr.Radio(["No", "Yes"], label = "Do you have serious difficulty walking or climbing stairs?", type = "index") Sex = gr.Radio(["Male", "Female"], label = "Sex?", type = "index") Age_dropdown = gr.Dropdown(["18-24", "25-29", "30-34", "35-39", "40-44", "45-49", "50-54", "55-59", "60-64", "65-69", "70-74", "75-79", "80 and older"], label="Age") Age = gr.Input(component=Age_dropdown, type="index") Education_dropdown = gr.Dropdown(["Never attended school or only kindergarten", "Grades 1 through 8 (Elementary)", "Grades 9 through 11 (Some high school)", "Grade 12 or GED (High school graduate)", "College 1 year to 3 years (Some college or technical school)", "College 4 years or more (College graduated)"], label="Education Level") Education = gr.Input(component=Education_dropdown, type="index") Income_dropdown = gr.Dropdown(["Less than $10,000", "Less than $16,250", "Less than $22,500", "Less than $28,750", "Less than $35,000", "Less than $48,500", "Less than $61,500", "$75,000 or more"], label="Income Level") Income = gr.Input(component=Income_dropdown, type="index") #Age = age_d[gr.Dropdown(["18-24","25-29","30-34","35-39","40-44","45-49","50-54","55-59","60-64","65-69","70-74","75-79","80 and older"],label="Age")] #Education = education_d[gr.Dropdown(["Never attended school or only kindergarten","Grades 1 through 8 (Elementary)","Grades 9 through 11 (Some high school)","Grade 12 or GED (High school graduate)","College 1 year to 3 years (Some college or technical school)","College 4 years or more (College graduated)"],label="Education Level")] #Income = income_d[gr.Dropdown(["Less than $10,000","Less than $16,250","Less than $22,500","Less than $28,750","Less than $35,000","Less than $48,500","Less than $61,500","$75,000 or more"],label="Income Level")] #submit_btn = gr.Button("Analyze") #Age = age_d[gr.Dropdown([1,2,3,4,5,6,7,8,9,10,11,12,13],label="Age",type = "index")] #Education = education_d[gr.Dropdown([1,2,3,4,5,6],label="Education Level",type = "index")] #Income = income_d[gr.Dropdown([1,2,3,4,5,6,7,8],label="Income Level",type = "index")] with gr.Column(visible=True) as output_col: label = gr.Label(label = "Predicted Label") local_plot = gr.Plot(label = 'Shap:') submit_btn.click( main_func, [HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income], [label,local_plot], api_name="Diabetes_Predictor" ) gr.Markdown("### Click on any of the examples below to see how it works:") gr.Examples([["No","No","No",23,"No","No","No","Yes","Yes","Yes","No","Yes","No","No",1,2,4,"No","Female","65-69","College 4 years or more (College graduated)","$75,000 or more"], ["Yes","Yes","Yes",32,"Yes","Yes","Yes","No","No","No","Yes","No","Yes",5,15,20,"Yes","Male","50-54","Grade 12 or GED (High school graduate)","Less than $35,000"]], [HighBP, HighChol, CholCheck, BMI, Smoker, Stroke, HeartDiseaseorAttack, PhysActivity, Fruits, Veggies, HvyAlcoholConsump, AnyHealthcare, NoDocbcCost, GenHlth, MentHlth, PhysHlth, DiffWalk, Sex, Age, Education, Income], [label,local_plot], main_func, cache_examples=True) demo.launch()