Spaces:
Runtime error
Runtime error
Added all app files for gradio demo for Binary classification on structured data using GRN-VSN model
Browse files- app.py +59 -0
- constants.py +88 -0
- predict.py +107 -0
- preprocess.py +65 -0
- requirements.txt +2 -0
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from .constants import CSV_HEADER, NUMERIC_FEATURE_NAMES, CATEGORICAL_FEATURES_WITH_VOCABULARY, NUMBER_INPUT_COLS
|
3 |
+
from .preprocess import create_max_values_map, create_dropdown_default_values_map, create_sample_test_data
|
4 |
+
from .predict import batch_predict, user_input_predict
|
5 |
+
|
6 |
+
inputs_list = []
|
7 |
+
max_values_map = create_max_values_map()
|
8 |
+
dropdown_default_values_map = create_dropdown_default_values_map()
|
9 |
+
sample_input_df_val = create_sample_test_data()
|
10 |
+
|
11 |
+
demo = gr.Blocks()
|
12 |
+
|
13 |
+
with demo:
|
14 |
+
|
15 |
+
gr.Markdown("# **Binary Classification using Gated Residual and Variable Selection Networks** \n")
|
16 |
+
gr.Markdown("This demo demonstrates the use of Gated Residual Networks (GRN) and Variable Selection Networks (VSN), proposed by Bryan Lim et al. in <a href=\"https://arxiv.org/abs/1912.09363/\">Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting</a> for structured data classification")
|
17 |
+
gr.Markdown("Play around and see yourself 🤗 ")
|
18 |
+
|
19 |
+
with gr.Tabs():
|
20 |
+
|
21 |
+
with gr.TabItem("Predict using batch of inputs"):
|
22 |
+
gr.Markdown("**Input DataFrame** \n")
|
23 |
+
input_df = gr.Dataframe(headers=CSV_HEADER,value=samp,)
|
24 |
+
gr.Markdown("**Output DataFrame** \n")
|
25 |
+
output_df = gr.Dataframe()
|
26 |
+
gr.Markdown("**Make Predictions**")
|
27 |
+
with gr.Row():
|
28 |
+
compute_button = gr.Button("Predict")
|
29 |
+
|
30 |
+
with gr.TabItem("Tweak inputs Yourself"):
|
31 |
+
with gr.Tabs():
|
32 |
+
|
33 |
+
with gr.TabItem("Numerical Inputs"):
|
34 |
+
gr.Markdown("Set values for numerical inputs here.")
|
35 |
+
for num_variable in NUMERIC_FEATURE_NAMES:
|
36 |
+
with gr.Column():
|
37 |
+
if num_variable in NUMBER_INPUT_COLS:
|
38 |
+
numeric_input = gr.Number(label=num_variable)
|
39 |
+
else:
|
40 |
+
curr_max_val = max_values_map["max_"+num_variable]
|
41 |
+
numeric_input = gr.Slider(0,curr_max_val, label=num_variable,step=1)
|
42 |
+
inputs_list.append(numeric_input)
|
43 |
+
|
44 |
+
with gr.TabItem("Categorical Inputs"):
|
45 |
+
gr.Markdown("Choose values for categorical inputs here.")
|
46 |
+
for cat_variable in CATEGORICAL_FEATURES_WITH_VOCABULARY.keys():
|
47 |
+
with gr.Column():
|
48 |
+
categorical_input = gr.Dropdown(CATEGORICAL_FEATURES_WITH_VOCABULARY[cat_variable], label=cat_variable, value=str(dropdown_default_values_map["max_"+cat_variable]))
|
49 |
+
inputs_list.append(categorical_input)
|
50 |
+
|
51 |
+
predict_button = gr.Button("Predict")
|
52 |
+
final_output = gr.Label()
|
53 |
+
|
54 |
+
predict_button.click(user_input_predict, inputs=inputs_list, outputs=final_output)
|
55 |
+
compute_button.click(batch_predict, inputs=input_df, outputs=output_df)
|
56 |
+
gr.Markdown('\n Author: <a href=\"https://huggingface.co/shivi\">Shivalika Singh</a> <br> Based on this <a href=\"https://keras.io/examples/structured_data/classification_with_grn_and_vsn/\">Keras example</a> by <a href=\"https://www.linkedin.com/in/khalid-salama-24403144/\">Khalid Salama</a> <br> Demo Powered by this <a href=\"https://huggingface.co/shivi/classification-grn-vsn\">GRN-VSN model</a>')
|
57 |
+
|
58 |
+
|
59 |
+
demo.launch()
|
constants.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from .preprocess import load_test_data
|
3 |
+
|
4 |
+
# Column names.
|
5 |
+
CSV_HEADER = [
|
6 |
+
"age",
|
7 |
+
"class_of_worker",
|
8 |
+
"detailed_industry_recode",
|
9 |
+
"detailed_occupation_recode",
|
10 |
+
"education",
|
11 |
+
"wage_per_hour",
|
12 |
+
"enroll_in_edu_inst_last_wk",
|
13 |
+
"marital_stat",
|
14 |
+
"major_industry_code",
|
15 |
+
"major_occupation_code",
|
16 |
+
"race",
|
17 |
+
"hispanic_origin",
|
18 |
+
"sex",
|
19 |
+
"member_of_a_labor_union",
|
20 |
+
"reason_for_unemployment",
|
21 |
+
"full_or_part_time_employment_stat",
|
22 |
+
"capital_gains",
|
23 |
+
"capital_losses",
|
24 |
+
"dividends_from_stocks",
|
25 |
+
"tax_filer_stat",
|
26 |
+
"region_of_previous_residence",
|
27 |
+
"state_of_previous_residence",
|
28 |
+
"detailed_household_and_family_stat",
|
29 |
+
"detailed_household_summary_in_household",
|
30 |
+
"instance_weight",
|
31 |
+
"migration_code-change_in_msa",
|
32 |
+
"migration_code-change_in_reg",
|
33 |
+
"migration_code-move_within_reg",
|
34 |
+
"live_in_this_house_1_year_ago",
|
35 |
+
"migration_prev_res_in_sunbelt",
|
36 |
+
"num_persons_worked_for_employer",
|
37 |
+
"family_members_under_18",
|
38 |
+
"country_of_birth_father",
|
39 |
+
"country_of_birth_mother",
|
40 |
+
"country_of_birth_self",
|
41 |
+
"citizenship",
|
42 |
+
"own_business_or_self_employed",
|
43 |
+
"fill_inc_questionnaire_for_veterans_admin",
|
44 |
+
"veterans_benefits",
|
45 |
+
"weeks_worked_in_year",
|
46 |
+
"year",
|
47 |
+
"income_level",
|
48 |
+
]
|
49 |
+
|
50 |
+
# Target feature name.
|
51 |
+
TARGET_FEATURE_NAME = "income_level"
|
52 |
+
|
53 |
+
# Weight column name.
|
54 |
+
WEIGHT_COLUMN_NAME = "instance_weight"
|
55 |
+
|
56 |
+
# Numeric feature names.
|
57 |
+
NUMERIC_FEATURE_NAMES = [
|
58 |
+
"age",
|
59 |
+
"wage_per_hour",
|
60 |
+
"capital_gains",
|
61 |
+
"capital_losses",
|
62 |
+
"dividends_from_stocks",
|
63 |
+
"num_persons_worked_for_employer",
|
64 |
+
"weeks_worked_in_year",
|
65 |
+
]
|
66 |
+
|
67 |
+
##Cols which will use "Number" component of gradio for taking user input
|
68 |
+
NUMBER_INPUT_COLS = ['age', 'num_persons_worked_for_employer','weeks_worked_in_year']
|
69 |
+
|
70 |
+
test_data = load_test_data()
|
71 |
+
|
72 |
+
CATEGORICAL_FEATURES_WITH_VOCABULARY = {
|
73 |
+
feature_name: sorted([str(value) for value in list(test_data[feature_name].unique())])
|
74 |
+
for feature_name in CSV_HEADER
|
75 |
+
if feature_name
|
76 |
+
not in list(NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME, TARGET_FEATURE_NAME])
|
77 |
+
}
|
78 |
+
# All features names.
|
79 |
+
FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
|
80 |
+
CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
|
81 |
+
)
|
82 |
+
# Feature default values.
|
83 |
+
COLUMN_DEFAULTS = [
|
84 |
+
[0.0]
|
85 |
+
if feature_name in NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME, WEIGHT_COLUMN_NAME]
|
86 |
+
else ["NA"]
|
87 |
+
for feature_name in CSV_HEADER
|
88 |
+
]
|
predict.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from .preprocess import get_dataset_from_csv
|
3 |
+
from huggingface_hub import from_pretrained_keras
|
4 |
+
|
5 |
+
##Load Model
|
6 |
+
model = from_pretrained_keras("shivi/classification-grn-vsn")
|
7 |
+
|
8 |
+
def batch_predict(input_data):
|
9 |
+
"""
|
10 |
+
This function is used for fetching predictions corresponding to input_dataframe.
|
11 |
+
It outputs another dataframe containing:
|
12 |
+
1. prediction probability for each class
|
13 |
+
2. actual expected outcome for each entry in the input dataframe
|
14 |
+
"""
|
15 |
+
input_data_file = "prod_data.csv"
|
16 |
+
labels = ['Probability of Income greater than 50000',"Probability of Income less than 50000","Actual Income"]
|
17 |
+
|
18 |
+
predictions_df = pd.DataFrame(columns=labels)
|
19 |
+
|
20 |
+
input_data.to_csv(input_data_file, index=None, header=None)
|
21 |
+
|
22 |
+
prod_dataset = get_dataset_from_csv(input_data_file, shuffle=True)
|
23 |
+
|
24 |
+
pred = model.predict(prod_dataset)
|
25 |
+
|
26 |
+
for prediction, actual_gt in zip(pred, input_data['income_level'].values.tolist()):
|
27 |
+
y_pred_prob = round(prediction.flatten()[0] * 100, 2)
|
28 |
+
y_not_prob = round((1-prediction.flatten()[0]) * 100, 2)
|
29 |
+
y_pred = ">50000" if prediction.flatten()[0] > 0.5 else "<50000"
|
30 |
+
prob_scores = {labels[0]: str(y_pred_prob)+"%" , labels[1]: str(y_not_prob)+"%", labels[2]: y_pred}
|
31 |
+
predictions_df = predictions_df.append(prob_scores,ignore_index=True)
|
32 |
+
|
33 |
+
return predictions_df
|
34 |
+
|
35 |
+
|
36 |
+
def user_input_predict(age, wage, cap_gains, cap_losses, dividends, num_persons, weeks_worked_in_year,
|
37 |
+
class_of_worker, detailed_industry_recode,detailed_occupation_recode,education,
|
38 |
+
enroll_in_edu_inst_last_wk, marital_stat, major_industry_code,major_occupation_code,
|
39 |
+
race, hispanic_origin, sex, member_of_a_labor_union,reason_for_unemployment,
|
40 |
+
full_or_part_time_employment_stat, tax_filer_stat,region_of_previous_residence,
|
41 |
+
state_of_previous_residence,detailed_household_and_family_stat,detailed_household_summary_in_household,
|
42 |
+
migration_codechange_in_msa,migration_codechange_in_reg, migration_codemove_within_reg,
|
43 |
+
live_in_this_house_1_year_ago,migration_prev_res_in_sunbelt,family_members_under_18,
|
44 |
+
country_of_birth_father,country_of_birth_mother,country_of_birth_self,
|
45 |
+
citizenship,own_business_or_self_employed,fill_inc_questionnaire_for_veterans_admin,
|
46 |
+
veterans_benefits, year):
|
47 |
+
|
48 |
+
"""
|
49 |
+
This function is used for fetching model predictions based on inputs given by user on demo app
|
50 |
+
"""
|
51 |
+
|
52 |
+
input_dict = {"age": [age],
|
53 |
+
"class_of_worker": [class_of_worker],
|
54 |
+
"detailed_industry_recode": [detailed_industry_recode],
|
55 |
+
"detailed_occupation_recode": [detailed_occupation_recode],
|
56 |
+
"education":[education],
|
57 |
+
"wage_per_hour": [wage],
|
58 |
+
"enroll_in_edu_inst_last_wk": [enroll_in_edu_inst_last_wk],
|
59 |
+
"marital_stat": [marital_stat],
|
60 |
+
"major_industry_code": [major_industry_code],
|
61 |
+
"major_occupation_code": [major_occupation_code],
|
62 |
+
"race": [race],
|
63 |
+
"hispanic_origin": [hispanic_origin],
|
64 |
+
"sex": [sex],
|
65 |
+
"member_of_a_labor_union": [member_of_a_labor_union],
|
66 |
+
"reason_for_unemployment": [reason_for_unemployment],
|
67 |
+
"full_or_part_time_employment_stat": [full_or_part_time_employment_stat],
|
68 |
+
"capital_gains": [cap_gains],
|
69 |
+
"capital_losses": [cap_losses],
|
70 |
+
"dividends_from_stocks": [dividends],
|
71 |
+
"tax_filer_stat": [tax_filer_stat],
|
72 |
+
"region_of_previous_residence": [region_of_previous_residence],
|
73 |
+
"state_of_previous_residence": [state_of_previous_residence],
|
74 |
+
"detailed_household_and_family_stat": [detailed_household_and_family_stat],
|
75 |
+
"detailed_household_summary_in_household": [detailed_household_summary_in_household],
|
76 |
+
"instance_weight": [0.0],
|
77 |
+
"migration_code-change_in_msa": [migration_codechange_in_msa],
|
78 |
+
"migration_code-change_in_reg": [migration_codechange_in_reg],
|
79 |
+
"migration_code-move_within_reg": [migration_codemove_within_reg],
|
80 |
+
"live_in_this_house_1_year_ago": [live_in_this_house_1_year_ago],
|
81 |
+
"migration_prev_res_in_sunbelt": [migration_prev_res_in_sunbelt],
|
82 |
+
"num_persons_worked_for_employer": [num_persons],
|
83 |
+
"family_members_under_18": [family_members_under_18],
|
84 |
+
"country_of_birth_father": [country_of_birth_father],
|
85 |
+
"country_of_birth_mother": [country_of_birth_mother],
|
86 |
+
"country_of_birth_self": [country_of_birth_self],
|
87 |
+
"citizenship": [citizenship],
|
88 |
+
"own_business_or_self_employed": [own_business_or_self_employed],
|
89 |
+
"fill_inc_questionnaire_for_veterans_admin": [fill_inc_questionnaire_for_veterans_admin],
|
90 |
+
"veterans_benefits": [veterans_benefits],
|
91 |
+
"weeks_worked_in_year": [weeks_worked_in_year],
|
92 |
+
"year": [year],
|
93 |
+
"income_level": [0],
|
94 |
+
}
|
95 |
+
input_df = pd.DataFrame.from_dict(input_dict)
|
96 |
+
input_data_file = "input_data.csv"
|
97 |
+
|
98 |
+
input_df.to_csv(input_data_file, index=None, header=None)
|
99 |
+
prod_dataset = get_dataset_from_csv(input_data_file, shuffle=True)
|
100 |
+
|
101 |
+
labels = ['Income greater than 50000',"Income less than 50000"]
|
102 |
+
prediction = model.predict(prod_dataset)
|
103 |
+
y_pred_prob = round(prediction[0].flatten()[0],5)
|
104 |
+
y_not_prob = round(1-prediction[0].flatten()[0],3)
|
105 |
+
|
106 |
+
confidences = {labels[0]: float(y_pred_prob), labels[1]: float(y_not_prob)}
|
107 |
+
return confidences
|
preprocess.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import pandas as pd
|
3 |
+
from .constants import CSV_HEADER, TARGET_FEATURE_NAME, WEIGHT_COLUMN_NAME, NUMERIC_FEATURE_NAMES, COLUMN_DEFAULTS, CATEGORICAL_FEATURES_WITH_VOCABULARY
|
4 |
+
|
5 |
+
|
6 |
+
##Helper functions for preprocessing of data:
|
7 |
+
|
8 |
+
def process(features, target):
|
9 |
+
for feature_name in features:
|
10 |
+
if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
|
11 |
+
# Cast categorical feature values to string.
|
12 |
+
features[feature_name] = tf.cast(features[feature_name], tf.dtypes.string)
|
13 |
+
# Get the instance weight.
|
14 |
+
weight = features.pop(WEIGHT_COLUMN_NAME)
|
15 |
+
return features, target, weight
|
16 |
+
|
17 |
+
|
18 |
+
def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
|
19 |
+
|
20 |
+
dataset = tf.data.experimental.make_csv_dataset(
|
21 |
+
csv_file_path,
|
22 |
+
batch_size=batch_size,
|
23 |
+
column_names=CSV_HEADER,
|
24 |
+
column_defaults=COLUMN_DEFAULTS,
|
25 |
+
label_name=TARGET_FEATURE_NAME,
|
26 |
+
num_epochs=1,
|
27 |
+
header=False,
|
28 |
+
shuffle=shuffle,
|
29 |
+
).map(process)
|
30 |
+
|
31 |
+
return dataset
|
32 |
+
|
33 |
+
def create_max_values_map():
|
34 |
+
max_values_map = {}
|
35 |
+
for col in NUMERIC_FEATURE_NAMES:
|
36 |
+
max_val = max(test_data[col])
|
37 |
+
max_values_map["max_"+col] = max_val
|
38 |
+
return max_values_map
|
39 |
+
|
40 |
+
def create_dropdown_default_values_map():
|
41 |
+
dropdown_default_values_map = {}
|
42 |
+
for col in CATEGORICAL_FEATURES_WITH_VOCABULARY.keys():
|
43 |
+
max_val = test_data[col].max()
|
44 |
+
dropdown_default_values_map["max_"+col] = max_val
|
45 |
+
return dropdown_default_values_map
|
46 |
+
|
47 |
+
def load_test_data():
|
48 |
+
|
49 |
+
test_data_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/census-income-mld/census-income.test.gz"
|
50 |
+
test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)
|
51 |
+
|
52 |
+
return test_data
|
53 |
+
|
54 |
+
def create_sample_test_data():
|
55 |
+
|
56 |
+
test_data = load_test_data()
|
57 |
+
|
58 |
+
test_data["income_level"] = test_data["income_level"].apply(
|
59 |
+
lambda x: 0 if x == " - 50000." else 1)
|
60 |
+
|
61 |
+
sample_df = test_data.loc[:20,:]
|
62 |
+
sample_df_values = samp.values.tolist()
|
63 |
+
|
64 |
+
return sample_df_values
|
65 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
tensorflow
|
2 |
+
gradio
|