Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,59 +1,79 @@
|
|
|
|
|
|
1 |
import pickle
|
2 |
import pandas as pd
|
3 |
import shap
|
4 |
import gradio as gr
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
|
7 |
# Load the model from disk
|
8 |
-
loaded_model = pickle.load(open("
|
9 |
|
10 |
# Setup SHAP
|
11 |
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
|
12 |
|
13 |
-
#
|
14 |
-
generation_mapping = {
|
15 |
-
|
16 |
-
|
17 |
-
}
|
18 |
|
19 |
-
|
20 |
-
|
|
|
|
|
21 |
|
22 |
-
#
|
23 |
-
def main_func(Generation,
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
prob = loaded_model.predict_proba(new_row)
|
36 |
-
shap_values = explainer(new_row)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
38 |
# Generate SHAP plot
|
39 |
-
|
40 |
-
|
41 |
plt.tight_layout()
|
42 |
local_plot = plt.gcf()
|
43 |
-
plt.rcParams['figure.figsize'] = 6, 4
|
44 |
plt.close()
|
45 |
-
|
46 |
-
return {"Leave": float(prob[0][0]), "Stay":
|
47 |
|
48 |
# Create the UI
|
49 |
title = "**Employee Turnover Predictor & Interpreter** πͺ"
|
50 |
description1 = """
|
51 |
-
This app takes six inputs about employees' satisfaction with different aspects of their work
|
|
|
|
|
|
|
52 |
"""
|
53 |
|
54 |
-
description2 = """
|
55 |
-
To use the app, click on one of the examples, or adjust the values of the six employee satisfaction factors, and click on Analyze. β¨
|
56 |
-
"""
|
57 |
|
58 |
with gr.Blocks(title=title) as demo:
|
59 |
gr.Markdown(f"## {title}")
|
@@ -61,31 +81,31 @@ with gr.Blocks(title=title) as demo:
|
|
61 |
gr.Markdown("""---""")
|
62 |
gr.Markdown(description2)
|
63 |
gr.Markdown("""---""")
|
|
|
64 |
with gr.Row():
|
65 |
with gr.Column():
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
74 |
submit_btn = gr.Button("Analyze")
|
|
|
75 |
with gr.Column(visible=True, scale=1, min_width=600) as output_col:
|
76 |
label = gr.Label(label="Predicted Label")
|
77 |
-
local_plot = gr.Plot(label=
|
78 |
|
79 |
-
# When the button is clicked, process the inputs
|
80 |
submit_btn.click(
|
81 |
main_func,
|
82 |
-
[Generation, Engage2,
|
83 |
[label, local_plot], api_name="Employee_Turnover"
|
84 |
)
|
85 |
-
|
86 |
-
gr.Markdown("### Click on any of the examples below to see how it works:")
|
87 |
-
gr.Examples([["Gen Z", 4, 4, 4, 5, 5], ["Millennials", 4, 5, 4, 4, 4]],
|
88 |
-
[Generation, Engage2, Tenure, WorkEnv3, GM3, WellBeing, SupportiveGM],
|
89 |
-
[label, local_plot], main_func, cache_examples=True)
|
90 |
|
91 |
-
demo.launch( )
|
|
|
1 |
+
#New code post Reza meeting - takes out the All option
|
2 |
+
|
3 |
import pickle
|
4 |
import pandas as pd
|
5 |
import shap
|
6 |
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
9 |
|
10 |
# Load the model from disk
|
11 |
+
loaded_model = pickle.load(open("h42_xgb.pkl", 'rb'))
|
12 |
|
13 |
# Setup SHAP
|
14 |
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
|
15 |
|
16 |
+
# Mapping categorical values to numerical
|
17 |
+
generation_mapping = {"Millennials": 5, "Gen Z": 6}
|
18 |
+
chain_scale_mapping = {"Upper Midscale": 3, "Upscale": 4, "Upper Upscale": 5}
|
19 |
+
region_mapping = {"Northeast": 1, "Midwest": 2, "South": 3, "West": 4}
|
20 |
+
recommend_mapping = {"1": 1, "2": 2, "3": 3, "4": 4, "5": 5}
|
21 |
|
22 |
+
generation_choices = list(generation_mapping.keys())
|
23 |
+
chain_scale_choices = list(chain_scale_mapping.keys())
|
24 |
+
region_choices = list(region_mapping.keys())
|
25 |
+
recommend_choices = list(recommend_mapping.keys())
|
26 |
|
27 |
+
# Define main function for prediction
|
28 |
+
def main_func(Generation, Engage2, RecommendToOthers, Voice, Merit, Workload, WellBeing, SupportiveGM, ChainScale, HotelRegion, ManagementLevel=2):
|
29 |
+
# Map inputs to numerical values
|
30 |
+
gen = generation_mapping[Generation]
|
31 |
+
cs = chain_scale_mapping[ChainScale]
|
32 |
+
reg = region_mapping[HotelRegion]
|
33 |
+
recommend = recommend_mapping[RecommendToOthers]
|
34 |
+
|
35 |
+
# Create DataFrame for model prediction
|
36 |
+
new_row = pd.DataFrame.from_dict({
|
37 |
+
'Generation': gen,
|
38 |
+
'ManagementLevel': ManagementLevel,
|
39 |
+
'Engage2': Engage2,
|
40 |
+
'RecommendToOthers': recommend,
|
41 |
+
'Voice': Voice,
|
42 |
+
'Merit': Merit,
|
43 |
+
'Workload': Workload,
|
44 |
+
'ChainScale': cs,
|
45 |
+
'WellBeing': WellBeing,
|
46 |
+
'SupportiveGM': SupportiveGM,
|
47 |
+
'HotelRegion': reg
|
48 |
+
}, orient = 'index').transpose()
|
49 |
+
|
50 |
+
# Predict probabilities
|
51 |
prob = loaded_model.predict_proba(new_row)
|
|
|
52 |
|
53 |
+
# Compute SHAP values
|
54 |
+
shap_values = explainer(new_row)
|
55 |
+
selected_features = ["Engage2", "RecommendToOthers", "Voice", "Merit", "Workload", "WellBeing", "SupportiveGM"]
|
56 |
+
shap_values_filtered = shap_values[:, selected_features]
|
57 |
+
|
58 |
# Generate SHAP plot
|
59 |
+
shap.plots.bar(shap_values_filtered[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
|
|
|
60 |
plt.tight_layout()
|
61 |
local_plot = plt.gcf()
|
62 |
+
plt.rcParams['figure.figsize'] = (6, 4)
|
63 |
plt.close()
|
64 |
+
|
65 |
+
return {"Leave": float(prob[0][0]), "Stay": float(prob[0][0])}, local_plot
|
66 |
|
67 |
# Create the UI
|
68 |
title = "**Employee Turnover Predictor & Interpreter** πͺ"
|
69 |
description1 = """
|
70 |
+
This app takes six inputs about employees' satisfaction with different aspects of their work and predicts whether the employee intends to stay with the employer or leave.
|
71 |
+
It provides:
|
72 |
+
1. The predicted probability of staying or leaving.
|
73 |
+
2. A SHAP interpretation plot showing feature importance.
|
74 |
"""
|
75 |
|
76 |
+
description2 = """Adjust the values of different employee satisfaction factors and click "Analyze" to get a prediction."""
|
|
|
|
|
77 |
|
78 |
with gr.Blocks(title=title) as demo:
|
79 |
gr.Markdown(f"## {title}")
|
|
|
81 |
gr.Markdown("""---""")
|
82 |
gr.Markdown(description2)
|
83 |
gr.Markdown("""---""")
|
84 |
+
|
85 |
with gr.Row():
|
86 |
with gr.Column():
|
87 |
+
Generation = gr.Dropdown(label="Select Generation π§π½", choices=generation_choices, value="Millennials")
|
88 |
+
ChainScale = gr.Dropdown(label="Select Chain Scale π§³", choices=chain_scale_choices, value="Upscale")
|
89 |
+
HotelRegion = gr.Dropdown(label="Select Hotel Region π¨", choices=region_choices, value="Midwest")
|
90 |
+
RecommendToOthers = gr.Dropdown(label="Recommend to Others", choices=recommend_choices, value="4")
|
91 |
+
|
92 |
+
Engage2 = gr.Slider(label="Engagement (Engage2)", minimum=1, maximum=5, value=4, step=0.1)
|
93 |
+
Voice = gr.Slider(label="Voice", minimum=1, maximum=5, value=4, step=0.1)
|
94 |
+
Merit = gr.Slider(label="Merit", minimum=1, maximum=5, value=4, step=0.1)
|
95 |
+
Workload = gr.Slider(label="Workload", minimum=1, maximum=5, value=4, step=0.1)
|
96 |
+
WellBeing = gr.Slider(label="Well-being", minimum=1, maximum=5, value=4, step=0.1)
|
97 |
+
SupportiveGM = gr.Slider(label="Supportive GM", minimum=1, maximum=5, value=4, step=0.1)
|
98 |
+
|
99 |
submit_btn = gr.Button("Analyze")
|
100 |
+
|
101 |
with gr.Column(visible=True, scale=1, min_width=600) as output_col:
|
102 |
label = gr.Label(label="Predicted Label")
|
103 |
+
local_plot = gr.Plot(label="SHAP Impact")
|
104 |
|
|
|
105 |
submit_btn.click(
|
106 |
main_func,
|
107 |
+
[Generation, Engage2, RecommendToOthers, Voice, Merit, Workload, WellBeing, SupportiveGM, ChainScale, HotelRegion],
|
108 |
[label, local_plot], api_name="Employee_Turnover"
|
109 |
)
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
+
demo.launch( )
|