awacke1 commited on
Commit
d79a601
1 Parent(s): c068b13

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +166 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ import pandas as pd
6
+ import shap
7
+ import xgboost as xgb
8
+ from datasets import load_dataset
9
+
10
+
11
+ matplotlib.use("Agg")
12
+ dataset = load_dataset("scikit-learn/adult-census-income")
13
+ X_train = dataset["train"].to_pandas()
14
+ _ = X_train.pop("fnlwgt")
15
+ _ = X_train.pop("race")
16
+ y_train = X_train.pop("income")
17
+ y_train = (y_train == ">50K").astype(int)
18
+ categorical_columns = [
19
+ "workclass",
20
+ "education",
21
+ "marital.status",
22
+ "occupation",
23
+ "relationship",
24
+ "sex",
25
+ "native.country",
26
+ ]
27
+ X_train = X_train.astype({col: "category" for col in categorical_columns})
28
+ data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
29
+ model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
30
+ explainer = shap.TreeExplainer(model)
31
+
32
+ def predict(*args):
33
+ df = pd.DataFrame([args], columns=X_train.columns)
34
+ df = df.astype({col: "category" for col in categorical_columns})
35
+ pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))
36
+ return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])}
37
+
38
+
39
+ def interpret(*args):
40
+ df = pd.DataFrame([args], columns=X_train.columns)
41
+ df = df.astype({col: "category" for col in categorical_columns})
42
+ shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
43
+ scores_desc = list(zip(shap_values[0], X_train.columns))
44
+ scores_desc = sorted(scores_desc)
45
+ fig_m = plt.figure(tight_layout=True)
46
+ plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
47
+ plt.title("Feature Shap Values")
48
+ plt.ylabel("Feature")
49
+ plt.xlabel("Shap Value")
50
+ plt.tight_layout()
51
+ return fig_m
52
+
53
+
54
+ unique_class = sorted(X_train["workclass"].unique())
55
+ unique_education = sorted(X_train["education"].unique())
56
+ unique_marital_status = sorted(X_train["marital.status"].unique())
57
+ unique_relationship = sorted(X_train["relationship"].unique())
58
+ unique_occupation = sorted(X_train["occupation"].unique())
59
+ unique_sex = sorted(X_train["sex"].unique())
60
+ unique_country = sorted(X_train["native.country"].unique())
61
+
62
+ with gr.Blocks() as demo:
63
+ gr.Markdown("""
64
+ **Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
65
+ """)
66
+ with gr.Row():
67
+ with gr.Column():
68
+ age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
69
+ work_class = gr.Dropdown(
70
+ label="Workclass",
71
+ choices=unique_class,
72
+ value=lambda: random.choice(unique_class),
73
+ )
74
+ education = gr.Dropdown(
75
+ label="Education Level",
76
+ choices=unique_education,
77
+ value=lambda: random.choice(unique_education),
78
+ )
79
+ years = gr.Slider(
80
+ label="Years of schooling",
81
+ minimum=1,
82
+ maximum=16,
83
+ step=1,
84
+ randomize=True,
85
+ )
86
+ marital_status = gr.Dropdown(
87
+ label="Marital Status",
88
+ choices=unique_marital_status,
89
+ value=lambda: random.choice(unique_marital_status),
90
+ )
91
+ occupation = gr.Dropdown(
92
+ label="Occupation",
93
+ choices=unique_occupation,
94
+ value=lambda: random.choice(unique_occupation),
95
+ )
96
+ with gr.Column():
97
+ relationship = gr.Dropdown(
98
+ label="Relationship Status",
99
+ choices=unique_relationship,
100
+ value=lambda: random.choice(unique_relationship),
101
+ )
102
+ sex = gr.Dropdown(
103
+ label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex)
104
+ )
105
+ capital_gain = gr.Slider(
106
+ label="Capital Gain",
107
+ minimum=0,
108
+ maximum=100000,
109
+ step=500,
110
+ randomize=True,
111
+ )
112
+ capital_loss = gr.Slider(
113
+ label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True
114
+ )
115
+ hours_per_week = gr.Slider(
116
+ label="Hours Per Week Worked", minimum=1, maximum=99, step=1
117
+ )
118
+ country = gr.Dropdown(
119
+ label="Native Country",
120
+ choices=unique_country,
121
+ value=lambda: random.choice(unique_country),
122
+ )
123
+ with gr.Column():
124
+ label = gr.Label()
125
+ plot = gr.Plot()
126
+ with gr.Row():
127
+ predict_btn = gr.Button(value="Predict")
128
+ interpret_btn = gr.Button(value="Explain")
129
+ predict_btn.click(
130
+ predict,
131
+ inputs=[
132
+ age,
133
+ work_class,
134
+ education,
135
+ years,
136
+ marital_status,
137
+ occupation,
138
+ relationship,
139
+ sex,
140
+ capital_gain,
141
+ capital_loss,
142
+ hours_per_week,
143
+ country,
144
+ ],
145
+ outputs=[label],
146
+ )
147
+ interpret_btn.click(
148
+ interpret,
149
+ inputs=[
150
+ age,
151
+ work_class,
152
+ education,
153
+ years,
154
+ marital_status,
155
+ occupation,
156
+ relationship,
157
+ sex,
158
+ capital_gain,
159
+ capital_loss,
160
+ hours_per_week,
161
+ country,
162
+ ],
163
+ outputs=[plot],
164
+ )
165
+
166
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ shap
3
+ xgboost
4
+ pandas
5
+ datasets