geninhu commited on
Commit
db827e6
β€’
1 Parent(s): 33367ee

Add application file

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ from tensorflow import keras
6
+ from tensorflow.keras import layers
7
+ import tensorflow_addons as tfa
8
+ import matplotlib.pyplot as plt
9
+
10
+ import gradio as gr
11
+ from huggingface_hub import from_pretrained_keras
12
+
13
+ model = from_pretrained_keras('keras-io/tab_transformer')
14
+
15
+ CSV_HEADER = [
16
+ "age",
17
+ "workclass",
18
+ "fnlwgt",
19
+ "education",
20
+ "education_num",
21
+ "marital_status",
22
+ "occupation",
23
+ "relationship",
24
+ "race",
25
+ "gender",
26
+ "capital_gain",
27
+ "capital_loss",
28
+ "hours_per_week",
29
+ "native_country",
30
+ "income_bracket",
31
+ ]
32
+
33
+ def get_dataset_from_pandas(data):
34
+ for col in data.columns:
35
+ if data[col].dtype == 'float64':
36
+ data[col] = data[col].astype('float32')
37
+ elif col == 'age':
38
+ data[col] = data[col].astype('float32')
39
+ ds = tf.data.Dataset.from_tensors(dict(data.drop(columns = [i for i in ['income_bracket','fnlwgt'] if i in data.columns])))
40
+ return ds
41
+
42
+
43
+ def infer(age, workclass, education, education_num, marital_status, occupation, relationship, race, gender, capital_gain, capital_loss, hours_per_week, native_country):
44
+
45
+ data = pd.DataFrame({
46
+ 'age': age,
47
+ 'workclass': workclass,
48
+ 'education': education,
49
+ 'education_num': education_num,
50
+ 'marital_status': marital_status,
51
+ 'occupation': occupation,
52
+ 'relationship':relationship,
53
+ 'race': race,
54
+ 'gender': gender,
55
+ 'capital_gain': capital_gain,
56
+ 'capital_loss': capital_loss,
57
+ 'hours_per_week':hours_per_week,
58
+ 'native_country': native_country,
59
+ }, index=[0])
60
+ validation_dataset = get_dataset_from_pandas(data)
61
+ # validation_dataset = get_dataset_from_csv(test_data_file, 1)
62
+ pred = model.predict(validation_dataset)
63
+
64
+ return f"{round(pred.flatten()[0]*100, 2)}%"
65
+
66
+ # get the inputs
67
+ inputs = [
68
+ gr.Slider(minimum=16, maximum=120, step=1, label='age', value=30),
69
+ gr.Radio(choices=[' Private', ' Local-gov', ' ?', ' Self-emp-not-inc',' Federal-gov', ' State-gov', ' Self-emp-inc', ' Without-pay', ' Never-worked'],
70
+ label='workclass', type='value',value=' Private'),
71
+ gr.Radio(choices=[' 11th', ' HS-grad', ' Assoc-acdm', ' Some-college', ' 10th', ' Prof-school', ' 7th-8th', ' Bachelors', ' Masters', ' Doctorate',
72
+ ' 5th-6th', ' Assoc-voc', ' 9th', ' 12th', ' 1st-4th', ' Preschool'],
73
+ type='value', label='education', value=' Bachelors'),
74
+ gr.Slider(minimum=1, maximum=16, step=1, label='education_num', value=10),
75
+ gr.Radio(choices=[' Never-married', ' Married-civ-spouse', ' Widowed', ' Divorced', ' Separated', ' Married-spouse-absent', ' Married-AF-spouse'],
76
+ type='value', label='marital_status', value=' Married-civ-spouse'),
77
+ gr.Radio(choices=[' Machine-op-inspct', ' Farming-fishing', ' Protective-serv', ' ?', ' Other-service', ' Prof-specialty', ' Craft-repair',
78
+ ' Adm-clerical', ' Exec-managerial', ' Tech-support', ' Sales', ' Priv-house-serv', ' Transport-moving', ' Handlers-cleaners', ' Armed-Forces'],
79
+ type='value', label='occupation', value=' Tech-support'),
80
+ gr.Radio(choices=[' Own-child', ' Husband', ' Not-in-family', ' Unmarried', ' Wife', ' Other-relative'],
81
+ type='value', label='relationship', value=' Wife'),
82
+ gr.Radio(choices=[' Black', ' White', ' Asian-Pac-Islander', ' Other', ' Amer-Indian-Eskimo'],
83
+ type='value', label='race', value=' Other'),
84
+ gr.Radio(choices=[' Male', ' Female'], type='value', label='gender', value=' Female'),
85
+ gr.Slider(minimum=0, maximum=500000, step=1, label='capital_gain', value=80000),
86
+ gr.Slider(minimum=0, maximum=50000, step=1, label='capital_loss', value=1000),
87
+ gr.Slider(minimum=1, maximum=168, step=1, label='hours_per_week', value=40),
88
+ gr.Radio(choices=[' United-States', ' ?', ' Peru', ' Guatemala', ' Mexico', ' Dominican-Republic', ' Ireland', ' Germany', ' Philippines', ' Thailand', ' Haiti',
89
+ ' El-Salvador', ' Puerto-Rico', ' Vietnam', ' South', ' Columbia', ' Japan', ' India', ' Cambodia', ' Poland', ' Laos', ' England', ' Cuba', ' Taiwan',
90
+ ' Italy', ' Canada', ' Portugal', ' China', ' Nicaragua', ' Honduras', ' Iran', ' Scotland', ' Jamaica', ' Ecuador', ' Yugoslavia', ' Hungary',
91
+ ' Hong', ' Greece', ' Trinadad&Tobago', ' Outlying-US(Guam-USVI-etc)', ' France'],
92
+ type='value', label='native_country', value=' Vietnam'),
93
+ ]
94
+
95
+ # the app outputs two segmented images
96
+ output = gr.Textbox(label='Probability of income larger than 50,000 USD per year:')
97
+ # it's good practice to pass examples, description and a title to guide users
98
+ title = 'Tab Transformer for Structured data'
99
+ description = 'Using Transformer to predict whether the income will be larger than 50,000 USD given the input features.'
100
+
101
+ article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. Based on this <a href=\"https://keras.io/examples/structured_data/tabtransformer/\">keras example</a> by <a href=\"https://www.linkedin.com/in/khalid-salama-24403144\">Khalid Salama.</a> HuggingFace Model <a href=\"https://huggingface.co/keras-io/tab_transformer\">here</a> "
102
+
103
+ examples = [[39.0, ' Private', ' Assoc-voc', 11.0, ' Divorced', ' Tech-support', ' Not-in-family', ' White', ' Female', 0.0, 0.0, 40.0, ' United-States'],
104
+ [60.0, ' Private', ' 12th', 8.0, ' Married-civ-spouse', ' Handlers-cleaners', ' Husband', ' White', ' Male', 0.0, 0.0, 40.0, ' United-States'],
105
+ [42.0, ' Private',' Masters', 14.0, ' Married-civ-spouse', ' Prof-specialty', ' Husband', ' Asian-Pac-Islander', ' Male', 0.0, 0.0, 40.0, ' Taiwan',],
106
+ [31.0, ' Local-gov',' Bachelors', 13.0, ' Married-civ-spouse', ' Craft-repair', ' Husband', ' White', ' Male', 0.0, 0.0, 51.0, ' United-States'],
107
+ [30.0, ' Private', ' Masters', 14.0, ' Never-married', ' Prof-specialty', ' Not-in-family', ' Asian-Pac-Islander', ' Male', 0.0, 0.0, 45.0, ' Iran']]
108
+
109
+ gr.Interface(infer, inputs, output, examples= examples, allow_flagging='never',
110
+ title=title, description=description, article=article, live=False).launch(enable_queue=True, debug=True, inbrowser=True)