Jurk06 commited on
Commit
1dbf177
1 Parent(s): 888c446

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+ import matplotlib.pyplot as plt
4
+ import pandas as pd
5
+ import shap
6
+ import xgboost as xgb
7
+ from datasets import load_dataset
8
+
9
+
10
+ dataset = load_dataset("scikit-learn/adult-census-income")
11
+ X_train = dataset["train"].to_pandas()
12
+ _ = X_train.pop("fnlwgt")
13
+ _ = X_train.pop("race")
14
+ y_train = X_train.pop("income")
15
+ y_train = (y_train == ">50K").astype(int)
16
+ categorical_columns = [
17
+ "workclass",
18
+ "education",
19
+ "marital.status",
20
+ "occupation",
21
+ "relationship",
22
+ "sex",
23
+ "native.country",
24
+ ]
25
+ X_train = X_train.astype({col: "category" for col in categorical_columns})
26
+ data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)
27
+ model = xgb.train(params={"objective": "binary:logistic"}, dtrain=data)
28
+ explainer = shap.TreeExplainer(model)
29
+
30
+ def predict(*args):
31
+ df = pd.DataFrame([args], columns=X_train.columns)
32
+ df = df.astype({col: "category" for col in categorical_columns})
33
+ pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))
34
+ return {">50K": float(pos_pred[0]), "<=50K": 1 - float(pos_pred[0])}
35
+
36
+
37
+ def interpret(*args):
38
+ df = pd.DataFrame([args], columns=X_train.columns)
39
+ df = df.astype({col: "category" for col in categorical_columns})
40
+ shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))
41
+ scores_desc = list(zip(shap_values[0], X_train.columns))
42
+ scores_desc = sorted(scores_desc)
43
+ fig_m = plt.figure(tight_layout=True)
44
+ plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])
45
+ plt.title("Feature Shap Values")
46
+ plt.ylabel("Shap Value")
47
+ plt.xlabel("Feature")
48
+ plt.tight_layout()
49
+ return fig_m
50
+
51
+
52
+ unique_class = sorted(X_train["workclass"].unique())
53
+ unique_education = sorted(X_train["education"].unique())
54
+ unique_marital_status = sorted(X_train["marital.status"].unique())
55
+ unique_relationship = sorted(X_train["relationship"].unique())
56
+ unique_occupation = sorted(X_train["occupation"].unique())
57
+ unique_sex = sorted(X_train["sex"].unique())
58
+ unique_country = sorted(X_train["native.country"].unique())
59
+
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("""
62
+ **Income Classification with XGBoost 💰**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).
63
+ """)
64
+ with gr.Row():
65
+ with gr.Column():
66
+ age = gr.Slider(label="Age", minimum=17, maximum=90, step=1, randomize=True)
67
+ work_class = gr.Dropdown(
68
+ label="Workclass",
69
+ choices=unique_class,
70
+ value=lambda: random.choice(unique_class),
71
+ )
72
+ education = gr.Dropdown(
73
+ label="Education Level",
74
+ choices=unique_education,
75
+ value=lambda: random.choice(unique_education),
76
+ )
77
+ years = gr.Slider(
78
+ label="Years of schooling",
79
+ minimum=1,
80
+ maximum=16,
81
+ step=1,
82
+ randomize=True,
83
+ )
84
+ marital_status = gr.Dropdown(
85
+ label="Marital Status",
86
+ choices=unique_marital_status,
87
+ value=lambda: random.choice(unique_marital_status),
88
+ )
89
+ occupation = gr.Dropdown(
90
+ label="Occupation",
91
+ choices=unique_occupation,
92
+ value=lambda: random.choice(unique_occupation),
93
+ )
94
+ relationship = gr.Dropdown(
95
+ label="Relationship Status",
96
+ choices=unique_relationship,
97
+ value=lambda: random.choice(unique_relationship),
98
+ )
99
+ sex = gr.Dropdown(
100
+ label="Sex", choices=unique_sex, value=lambda: random.choice(unique_sex)
101
+ )
102
+ capital_gain = gr.Slider(
103
+ label="Capital Gain",
104
+ minimum=0,
105
+ maximum=100000,
106
+ step=500,
107
+ randomize=True,
108
+ )
109
+ capital_loss = gr.Slider(
110
+ label="Capital Loss", minimum=0, maximum=10000, step=500, randomize=True
111
+ )
112
+ hours_per_week = gr.Slider(
113
+ label="Hours Per Week Worked", minimum=1, maximum=99, step=1
114
+ )
115
+ country = gr.Dropdown(
116
+ label="Native Country",
117
+ choices=unique_country,
118
+ value=lambda: random.choice(unique_country),
119
+ )
120
+ with gr.Column():
121
+ label = gr.Label()
122
+ plot = gr.Plot()
123
+ with gr.Row():
124
+ predict_btn = gr.Button(value="Predict")
125
+ interpret_btn = gr.Button(value="Explain")
126
+ predict_btn.click(
127
+ predict,
128
+ inputs=[
129
+ age,
130
+ work_class,
131
+ education,
132
+ years,
133
+ marital_status,
134
+ occupation,
135
+ relationship,
136
+ sex,
137
+ capital_gain,
138
+ capital_loss,
139
+ hours_per_week,
140
+ country,
141
+ ],
142
+ outputs=[label],
143
+ )
144
+ interpret_btn.click(
145
+ interpret,
146
+ inputs=[
147
+ age,
148
+ work_class,
149
+ education,
150
+ years,
151
+ marital_status,
152
+ occupation,
153
+ relationship,
154
+ sex,
155
+ capital_gain,
156
+ capital_loss,
157
+ hours_per_week,
158
+ country,
159
+ ],
160
+ outputs=[plot],
161
+ )
162
+
163
+ demo.launch()