freddyaboulton HF staff commited on
Commit
a02aab8
β€’
1 Parent(s): 5235215
Files changed (3) hide show
  1. README.md +3 -3
  2. app.py +176 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Xgboost Income Prediction With Explainability
3
- emoji: πŸ“ˆ
4
- colorFrom: yellow
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 3.1.7
8
  app_file: app.py
 
1
  ---
2
  title: Xgboost Income Prediction With Explainability
3
+ emoji: πŸ’°
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 3.1.7
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import pandas as pd
7
+ import shap
8
+ import xgboost as xgb
9
+ from datasets import load_dataset
10
+
11
+ matplotlib.use("Agg")
12
+
13
+ dataset = load_dataset("scikit-learn/adult-census-income")
14
+
15
+ X_train = dataset["train"].to_pandas()
16
+ _ = X_train.pop("fnlwgt")
17
+ _ = X_train.pop("race")
18
+
19
+ y_train = X_train.pop("income")
20
+ y_train = (y_train == ">50K").astype(int)
21
+ categorical_columns = [
22
+ "workclass",
23
+ "education",
24
+ "marital.status",
25
+ "occupation",
26
+ "relationship",
27
+ "sex",
28
+ "native.country",
29
+ ]
30
+ X_train = X_train.astype({col: "category" for col in categorical_columns})
31
+
32
+
33
+ data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
34
+ model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
35
+ explainer = shap.TreeExplainer(model)
36
+
37
+
38
+ def predict(*args):
39
+ df = pd.DataFrame([args], columns=X_train.columns)
40
+ df = df.astype({col: "category" for col in categorical_columns})
41
+ pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))
42
+ return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])}
43
+
44
+
45
+ def interpret(*args):
46
+ df = pd.DataFrame([args], columns=X_train.columns)
47
+ df = df.astype({col: "category" for col in categorical_columns})
48
+ shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
49
+ scores_desc = list(zip(shap_values[0], X_train.columns))
50
+ scores_desc = sorted(scores_desc)
51
+ fig_m = plt.figure(tight_layout=True)
52
+ plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
53
+ plt.title("Feature Shap Values")
54
+ plt.ylabel("Shap Value")
55
+ plt.xlabel("Feature")
56
+ plt.tight_layout()
57
+ return fig_m
58
+
59
+
60
+ unique_class = sorted(X_train["workclass"].unique())
61
+ unique_education = sorted(X_train["education"].unique())
62
+ unique_marital_status = sorted(X_train["marital.status"].unique())
63
+ unique_relationship = sorted(X_train["relationship"].unique())
64
+ unique_occupation = sorted(X_train["occupation"].unique())
65
+ unique_sex = sorted(X_train["sex"].unique())
66
+ unique_country = sorted(X_train["native.country"].unique())
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("""
70
+ ## Income Classification with XGBoost πŸ’°
71
+
72
+ This example shows how to load data from the hugging face hub to train an XGBoost classifier and
73
+ demo the predictions with gradio.
74
+
75
+ The source is [here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability).
76
+ """)
77
+ with gr.Row():
78
+ with gr.Column():
79
+ age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
80
+ work_class = gr.Dropdown(
81
+ label="Workclass",
82
+ choices=unique_class,
83
+ value=lambda: random.choice(unique_class),
84
+ )
85
+ education = gr.Dropdown(
86
+ label="Education Level",
87
+ choices=unique_education,
88
+ value=lambda: random.choice(unique_education),
89
+ )
90
+ years = gr.Slider(
91
+ label="Years of schooling",
92
+ minimum=1,
93
+ maximum=16,
94
+ step=1,
95
+ randomize=True,
96
+ )
97
+ marital_status = gr.Dropdown(
98
+ label="Marital Status",
99
+ choices=unique_marital_status,
100
+ value=lambda: random.choice(unique_marital_status),
101
+ )
102
+ occupation = gr.Dropdown(
103
+ label="Occupation",
104
+ choices=unique_occupation,
105
+ value=lambda: random.choice(unique_occupation),
106
+ )
107
+ relationship = gr.Dropdown(
108
+ label="Relationship Status",
109
+ choices=unique_relationship,
110
+ value=lambda: random.choice(unique_relationship),
111
+ )
112
+ sex = gr.Dropdown(
113
+ label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex)
114
+ )
115
+ capital_gain = gr.Slider(
116
+ label="Capital Gain",
117
+ minimum=0,
118
+ maximum=100000,
119
+ step=500,
120
+ randomize=True,
121
+ )
122
+ capital_loss = gr.Slider(
123
+ label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True
124
+ )
125
+ hours_per_week = gr.Slider(
126
+ label="Hours Per Week Worked", minimum=1, maximum=99, step=1
127
+ )
128
+ country = gr.Dropdown(
129
+ label="Native Country",
130
+ choices=unique_country,
131
+ value=lambda: random.choice(unique_country),
132
+ )
133
+ with gr.Column():
134
+ label = gr.Label()
135
+ plot = gr.Plot()
136
+ with gr.Row():
137
+ predict_btn = gr.Button(value="Predict")
138
+ interpret_btn = gr.Button(value="Interpret")
139
+ predict_btn.click(
140
+ predict,
141
+ inputs=[
142
+ age,
143
+ work_class,
144
+ education,
145
+ years,
146
+ marital_status,
147
+ occupation,
148
+ relationship,
149
+ sex,
150
+ capital_gain,
151
+ capital_loss,
152
+ hours_per_week,
153
+ country,
154
+ ],
155
+ outputs=[label],
156
+ )
157
+ interpret_btn.click(
158
+ interpret,
159
+ inputs=[
160
+ age,
161
+ work_class,
162
+ education,
163
+ years,
164
+ marital_status,
165
+ occupation,
166
+ relationship,
167
+ sex,
168
+ capital_gain,
169
+ capital_loss,
170
+ hours_per_week,
171
+ country,
172
+ ],
173
+ outputs=[plot],
174
+ )
175
+
176
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ shap
3
+ xgboost
4
+ pandas
5
+ datasets