shivi commited on
Commit
563baab
1 Parent(s): 51f2058

Added all app files for gradio demo for Binary classification on structured data using GRN-VSN model

Browse files
Files changed (5) hide show
  1. app.py +59 -0
  2. constants.py +88 -0
  3. predict.py +107 -0
  4. preprocess.py +65 -0
  5. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .constants import CSV_HEADER, NUMERIC_FEATURE_NAMES, CATEGORICAL_FEATURES_WITH_VOCABULARY, NUMBER_INPUT_COLS
3
+ from .preprocess import create_max_values_map, create_dropdown_default_values_map, create_sample_test_data
4
+ from .predict import batch_predict, user_input_predict
5
+
6
+ inputs_list = []
7
+ max_values_map = create_max_values_map()
8
+ dropdown_default_values_map = create_dropdown_default_values_map()
9
+ sample_input_df_val = create_sample_test_data()
10
+
11
+ demo = gr.Blocks()
12
+
13
+ with demo:
14
+
15
+ gr.Markdown("# **Binary Classification using Gated Residual and Variable Selection Networks** \n")
16
+ gr.Markdown("This demo demonstrates the use of Gated Residual Networks (GRN) and Variable Selection Networks (VSN), proposed by Bryan Lim et al. in <a href=\"https://arxiv.org/abs/1912.09363/\">Temporal Fusion Transformers (TFT) for Interpretable Multi-horizon Time Series Forecasting</a> for structured data classification")
17
+ gr.Markdown("Play around and see yourself 🤗 ")
18
+
19
+ with gr.Tabs():
20
+
21
+ with gr.TabItem("Predict using batch of inputs"):
22
+ gr.Markdown("**Input DataFrame** \n")
23
+ input_df = gr.Dataframe(headers=CSV_HEADER,value=samp,)
24
+ gr.Markdown("**Output DataFrame** \n")
25
+ output_df = gr.Dataframe()
26
+ gr.Markdown("**Make Predictions**")
27
+ with gr.Row():
28
+ compute_button = gr.Button("Predict")
29
+
30
+ with gr.TabItem("Tweak inputs Yourself"):
31
+ with gr.Tabs():
32
+
33
+ with gr.TabItem("Numerical Inputs"):
34
+ gr.Markdown("Set values for numerical inputs here.")
35
+ for num_variable in NUMERIC_FEATURE_NAMES:
36
+ with gr.Column():
37
+ if num_variable in NUMBER_INPUT_COLS:
38
+ numeric_input = gr.Number(label=num_variable)
39
+ else:
40
+ curr_max_val = max_values_map["max_"+num_variable]
41
+ numeric_input = gr.Slider(0,curr_max_val, label=num_variable,step=1)
42
+ inputs_list.append(numeric_input)
43
+
44
+ with gr.TabItem("Categorical Inputs"):
45
+ gr.Markdown("Choose values for categorical inputs here.")
46
+ for cat_variable in CATEGORICAL_FEATURES_WITH_VOCABULARY.keys():
47
+ with gr.Column():
48
+ categorical_input = gr.Dropdown(CATEGORICAL_FEATURES_WITH_VOCABULARY[cat_variable], label=cat_variable, value=str(dropdown_default_values_map["max_"+cat_variable]))
49
+ inputs_list.append(categorical_input)
50
+
51
+ predict_button = gr.Button("Predict")
52
+ final_output = gr.Label()
53
+
54
+ predict_button.click(user_input_predict, inputs=inputs_list, outputs=final_output)
55
+ compute_button.click(batch_predict, inputs=input_df, outputs=output_df)
56
+ gr.Markdown('\n Author: <a href=\"https://huggingface.co/shivi\">Shivalika Singh</a> <br> Based on this <a href=\"https://keras.io/examples/structured_data/classification_with_grn_and_vsn/\">Keras example</a> by <a href=\"https://www.linkedin.com/in/khalid-salama-24403144/\">Khalid Salama</a> <br> Demo Powered by this <a href=\"https://huggingface.co/shivi/classification-grn-vsn\">GRN-VSN model</a>')
57
+
58
+
59
+ demo.launch()
constants.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from .preprocess import load_test_data
3
+
4
+ # Column names.
5
+ CSV_HEADER = [
6
+ "age",
7
+ "class_of_worker",
8
+ "detailed_industry_recode",
9
+ "detailed_occupation_recode",
10
+ "education",
11
+ "wage_per_hour",
12
+ "enroll_in_edu_inst_last_wk",
13
+ "marital_stat",
14
+ "major_industry_code",
15
+ "major_occupation_code",
16
+ "race",
17
+ "hispanic_origin",
18
+ "sex",
19
+ "member_of_a_labor_union",
20
+ "reason_for_unemployment",
21
+ "full_or_part_time_employment_stat",
22
+ "capital_gains",
23
+ "capital_losses",
24
+ "dividends_from_stocks",
25
+ "tax_filer_stat",
26
+ "region_of_previous_residence",
27
+ "state_of_previous_residence",
28
+ "detailed_household_and_family_stat",
29
+ "detailed_household_summary_in_household",
30
+ "instance_weight",
31
+ "migration_code-change_in_msa",
32
+ "migration_code-change_in_reg",
33
+ "migration_code-move_within_reg",
34
+ "live_in_this_house_1_year_ago",
35
+ "migration_prev_res_in_sunbelt",
36
+ "num_persons_worked_for_employer",
37
+ "family_members_under_18",
38
+ "country_of_birth_father",
39
+ "country_of_birth_mother",
40
+ "country_of_birth_self",
41
+ "citizenship",
42
+ "own_business_or_self_employed",
43
+ "fill_inc_questionnaire_for_veterans_admin",
44
+ "veterans_benefits",
45
+ "weeks_worked_in_year",
46
+ "year",
47
+ "income_level",
48
+ ]
49
+
50
+ # Target feature name.
51
+ TARGET_FEATURE_NAME = "income_level"
52
+
53
+ # Weight column name.
54
+ WEIGHT_COLUMN_NAME = "instance_weight"
55
+
56
+ # Numeric feature names.
57
+ NUMERIC_FEATURE_NAMES = [
58
+ "age",
59
+ "wage_per_hour",
60
+ "capital_gains",
61
+ "capital_losses",
62
+ "dividends_from_stocks",
63
+ "num_persons_worked_for_employer",
64
+ "weeks_worked_in_year",
65
+ ]
66
+
67
+ ##Cols which will use "Number" component of gradio for taking user input
68
+ NUMBER_INPUT_COLS = ['age', 'num_persons_worked_for_employer','weeks_worked_in_year']
69
+
70
+ test_data = load_test_data()
71
+
72
+ CATEGORICAL_FEATURES_WITH_VOCABULARY = {
73
+ feature_name: sorted([str(value) for value in list(test_data[feature_name].unique())])
74
+ for feature_name in CSV_HEADER
75
+ if feature_name
76
+ not in list(NUMERIC_FEATURE_NAMES + [WEIGHT_COLUMN_NAME, TARGET_FEATURE_NAME])
77
+ }
78
+ # All features names.
79
+ FEATURE_NAMES = NUMERIC_FEATURE_NAMES + list(
80
+ CATEGORICAL_FEATURES_WITH_VOCABULARY.keys()
81
+ )
82
+ # Feature default values.
83
+ COLUMN_DEFAULTS = [
84
+ [0.0]
85
+ if feature_name in NUMERIC_FEATURE_NAMES + [TARGET_FEATURE_NAME, WEIGHT_COLUMN_NAME]
86
+ else ["NA"]
87
+ for feature_name in CSV_HEADER
88
+ ]
predict.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from .preprocess import get_dataset_from_csv
3
+ from huggingface_hub import from_pretrained_keras
4
+
5
+ ##Load Model
6
+ model = from_pretrained_keras("shivi/classification-grn-vsn")
7
+
8
+ def batch_predict(input_data):
9
+ """
10
+ This function is used for fetching predictions corresponding to input_dataframe.
11
+ It outputs another dataframe containing:
12
+ 1. prediction probability for each class
13
+ 2. actual expected outcome for each entry in the input dataframe
14
+ """
15
+ input_data_file = "prod_data.csv"
16
+ labels = ['Probability of Income greater than 50000',"Probability of Income less than 50000","Actual Income"]
17
+
18
+ predictions_df = pd.DataFrame(columns=labels)
19
+
20
+ input_data.to_csv(input_data_file, index=None, header=None)
21
+
22
+ prod_dataset = get_dataset_from_csv(input_data_file, shuffle=True)
23
+
24
+ pred = model.predict(prod_dataset)
25
+
26
+ for prediction, actual_gt in zip(pred, input_data['income_level'].values.tolist()):
27
+ y_pred_prob = round(prediction.flatten()[0] * 100, 2)
28
+ y_not_prob = round((1-prediction.flatten()[0]) * 100, 2)
29
+ y_pred = ">50000" if prediction.flatten()[0] > 0.5 else "<50000"
30
+ prob_scores = {labels[0]: str(y_pred_prob)+"%" , labels[1]: str(y_not_prob)+"%", labels[2]: y_pred}
31
+ predictions_df = predictions_df.append(prob_scores,ignore_index=True)
32
+
33
+ return predictions_df
34
+
35
+
36
+ def user_input_predict(age, wage, cap_gains, cap_losses, dividends, num_persons, weeks_worked_in_year,
37
+ class_of_worker, detailed_industry_recode,detailed_occupation_recode,education,
38
+ enroll_in_edu_inst_last_wk, marital_stat, major_industry_code,major_occupation_code,
39
+ race, hispanic_origin, sex, member_of_a_labor_union,reason_for_unemployment,
40
+ full_or_part_time_employment_stat, tax_filer_stat,region_of_previous_residence,
41
+ state_of_previous_residence,detailed_household_and_family_stat,detailed_household_summary_in_household,
42
+ migration_codechange_in_msa,migration_codechange_in_reg, migration_codemove_within_reg,
43
+ live_in_this_house_1_year_ago,migration_prev_res_in_sunbelt,family_members_under_18,
44
+ country_of_birth_father,country_of_birth_mother,country_of_birth_self,
45
+ citizenship,own_business_or_self_employed,fill_inc_questionnaire_for_veterans_admin,
46
+ veterans_benefits, year):
47
+
48
+ """
49
+ This function is used for fetching model predictions based on inputs given by user on demo app
50
+ """
51
+
52
+ input_dict = {"age": [age],
53
+ "class_of_worker": [class_of_worker],
54
+ "detailed_industry_recode": [detailed_industry_recode],
55
+ "detailed_occupation_recode": [detailed_occupation_recode],
56
+ "education":[education],
57
+ "wage_per_hour": [wage],
58
+ "enroll_in_edu_inst_last_wk": [enroll_in_edu_inst_last_wk],
59
+ "marital_stat": [marital_stat],
60
+ "major_industry_code": [major_industry_code],
61
+ "major_occupation_code": [major_occupation_code],
62
+ "race": [race],
63
+ "hispanic_origin": [hispanic_origin],
64
+ "sex": [sex],
65
+ "member_of_a_labor_union": [member_of_a_labor_union],
66
+ "reason_for_unemployment": [reason_for_unemployment],
67
+ "full_or_part_time_employment_stat": [full_or_part_time_employment_stat],
68
+ "capital_gains": [cap_gains],
69
+ "capital_losses": [cap_losses],
70
+ "dividends_from_stocks": [dividends],
71
+ "tax_filer_stat": [tax_filer_stat],
72
+ "region_of_previous_residence": [region_of_previous_residence],
73
+ "state_of_previous_residence": [state_of_previous_residence],
74
+ "detailed_household_and_family_stat": [detailed_household_and_family_stat],
75
+ "detailed_household_summary_in_household": [detailed_household_summary_in_household],
76
+ "instance_weight": [0.0],
77
+ "migration_code-change_in_msa": [migration_codechange_in_msa],
78
+ "migration_code-change_in_reg": [migration_codechange_in_reg],
79
+ "migration_code-move_within_reg": [migration_codemove_within_reg],
80
+ "live_in_this_house_1_year_ago": [live_in_this_house_1_year_ago],
81
+ "migration_prev_res_in_sunbelt": [migration_prev_res_in_sunbelt],
82
+ "num_persons_worked_for_employer": [num_persons],
83
+ "family_members_under_18": [family_members_under_18],
84
+ "country_of_birth_father": [country_of_birth_father],
85
+ "country_of_birth_mother": [country_of_birth_mother],
86
+ "country_of_birth_self": [country_of_birth_self],
87
+ "citizenship": [citizenship],
88
+ "own_business_or_self_employed": [own_business_or_self_employed],
89
+ "fill_inc_questionnaire_for_veterans_admin": [fill_inc_questionnaire_for_veterans_admin],
90
+ "veterans_benefits": [veterans_benefits],
91
+ "weeks_worked_in_year": [weeks_worked_in_year],
92
+ "year": [year],
93
+ "income_level": [0],
94
+ }
95
+ input_df = pd.DataFrame.from_dict(input_dict)
96
+ input_data_file = "input_data.csv"
97
+
98
+ input_df.to_csv(input_data_file, index=None, header=None)
99
+ prod_dataset = get_dataset_from_csv(input_data_file, shuffle=True)
100
+
101
+ labels = ['Income greater than 50000',"Income less than 50000"]
102
+ prediction = model.predict(prod_dataset)
103
+ y_pred_prob = round(prediction[0].flatten()[0],5)
104
+ y_not_prob = round(1-prediction[0].flatten()[0],3)
105
+
106
+ confidences = {labels[0]: float(y_pred_prob), labels[1]: float(y_not_prob)}
107
+ return confidences
preprocess.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import pandas as pd
3
+ from .constants import CSV_HEADER, TARGET_FEATURE_NAME, WEIGHT_COLUMN_NAME, NUMERIC_FEATURE_NAMES, COLUMN_DEFAULTS, CATEGORICAL_FEATURES_WITH_VOCABULARY
4
+
5
+
6
+ ##Helper functions for preprocessing of data:
7
+
8
+ def process(features, target):
9
+ for feature_name in features:
10
+ if feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:
11
+ # Cast categorical feature values to string.
12
+ features[feature_name] = tf.cast(features[feature_name], tf.dtypes.string)
13
+ # Get the instance weight.
14
+ weight = features.pop(WEIGHT_COLUMN_NAME)
15
+ return features, target, weight
16
+
17
+
18
+ def get_dataset_from_csv(csv_file_path, shuffle=False, batch_size=128):
19
+
20
+ dataset = tf.data.experimental.make_csv_dataset(
21
+ csv_file_path,
22
+ batch_size=batch_size,
23
+ column_names=CSV_HEADER,
24
+ column_defaults=COLUMN_DEFAULTS,
25
+ label_name=TARGET_FEATURE_NAME,
26
+ num_epochs=1,
27
+ header=False,
28
+ shuffle=shuffle,
29
+ ).map(process)
30
+
31
+ return dataset
32
+
33
+ def create_max_values_map():
34
+ max_values_map = {}
35
+ for col in NUMERIC_FEATURE_NAMES:
36
+ max_val = max(test_data[col])
37
+ max_values_map["max_"+col] = max_val
38
+ return max_values_map
39
+
40
+ def create_dropdown_default_values_map():
41
+ dropdown_default_values_map = {}
42
+ for col in CATEGORICAL_FEATURES_WITH_VOCABULARY.keys():
43
+ max_val = test_data[col].max()
44
+ dropdown_default_values_map["max_"+col] = max_val
45
+ return dropdown_default_values_map
46
+
47
+ def load_test_data():
48
+
49
+ test_data_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/census-income-mld/census-income.test.gz"
50
+ test_data = pd.read_csv(test_data_url, header=None, names=CSV_HEADER)
51
+
52
+ return test_data
53
+
54
+ def create_sample_test_data():
55
+
56
+ test_data = load_test_data()
57
+
58
+ test_data["income_level"] = test_data["income_level"].apply(
59
+ lambda x: 0 if x == " - 50000." else 1)
60
+
61
+ sample_df = test_data.loc[:20,:]
62
+ sample_df_values = samp.values.tolist()
63
+
64
+ return sample_df_values
65
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tensorflow
2
+ gradio