PinkPonyClub / app.py
gkenwashington's picture
Update app.py
a370df5 verified
import pickle
import pandas as pd
import shap
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors
# Load the model from disk
loaded_model = pickle.load(open("huggingface_final.sav", 'rb'))
# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
# Hilton Color Palette
hilton_blue = "#0057B8"
hilton_gold = "#A28F65"
hilton_gray = "#B1B3B3"
# Custom Colormap for SHAP
hilton_cmap = mcolors.LinearSegmentedColormap.from_list("HiltonCmap", [hilton_gold, hilton_blue])
def main_func(Employee, WorkEnvironment, Voice, LearningDevelopment, WellBeing, SupportiveGM):
new_row = pd.DataFrame.from_dict({
'WorkEnvironment': WorkEnvironment,
'Voice': Voice,
'LearningDevelopment': LearningDevelopment,
'WellBeing': WellBeing,
'SupportiveGM': SupportiveGM
}, orient='index').transpose()
# Make prediction
prob = loaded_model.predict_proba(new_row)
shap_values = explainer(new_row)
# Generate SHAP plot
plt.figure(figsize=(10, 5)) # Adjust the size as needed
shap.plots.bar(shap_values[0], max_display=6, show=False)
# Apply Hilton style
plt.xticks(color="black")
plt.yticks(color="black")
plt.xlabel("Feature", fontsize=12, color="black")
plt.ylabel("SHAP Value", fontsize=12, color="black")
plt.title("SHAP Analysis - Feature Importance", fontsize=14, color=hilton_blue)
plt.tight_layout()
plot = plt.gcf()
plt.close()
return {"Leave": float(prob[0][0]), "Stay": 1 - float(prob[0][0])}, plot
# Custom CSS to Style Sliders
custom_css = """
Body {
background: ("AppPicture.jpg") no-repeat center center fixed;
background-size: cover;
}
/* Hilton-Themed Sliders */
input[type="range"] {
accent-color: #0057B8 !important; /* Hilton Blue */
background: #A28F65 !important; /* Hilton Gold */
}
/* Slider Track */
input[type="range"]::-webkit-slider-runnable-track {
background: #0057B8 !important;
height: 6px; /* Adjust track thickness */
border-radius: 5px;
}
/* Slider Thumb */
input[type="range"]::-webkit-slider-thumb {
background: #A28F65 !important; /* Hilton Gold */
border: 2px solid #B1B3B3 !important; /* Hilton Gray */
width: 16px;
height: 16px;
border-radius: 50%;
}
"""
# Create the UI
title = "**MSBA Team 2 Employee Intent to Stay Predictor**"
description1 = """
This app takes five inputs about employees' satisfaction with different aspects of their work and predicts whether the employee intends to stay with the employer or leave.
There are two outputs from the app: 1) the predicted probability of stay or leave, 2) SHAP's bar plot which visualizes the extent to which each factor impacts the stay/leave prediction.
"""
description2 = """
To use the app, adjust the values of the five employee satisfaction factors, and click on Analyze.
"""
with gr.Blocks(title=title, css=custom_css) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Markdown(description2)
gr.Markdown("""---""")
with gr.Row():
with gr.Column():
WorkEnvironment = gr.Slider(label="Work Environment Score", minimum=1, maximum=5, value=4, step=0.1)
Voice = gr.Slider(label="Voice Score", minimum=1, maximum=5, value=4, step=0.1)
LearningDevelopment = gr.Slider(label="Learning Development Score", minimum=1, maximum=5, value=4, step=0.1)
WellBeing = gr.Slider(label="Well Being Score", minimum=1, maximum=5, value=4, step=0.1)
SupportiveGM = gr.Slider(label="Supportive GM Score", minimum=1, maximum=5, value=4, step=0.1)
submit_btn = gr.Button("Analyze")
with gr.Column(visible=True, scale=1, min_width=600) as output_col:
label = gr.Label(label="Predicted Label")
local_plot = gr.Plot(label='SHAP Analysis')
submit_btn.click(
main_func,
[WorkEnvironment, Voice, LearningDevelopment, WellBeing, SupportiveGM],
[label, local_plot], api_name="Employee_Turnover"
)
gr.Markdown("### Adjust the sliders above and click 'Analyze' to see the prediction and SHAP analysis.")
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples([["Median Negative",3.8,3.5,3.6,3.9,3.7], ["Goal Negative",4.8,3.5,3.6,4.9,4.7]],
[gr.Textbox(label="Employee"), WorkEnvironment, Voice, LearningDevelopment, WellBeing, SupportiveGM],
[label,local_plot], main_func, cache_examples=True)
demo.launch(share=True)