Yassirabbas75 commited on
Commit
bc495ba
1 Parent(s): 05c4fca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import gradio as gr
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import pandas as pd
7
+ import shap
8
+ import xgboost as xgb
9
+ from datasets import load_dataset
10
+
11
+ matplotlib.use("Agg")
12
+
13
+ dataset = load_dataset("scikit-learn/adult-census-income")
14
+
15
+ X_train = dataset["train"].to_pandas()
16
+ _ = X_train.pop("fnlwgt")
17
+ _ = X_train.pop("race")
18
+
19
+ y_train = X_train.pop("income")
20
+ y_train = (y_train == ">50K").astype(int)
21
+ categorical_columns = [
22
+ "workclass",
23
+ "education",
24
+ "marital.status",
25
+ "occupation",
26
+ "relationship",
27
+ "sex",
28
+ "native.country",
29
+ ]
30
+ X_train = X_train.astype({col: "category" for col in categorical_columns})
31
+
32
+
33
+ data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
34
+ model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
35
+ explainer = shap.TreeExplainer(model)
36
+
37
+
38
+ def predict(*args):
39
+ df = pd.DataFrame([args], columns=X_train.columns)
40
+ df = df.astype({col: "category" for col in categorical_columns})
41
+ pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))
42
+ return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])}
43
+
44
+
45
+ def interpret(*args):
46
+ df = pd.DataFrame([args], columns=X_train.columns)
47
+ df = df.astype({col: "category" for col in categorical_columns})
48
+ shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
49
+ scores_desc = list(zip(shap_values[0], X_train.columns))
50
+ scores_desc = sorted(scores_desc)
51
+ fig_m = plt.figure(tight_layout=True)
52
+ plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
53
+ plt.title("Feature Shap Values")
54
+ plt.ylabel("Shap Value")
55
+ plt.xlabel("Feature")
56
+ plt.tight_layout()
57
+ return fig_m
58
+
59
+
60
+ unique_class = sorted(X_train["workclass"].unique())
61
+ unique_education = sorted(X_train["education"].unique())
62
+ unique_marital_status = sorted(X_train["marital.status"].unique())
63
+ unique_relationship = sorted(X_train["relationship"].unique())
64
+ unique_occupation = sorted(X_train["occupation"].unique())
65
+ unique_sex = sorted(X_train["sex"].unique())
66
+ unique_country = sorted(X_train["native.country"].unique())
67
+
68
+ with gr.Blocks() as demo:
69
+ gr.Markdown("""
70
+ **Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
71
+ """)
72
+ with gr.Row():
73
+ with gr.Column():
74
+ age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
75
+ work_class = gr.Dropdown(
76
+ label="Workclass",
77
+ choices=unique_class,
78
+ value=lambda: random.choice(unique_class),
79
+ )
80
+ education = gr.Dropdown(
81
+ label="Education Level",
82
+ choices=unique_education,
83
+ value=lambda: random.choice(unique_education),
84
+ )
85
+ years = gr.Slider(
86
+ label="Years of schooling",
87
+ minimum=1,
88
+ maximum=16,
89
+ step=1,
90
+ randomize=True,
91
+ )
92
+ marital_status = gr.Dropdown(
93
+ label="Marital Status",
94
+ choices=unique_marital_status,
95
+ value=lambda: random.choice(unique_marital_status),
96
+ )
97
+ occupation = gr.Dropdown(
98
+ label="Occupation",
99
+ choices=unique_occupation,
100
+ value=lambda: random.choice(unique_occupation),
101
+ )
102
+ relationship = gr.Dropdown(
103
+ label="Relationship Status",
104
+ choices=unique_relationship,
105
+ value=lambda: random.choice(unique_relationship),
106
+ )
107
+ sex = gr.Dropdown(
108
+ label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex)
109
+ )
110
+ capital_gain = gr.Slider(
111
+ label="Capital Gain",
112
+ minimum=0,
113
+ maximum=100000,
114
+ step=500,
115
+ randomize=True,
116
+ )
117
+ capital_loss = gr.Slider(
118
+ label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True
119
+ )
120
+ hours_per_week = gr.Slider(
121
+ label="Hours Per Week Worked", minimum=1, maximum=99, step=1
122
+ )
123
+ country = gr.Dropdown(
124
+ label="Native Country",
125
+ choices=unique_country,
126
+ value=lambda: random.choice(unique_country),
127
+ )
128
+ with gr.Column():
129
+ label = gr.Label()
130
+ plot = gr.Plot()
131
+ with gr.Row():
132
+ predict_btn = gr.Button(value="Predict")
133
+ interpret_btn = gr.Button(value="Explain")
134
+ predict_btn.click(
135
+ predict,
136
+ inputs=[
137
+ age,
138
+ work_class,
139
+ education,
140
+ years,
141
+ marital_status,
142
+ occupation,
143
+ relationship,
144
+ sex,
145
+ capital_gain,
146
+ capital_loss,
147
+ hours_per_week,
148
+ country,
149
+ ],
150
+ outputs=[label],
151
+ )
152
+ interpret_btn.click(
153
+ interpret,
154
+ inputs=[
155
+ age,
156
+ work_class,
157
+ education,
158
+ years,
159
+ marital_status,
160
+ occupation,
161
+ relationship,
162
+ sex,
163
+ capital_gain,
164
+ capital_loss,
165
+ hours_per_week,
166
+ country,
167
+ ],
168
+ outputs=[plot],
169
+ )
170
+
171
+ demo.launch()