manisha56 commited on
Commit
d863f1f
Β·
verified Β·
1 Parent(s): 95af023

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -48
app.py CHANGED
@@ -1,59 +1,79 @@
 
 
1
  import pickle
2
  import pandas as pd
3
  import shap
4
  import gradio as gr
 
5
  import matplotlib.pyplot as plt
6
 
7
  # Load the model from disk
8
- loaded_model = pickle.load(open("h37_xgb.pkl", 'rb'))
9
 
10
  # Setup SHAP
11
  explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
12
 
13
- # Map the Generation selection to numeric values (5 for Millennials, 6 for Gen Z)
14
- generation_mapping = {
15
- "Millennials": 5,
16
- "Gen Z": 6
17
- }
18
 
19
- def process_generation(selected_value):
20
- return f"You selected: {selected_value}"
 
 
21
 
22
- # Create the main function for server
23
- def main_func(Generation, Tenure, Engage2, WorkEnv3, GM3, WellBeing, SupportiveGM,
24
- ManagementLevel=2, RecommendToOthers=4, ChainScale=3, HotelRegion=2):
25
-
26
- # Convert selected Generation string to the corresponding numeric value
27
- Generation_numeric = generation_mapping[Generation]
28
-
29
- # Create the dataframe with the mapped generation value
30
- new_row = pd.DataFrame.from_dict({'Generation': Generation_numeric, 'Tenure': Tenure, 'ManagementLevel': ManagementLevel,
31
- 'Engage2': Engage2, 'RecommendToOthers': RecommendToOthers, 'WorkEnv3': WorkEnv3,
32
- 'GM3': GM3, 'ChainScale': ChainScale, 'WellBeing': WellBeing, 'SupportiveGM': SupportiveGM,
33
- 'HotelRegion': HotelRegion}, orient='index').transpose()
34
-
 
 
 
 
 
 
 
 
 
 
 
35
  prob = loaded_model.predict_proba(new_row)
36
- shap_values = explainer(new_row)
37
 
 
 
 
 
 
38
  # Generate SHAP plot
39
- plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
40
-
41
  plt.tight_layout()
42
  local_plot = plt.gcf()
43
- plt.rcParams['figure.figsize'] = 6, 4
44
  plt.close()
45
-
46
- return {"Leave": float(prob[0][0]), "Stay": 1 - float(prob[0][0])}, local_plot
47
 
48
  # Create the UI
49
  title = "**Employee Turnover Predictor & Interpreter** πŸͺ"
50
  description1 = """
51
- This app takes six inputs about employees' satisfaction with different aspects of their work (such as work-life balance, ...) and predicts whether the employee intends to stay with the employer or leave. There are two outputs from the app: 1- the predicted probability of stay or leave, 2- Shapley's force-plot which visualizes the extent to which each factor impacts the stay/leave prediction.
 
 
 
52
  """
53
 
54
- description2 = """
55
- To use the app, click on one of the examples, or adjust the values of the six employee satisfaction factors, and click on Analyze. ✨
56
- """
57
 
58
  with gr.Blocks(title=title) as demo:
59
  gr.Markdown(f"## {title}")
@@ -61,31 +81,31 @@ with gr.Blocks(title=title) as demo:
61
  gr.Markdown("""---""")
62
  gr.Markdown(description2)
63
  gr.Markdown("""---""")
 
64
  with gr.Row():
65
  with gr.Column():
66
- # Dropdown to select Generation (Millennials or Gen Z)
67
- Generation = gr.Dropdown(label="Select Generation", choices=["Millennials", "Gen Z"], value="Millennials")
68
- Engage2 = gr.Slider(label="Engage2", minimum=1, maximum=5, value=4, step=.1)
69
- Tenure = gr.Slider(label="Tenure", minimum=1, maximum=5, value=4, step=.1)
70
- WorkEnv3 = gr.Slider(label="WorkEnv3", minimum=1, maximum=5, value=4, step=.1)
71
- GM3 = gr.Slider(label="GM3", minimum=1, maximum=5, value=4, step=.1)
72
- WellBeing = gr.Slider(label="Well Being", minimum=1, maximum=5, value=4, step=.1)
73
- SupportiveGM = gr.Slider(label="SupportiveGM", minimum=1, maximum=5, value=4, step=.1)
 
 
 
 
74
  submit_btn = gr.Button("Analyze")
 
75
  with gr.Column(visible=True, scale=1, min_width=600) as output_col:
76
  label = gr.Label(label="Predicted Label")
77
- local_plot = gr.Plot(label='Shap:')
78
 
79
- # When the button is clicked, process the inputs
80
  submit_btn.click(
81
  main_func,
82
- [Generation, Engage2, Tenure, WorkEnv3, GM3, WellBeing, SupportiveGM],
83
  [label, local_plot], api_name="Employee_Turnover"
84
  )
85
-
86
- gr.Markdown("### Click on any of the examples below to see how it works:")
87
- gr.Examples([["Gen Z", 4, 4, 4, 5, 5], ["Millennials", 4, 5, 4, 4, 4]],
88
- [Generation, Engage2, Tenure, WorkEnv3, GM3, WellBeing, SupportiveGM],
89
- [label, local_plot], main_func, cache_examples=True)
90
 
91
- demo.launch( )
 
1
+ #New code post Reza meeting - takes out the All option
2
+
3
  import pickle
4
  import pandas as pd
5
  import shap
6
  import gradio as gr
7
+ import numpy as np
8
  import matplotlib.pyplot as plt
9
 
10
  # Load the model from disk
11
+ loaded_model = pickle.load(open("h42_xgb.pkl", 'rb'))
12
 
13
  # Setup SHAP
14
  explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
15
 
16
+ # Mapping categorical values to numerical
17
+ generation_mapping = {"Millennials": 5, "Gen Z": 6}
18
+ chain_scale_mapping = {"Upper Midscale": 3, "Upscale": 4, "Upper Upscale": 5}
19
+ region_mapping = {"Northeast": 1, "Midwest": 2, "South": 3, "West": 4}
20
+ recommend_mapping = {"1": 1, "2": 2, "3": 3, "4": 4, "5": 5}
21
 
22
+ generation_choices = list(generation_mapping.keys())
23
+ chain_scale_choices = list(chain_scale_mapping.keys())
24
+ region_choices = list(region_mapping.keys())
25
+ recommend_choices = list(recommend_mapping.keys())
26
 
27
+ # Define main function for prediction
28
+ def main_func(Generation, Engage2, RecommendToOthers, Voice, Merit, Workload, WellBeing, SupportiveGM, ChainScale, HotelRegion, ManagementLevel=2):
29
+ # Map inputs to numerical values
30
+ gen = generation_mapping[Generation]
31
+ cs = chain_scale_mapping[ChainScale]
32
+ reg = region_mapping[HotelRegion]
33
+ recommend = recommend_mapping[RecommendToOthers]
34
+
35
+ # Create DataFrame for model prediction
36
+ new_row = pd.DataFrame.from_dict({
37
+ 'Generation': gen,
38
+ 'ManagementLevel': ManagementLevel,
39
+ 'Engage2': Engage2,
40
+ 'RecommendToOthers': recommend,
41
+ 'Voice': Voice,
42
+ 'Merit': Merit,
43
+ 'Workload': Workload,
44
+ 'ChainScale': cs,
45
+ 'WellBeing': WellBeing,
46
+ 'SupportiveGM': SupportiveGM,
47
+ 'HotelRegion': reg
48
+ }, orient = 'index').transpose()
49
+
50
+ # Predict probabilities
51
  prob = loaded_model.predict_proba(new_row)
 
52
 
53
+ # Compute SHAP values
54
+ shap_values = explainer(new_row)
55
+ selected_features = ["Engage2", "RecommendToOthers", "Voice", "Merit", "Workload", "WellBeing", "SupportiveGM"]
56
+ shap_values_filtered = shap_values[:, selected_features]
57
+
58
  # Generate SHAP plot
59
+ shap.plots.bar(shap_values_filtered[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
 
60
  plt.tight_layout()
61
  local_plot = plt.gcf()
62
+ plt.rcParams['figure.figsize'] = (6, 4)
63
  plt.close()
64
+
65
+ return {"Leave": float(prob[0][0]), "Stay": float(prob[0][0])}, local_plot
66
 
67
  # Create the UI
68
  title = "**Employee Turnover Predictor & Interpreter** πŸͺ"
69
  description1 = """
70
+ This app takes six inputs about employees' satisfaction with different aspects of their work and predicts whether the employee intends to stay with the employer or leave.
71
+ It provides:
72
+ 1. The predicted probability of staying or leaving.
73
+ 2. A SHAP interpretation plot showing feature importance.
74
  """
75
 
76
+ description2 = """Adjust the values of different employee satisfaction factors and click "Analyze" to get a prediction."""
 
 
77
 
78
  with gr.Blocks(title=title) as demo:
79
  gr.Markdown(f"## {title}")
 
81
  gr.Markdown("""---""")
82
  gr.Markdown(description2)
83
  gr.Markdown("""---""")
84
+
85
  with gr.Row():
86
  with gr.Column():
87
+ Generation = gr.Dropdown(label="Select Generation 🧍🏽", choices=generation_choices, value="Millennials")
88
+ ChainScale = gr.Dropdown(label="Select Chain Scale 🧳", choices=chain_scale_choices, value="Upscale")
89
+ HotelRegion = gr.Dropdown(label="Select Hotel Region 🏨", choices=region_choices, value="Midwest")
90
+ RecommendToOthers = gr.Dropdown(label="Recommend to Others", choices=recommend_choices, value="4")
91
+
92
+ Engage2 = gr.Slider(label="Engagement (Engage2)", minimum=1, maximum=5, value=4, step=0.1)
93
+ Voice = gr.Slider(label="Voice", minimum=1, maximum=5, value=4, step=0.1)
94
+ Merit = gr.Slider(label="Merit", minimum=1, maximum=5, value=4, step=0.1)
95
+ Workload = gr.Slider(label="Workload", minimum=1, maximum=5, value=4, step=0.1)
96
+ WellBeing = gr.Slider(label="Well-being", minimum=1, maximum=5, value=4, step=0.1)
97
+ SupportiveGM = gr.Slider(label="Supportive GM", minimum=1, maximum=5, value=4, step=0.1)
98
+
99
  submit_btn = gr.Button("Analyze")
100
+
101
  with gr.Column(visible=True, scale=1, min_width=600) as output_col:
102
  label = gr.Label(label="Predicted Label")
103
+ local_plot = gr.Plot(label="SHAP Impact")
104
 
 
105
  submit_btn.click(
106
  main_func,
107
+ [Generation, Engage2, RecommendToOthers, Voice, Merit, Workload, WellBeing, SupportiveGM, ChainScale, HotelRegion],
108
  [label, local_plot], api_name="Employee_Turnover"
109
  )
 
 
 
 
 
110
 
111
+ demo.launch( )