freddyaboulton HF staff commited on
Commit
288ef76
1 Parent(s): 5942545

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+ import shap
7
+ import xgboost as xgb
8
+ from datasets import load_dataset
9
+
10
+
11
+ matplotlib.use("Agg")
12
+ dataset = load_dataset("scikit-learn/adult-census-income")
13
+ X_train = dataset["train"].to_pandas()
14
+ _ = X_train.pop("fnlwgt")
15
+ _ = X_train.pop("race")
16
+ y_train = X_train.pop("income")
17
+ y_train = (y_train == ">50K").astype(int)
18
+ categorical_columns = [
19
+ "workclass",
20
+ "education",
21
+ "marital.status",
22
+ "occupation",
23
+ "relationship",
24
+ "sex",
25
+ "native.country",
26
+ ]
27
+ X_train = X_train.astype({col: "category" for col in categorical_columns})
28
+ data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
29
+ model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
30
+ explainer = shap.TreeExplainer(model)
31
+
32
+ def predict(*args):
33
+ df = pd.DataFrame([args], columns=X_train.columns)
34
+ df = df.astype({col: "category" for col in categorical_columns})
35
+ pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))
36
+ return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])}
37
+
38
+
39
+ def interpret(*args):
40
+ df = pd.DataFrame([args], columns=X_train.columns)
41
+ df = df.astype({col: "category" for col in categorical_columns})
42
+ shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
43
+ scores_desc = list(zip(shap_values[0], X_train.columns))
44
+ scores_desc = sorted(scores_desc)
45
+ fig_m = plt.figure(tight_layout=True)
46
+ plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
47
+ plt.title("Feature Shap Values")
48
+ plt.ylabel("Shap Value")
49
+ plt.xlabel("Feature")
50
+ plt.tight_layout()
51
+ return fig_m
52
+
53
+
54
+ unique_class = sorted(X_train["workclass"].unique())
55
+ unique_education = sorted(X_train["education"].unique())
56
+ unique_marital_status = sorted(X_train["marital.status"].unique())
57
+ unique_relationship = sorted(X_train["relationship"].unique())
58
+ unique_occupation = sorted(X_train["occupation"].unique())
59
+ unique_sex = sorted(X_train["sex"].unique())
60
+ unique_country = sorted(X_train["native.country"].unique())
61
+
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("""
64
+ **Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
65
+ """)
66
+ with gr.Row():
67
+ with gr.Column():
68
+ age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
69
+ work_class = gr.Dropdown(
70
+ label="Workclass",
71
+ choices=unique_class,
72
+ value=lambda: random.choice(unique_class),
73
+ )
74
+ education = gr.Dropdown(
75
+ label="Education Level",
76
+ choices=unique_education,
77
+ value=lambda: random.choice(unique_education),
78
+ )
79
+ years = gr.Slider(
80
+ label="Years of schooling",
81
+ minimum=1,
82
+ maximum=16,
83
+ step=1,
84
+ randomize=True,
85
+ )
86
+ marital_status = gr.Dropdown(
87
+ label="Marital Status",
88
+ choices=unique_marital_status,
89
+ value=lambda: random.choice(unique_marital_status),
90
+ )
91
+ occupation = gr.Dropdown(
92
+ label="Occupation",
93
+ choices=unique_occupation,
94
+ value=lambda: random.choice(unique_occupation),
95
+ )
96
+ relationship = gr.Dropdown(
97
+ label="Relationship Status",
98
+ choices=unique_relationship,
99
+ value=lambda: random.choice(unique_relationship),
100
+ )
101
+ sex = gr.Dropdown(
102
+ label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex)
103
+ )
104
+ capital_gain = gr.Slider(
105
+ label="Capital Gain",
106
+ minimum=0,
107
+ maximum=100000,
108
+ step=500,
109
+ randomize=True,
110
+ )
111
+ capital_loss = gr.Slider(
112
+ label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True
113
+ )
114
+ hours_per_week = gr.Slider(
115
+ label="Hours Per Week Worked", minimum=1, maximum=99, step=1
116
+ )
117
+ country = gr.Dropdown(
118
+ label="Native Country",
119
+ choices=unique_country,
120
+ value=lambda: random.choice(unique_country),
121
+ )
122
+ with gr.Column():
123
+ label = gr.Label()
124
+ plot = gr.Plot()
125
+ with gr.Row():
126
+ predict_btn = gr.Button(value="Predict")
127
+ interpret_btn = gr.Button(value="Explain")
128
+ predict_btn.click(
129
+ predict,
130
+ inputs=[
131
+ age,
132
+ work_class,
133
+ education,
134
+ years,
135
+ marital_status,
136
+ occupation,
137
+ relationship,
138
+ sex,
139
+ capital_gain,
140
+ capital_loss,
141
+ hours_per_week,
142
+ country,
143
+ ],
144
+ outputs=[label],
145
+ )
146
+ interpret_btn.click(
147
+ interpret,
148
+ inputs=[
149
+ age,
150
+ work_class,
151
+ education,
152
+ years,
153
+ marital_status,
154
+ occupation,
155
+ relationship,
156
+ sex,
157
+ capital_gain,
158
+ capital_loss,
159
+ hours_per_week,
160
+ country,
161
+ ],
162
+ outputs=[plot],
163
+ )
164
+
165
+ demo.launch()