geninhu's picture
Add application file
db827e6
raw history blame
No virus
6.45 kB
import math
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
import gradio as gr
from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras('keras-io/tab_transformer')
CSV_HEADER = [
"age",
"workclass",
"fnlwgt",
"education",
"education_num",
"marital_status",
"occupation",
"relationship",
"race",
"gender",
"capital_gain",
"capital_loss",
"hours_per_week",
"native_country",
"income_bracket",
]
def get_dataset_from_pandas(data):
for col in data.columns:
if data[col].dtype == 'float64':
data[col] = data[col].astype('float32')
elif col == 'age':
data[col] = data[col].astype('float32')
ds = tf.data.Dataset.from_tensors(dict(data.drop(columns = [i for i in ['income_bracket','fnlwgt'] if i in data.columns])))
return ds
def infer(age, workclass, education, education_num, marital_status, occupation, relationship, race, gender, capital_gain, capital_loss, hours_per_week, native_country):
data = pd.DataFrame({
'age': age,
'workclass': workclass,
'education': education,
'education_num': education_num,
'marital_status': marital_status,
'occupation': occupation,
'relationship':relationship,
'race': race,
'gender': gender,
'capital_gain': capital_gain,
'capital_loss': capital_loss,
'hours_per_week':hours_per_week,
'native_country': native_country,
}, index=[0])
validation_dataset = get_dataset_from_pandas(data)
# validation_dataset = get_dataset_from_csv(test_data_file, 1)
pred = model.predict(validation_dataset)
return f"{round(pred.flatten()[0]*100, 2)}%"
# get the inputs
inputs = [
gr.Slider(minimum=16, maximum=120, step=1, label='age', value=30),
gr.Radio(choices=[' Private', ' Local-gov', ' ?', ' Self-emp-not-inc',' Federal-gov', ' State-gov', ' Self-emp-inc', ' Without-pay', ' Never-worked'],
label='workclass', type='value',value=' Private'),
gr.Radio(choices=[' 11th', ' HS-grad', ' Assoc-acdm', ' Some-college', ' 10th', ' Prof-school', ' 7th-8th', ' Bachelors', ' Masters', ' Doctorate',
' 5th-6th', ' Assoc-voc', ' 9th', ' 12th', ' 1st-4th', ' Preschool'],
type='value', label='education', value=' Bachelors'),
gr.Slider(minimum=1, maximum=16, step=1, label='education_num', value=10),
gr.Radio(choices=[' Never-married', ' Married-civ-spouse', ' Widowed', ' Divorced', ' Separated', ' Married-spouse-absent', ' Married-AF-spouse'],
type='value', label='marital_status', value=' Married-civ-spouse'),
gr.Radio(choices=[' Machine-op-inspct', ' Farming-fishing', ' Protective-serv', ' ?', ' Other-service', ' Prof-specialty', ' Craft-repair',
' Adm-clerical', ' Exec-managerial', ' Tech-support', ' Sales', ' Priv-house-serv', ' Transport-moving', ' Handlers-cleaners', ' Armed-Forces'],
type='value', label='occupation', value=' Tech-support'),
gr.Radio(choices=[' Own-child', ' Husband', ' Not-in-family', ' Unmarried', ' Wife', ' Other-relative'],
type='value', label='relationship', value=' Wife'),
gr.Radio(choices=[' Black', ' White', ' Asian-Pac-Islander', ' Other', ' Amer-Indian-Eskimo'],
type='value', label='race', value=' Other'),
gr.Radio(choices=[' Male', ' Female'], type='value', label='gender', value=' Female'),
gr.Slider(minimum=0, maximum=500000, step=1, label='capital_gain', value=80000),
gr.Slider(minimum=0, maximum=50000, step=1, label='capital_loss', value=1000),
gr.Slider(minimum=1, maximum=168, step=1, label='hours_per_week', value=40),
gr.Radio(choices=[' United-States', ' ?', ' Peru', ' Guatemala', ' Mexico', ' Dominican-Republic', ' Ireland', ' Germany', ' Philippines', ' Thailand', ' Haiti',
' El-Salvador', ' Puerto-Rico', ' Vietnam', ' South', ' Columbia', ' Japan', ' India', ' Cambodia', ' Poland', ' Laos', ' England', ' Cuba', ' Taiwan',
' Italy', ' Canada', ' Portugal', ' China', ' Nicaragua', ' Honduras', ' Iran', ' Scotland', ' Jamaica', ' Ecuador', ' Yugoslavia', ' Hungary',
' Hong', ' Greece', ' Trinadad&Tobago', ' Outlying-US(Guam-USVI-etc)', ' France'],
type='value', label='native_country', value=' Vietnam'),
]
# the app outputs two segmented images
output = gr.Textbox(label='Probability of income larger than 50,000 USD per year:')
# it's good practice to pass examples, description and a title to guide users
title = 'Tab Transformer for Structured data'
description = 'Using Transformer to predict whether the income will be larger than 50,000 USD given the input features.'
article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. Based on this <a href=\"https://keras.io/examples/structured_data/tabtransformer/\">keras example</a> by <a href=\"https://www.linkedin.com/in/khalid-salama-24403144\">Khalid Salama.</a> HuggingFace Model <a href=\"https://huggingface.co/keras-io/tab_transformer\">here</a> "
examples = [[39.0, ' Private', ' Assoc-voc', 11.0, ' Divorced', ' Tech-support', ' Not-in-family', ' White', ' Female', 0.0, 0.0, 40.0, ' United-States'],
[60.0, ' Private', ' 12th', 8.0, ' Married-civ-spouse', ' Handlers-cleaners', ' Husband', ' White', ' Male', 0.0, 0.0, 40.0, ' United-States'],
[42.0, ' Private',' Masters', 14.0, ' Married-civ-spouse', ' Prof-specialty', ' Husband', ' Asian-Pac-Islander', ' Male', 0.0, 0.0, 40.0, ' Taiwan',],
[31.0, ' Local-gov',' Bachelors', 13.0, ' Married-civ-spouse', ' Craft-repair', ' Husband', ' White', ' Male', 0.0, 0.0, 51.0, ' United-States'],
[30.0, ' Private', ' Masters', 14.0, ' Never-married', ' Prof-specialty', ' Not-in-family', ' Asian-Pac-Islander', ' Male', 0.0, 0.0, 45.0, ' Iran']]
gr.Interface(infer, inputs, output, examples= examples, allow_flagging='never',
title=title, description=description, article=article, live=False).launch(enable_queue=True, debug=True, inbrowser=True)