Spaces:
Runtime error
Runtime error
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import tensorflow as tf
|
5 |
+
from tensorflow import keras
|
6 |
+
from tensorflow.keras import layers
|
7 |
+
import tensorflow_addons as tfa
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
from huggingface_hub import from_pretrained_keras
|
12 |
+
|
13 |
+
model = from_pretrained_keras('keras-io/tab_transformer')
|
14 |
+
|
15 |
+
CSV_HEADER = [
|
16 |
+
"age",
|
17 |
+
"workclass",
|
18 |
+
"fnlwgt",
|
19 |
+
"education",
|
20 |
+
"education_num",
|
21 |
+
"marital_status",
|
22 |
+
"occupation",
|
23 |
+
"relationship",
|
24 |
+
"race",
|
25 |
+
"gender",
|
26 |
+
"capital_gain",
|
27 |
+
"capital_loss",
|
28 |
+
"hours_per_week",
|
29 |
+
"native_country",
|
30 |
+
"income_bracket",
|
31 |
+
]
|
32 |
+
|
33 |
+
def get_dataset_from_pandas(data):
|
34 |
+
for col in data.columns:
|
35 |
+
if data[col].dtype == 'float64':
|
36 |
+
data[col] = data[col].astype('float32')
|
37 |
+
elif col == 'age':
|
38 |
+
data[col] = data[col].astype('float32')
|
39 |
+
ds = tf.data.Dataset.from_tensors(dict(data.drop(columns = [i for i in ['income_bracket','fnlwgt'] if i in data.columns])))
|
40 |
+
return ds
|
41 |
+
|
42 |
+
|
43 |
+
def infer(age, workclass, education, education_num, marital_status, occupation, relationship, race, gender, capital_gain, capital_loss, hours_per_week, native_country):
|
44 |
+
|
45 |
+
data = pd.DataFrame({
|
46 |
+
'age': age,
|
47 |
+
'workclass': workclass,
|
48 |
+
'education': education,
|
49 |
+
'education_num': education_num,
|
50 |
+
'marital_status': marital_status,
|
51 |
+
'occupation': occupation,
|
52 |
+
'relationship':relationship,
|
53 |
+
'race': race,
|
54 |
+
'gender': gender,
|
55 |
+
'capital_gain': capital_gain,
|
56 |
+
'capital_loss': capital_loss,
|
57 |
+
'hours_per_week':hours_per_week,
|
58 |
+
'native_country': native_country,
|
59 |
+
}, index=[0])
|
60 |
+
validation_dataset = get_dataset_from_pandas(data)
|
61 |
+
# validation_dataset = get_dataset_from_csv(test_data_file, 1)
|
62 |
+
pred = model.predict(validation_dataset)
|
63 |
+
|
64 |
+
return f"{round(pred.flatten()[0]*100, 2)}%"
|
65 |
+
|
66 |
+
# get the inputs
|
67 |
+
inputs = [
|
68 |
+
gr.Slider(minimum=16, maximum=120, step=1, label='age', value=30),
|
69 |
+
gr.Radio(choices=[' Private', ' Local-gov', ' ?', ' Self-emp-not-inc',' Federal-gov', ' State-gov', ' Self-emp-inc', ' Without-pay', ' Never-worked'],
|
70 |
+
label='workclass', type='value',value=' Private'),
|
71 |
+
gr.Radio(choices=[' 11th', ' HS-grad', ' Assoc-acdm', ' Some-college', ' 10th', ' Prof-school', ' 7th-8th', ' Bachelors', ' Masters', ' Doctorate',
|
72 |
+
' 5th-6th', ' Assoc-voc', ' 9th', ' 12th', ' 1st-4th', ' Preschool'],
|
73 |
+
type='value', label='education', value=' Bachelors'),
|
74 |
+
gr.Slider(minimum=1, maximum=16, step=1, label='education_num', value=10),
|
75 |
+
gr.Radio(choices=[' Never-married', ' Married-civ-spouse', ' Widowed', ' Divorced', ' Separated', ' Married-spouse-absent', ' Married-AF-spouse'],
|
76 |
+
type='value', label='marital_status', value=' Married-civ-spouse'),
|
77 |
+
gr.Radio(choices=[' Machine-op-inspct', ' Farming-fishing', ' Protective-serv', ' ?', ' Other-service', ' Prof-specialty', ' Craft-repair',
|
78 |
+
' Adm-clerical', ' Exec-managerial', ' Tech-support', ' Sales', ' Priv-house-serv', ' Transport-moving', ' Handlers-cleaners', ' Armed-Forces'],
|
79 |
+
type='value', label='occupation', value=' Tech-support'),
|
80 |
+
gr.Radio(choices=[' Own-child', ' Husband', ' Not-in-family', ' Unmarried', ' Wife', ' Other-relative'],
|
81 |
+
type='value', label='relationship', value=' Wife'),
|
82 |
+
gr.Radio(choices=[' Black', ' White', ' Asian-Pac-Islander', ' Other', ' Amer-Indian-Eskimo'],
|
83 |
+
type='value', label='race', value=' Other'),
|
84 |
+
gr.Radio(choices=[' Male', ' Female'], type='value', label='gender', value=' Female'),
|
85 |
+
gr.Slider(minimum=0, maximum=500000, step=1, label='capital_gain', value=80000),
|
86 |
+
gr.Slider(minimum=0, maximum=50000, step=1, label='capital_loss', value=1000),
|
87 |
+
gr.Slider(minimum=1, maximum=168, step=1, label='hours_per_week', value=40),
|
88 |
+
gr.Radio(choices=[' United-States', ' ?', ' Peru', ' Guatemala', ' Mexico', ' Dominican-Republic', ' Ireland', ' Germany', ' Philippines', ' Thailand', ' Haiti',
|
89 |
+
' El-Salvador', ' Puerto-Rico', ' Vietnam', ' South', ' Columbia', ' Japan', ' India', ' Cambodia', ' Poland', ' Laos', ' England', ' Cuba', ' Taiwan',
|
90 |
+
' Italy', ' Canada', ' Portugal', ' China', ' Nicaragua', ' Honduras', ' Iran', ' Scotland', ' Jamaica', ' Ecuador', ' Yugoslavia', ' Hungary',
|
91 |
+
' Hong', ' Greece', ' Trinadad&Tobago', ' Outlying-US(Guam-USVI-etc)', ' France'],
|
92 |
+
type='value', label='native_country', value=' Vietnam'),
|
93 |
+
]
|
94 |
+
|
95 |
+
# the app outputs two segmented images
|
96 |
+
output = gr.Textbox(label='Probability of income larger than 50,000 USD per year:')
|
97 |
+
# it's good practice to pass examples, description and a title to guide users
|
98 |
+
title = 'Tab Transformer for Structured data'
|
99 |
+
description = 'Using Transformer to predict whether the income will be larger than 50,000 USD given the input features.'
|
100 |
+
|
101 |
+
article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. Based on this <a href=\"https://keras.io/examples/structured_data/tabtransformer/\">keras example</a> by <a href=\"https://www.linkedin.com/in/khalid-salama-24403144\">Khalid Salama.</a> HuggingFace Model <a href=\"https://huggingface.co/keras-io/tab_transformer\">here</a> "
|
102 |
+
|
103 |
+
examples = [[39.0, ' Private', ' Assoc-voc', 11.0, ' Divorced', ' Tech-support', ' Not-in-family', ' White', ' Female', 0.0, 0.0, 40.0, ' United-States'],
|
104 |
+
[60.0, ' Private', ' 12th', 8.0, ' Married-civ-spouse', ' Handlers-cleaners', ' Husband', ' White', ' Male', 0.0, 0.0, 40.0, ' United-States'],
|
105 |
+
[42.0, ' Private',' Masters', 14.0, ' Married-civ-spouse', ' Prof-specialty', ' Husband', ' Asian-Pac-Islander', ' Male', 0.0, 0.0, 40.0, ' Taiwan',],
|
106 |
+
[31.0, ' Local-gov',' Bachelors', 13.0, ' Married-civ-spouse', ' Craft-repair', ' Husband', ' White', ' Male', 0.0, 0.0, 51.0, ' United-States'],
|
107 |
+
[30.0, ' Private', ' Masters', 14.0, ' Never-married', ' Prof-specialty', ' Not-in-family', ' Asian-Pac-Islander', ' Male', 0.0, 0.0, 45.0, ' Iran']]
|
108 |
+
|
109 |
+
gr.Interface(infer, inputs, output, examples= examples, allow_flagging='never',
|
110 |
+
title=title, description=description, article=article, live=False).launch(enable_queue=True, debug=True, inbrowser=True)
|