File size: 6,452 Bytes
db827e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import math
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt

import gradio as gr
from huggingface_hub import from_pretrained_keras

model = from_pretrained_keras('keras-io/tab_transformer')

CSV_HEADER = [
    "age",
    "workclass",
    "fnlwgt",
    "education",
    "education_num",
    "marital_status",
    "occupation",
    "relationship",
    "race",
    "gender",
    "capital_gain",
    "capital_loss",
    "hours_per_week",
    "native_country",
    "income_bracket",
]

def get_dataset_from_pandas(data):
    for col in data.columns:
        if data[col].dtype == 'float64':
            data[col] = data[col].astype('float32')
        elif col == 'age':
            data[col] = data[col].astype('float32')
    ds = tf.data.Dataset.from_tensors(dict(data.drop(columns = [i for i in ['income_bracket','fnlwgt'] if i in data.columns])))
    return ds


def infer(age, workclass, education, education_num, marital_status, occupation, relationship, race, gender, capital_gain, capital_loss, hours_per_week, native_country):

    data = pd.DataFrame({
        'age': age,
        'workclass': workclass,
        'education': education,
        'education_num': education_num,
        'marital_status': marital_status,
        'occupation': occupation,
        'relationship':relationship,
        'race': race,
        'gender': gender,
        'capital_gain': capital_gain,
        'capital_loss': capital_loss,
        'hours_per_week':hours_per_week,
        'native_country': native_country,
    }, index=[0])
    validation_dataset = get_dataset_from_pandas(data)
    # validation_dataset = get_dataset_from_csv(test_data_file, 1)
    pred = model.predict(validation_dataset)

    return f"{round(pred.flatten()[0]*100, 2)}%"

# get the inputs
inputs = [
          gr.Slider(minimum=16, maximum=120, step=1, label='age', value=30),
          gr.Radio(choices=[' Private', ' Local-gov', ' ?', ' Self-emp-not-inc',' Federal-gov', ' State-gov', ' Self-emp-inc', ' Without-pay', ' Never-worked'],
                    label='workclass', type='value',value=' Private'),
          gr.Radio(choices=[' 11th', ' HS-grad', ' Assoc-acdm', ' Some-college', ' 10th', ' Prof-school', ' 7th-8th', ' Bachelors', ' Masters', ' Doctorate',
                            ' 5th-6th', ' Assoc-voc', ' 9th', ' 12th', ' 1st-4th', ' Preschool'],
                   type='value', label='education', value=' Bachelors'),
          gr.Slider(minimum=1, maximum=16, step=1, label='education_num', value=10),
          gr.Radio(choices=[' Never-married', ' Married-civ-spouse', ' Widowed', ' Divorced', ' Separated', ' Married-spouse-absent', ' Married-AF-spouse'],
                   type='value', label='marital_status', value=' Married-civ-spouse'),
          gr.Radio(choices=[' Machine-op-inspct', ' Farming-fishing', ' Protective-serv', ' ?', ' Other-service', ' Prof-specialty', ' Craft-repair',
                            ' Adm-clerical', ' Exec-managerial', ' Tech-support', ' Sales', ' Priv-house-serv', ' Transport-moving', ' Handlers-cleaners', ' Armed-Forces'],
                   type='value', label='occupation', value=' Tech-support'),
          gr.Radio(choices=[' Own-child', ' Husband', ' Not-in-family', ' Unmarried', ' Wife', ' Other-relative'],
                   type='value', label='relationship', value=' Wife'),
          gr.Radio(choices=[' Black', ' White', ' Asian-Pac-Islander', ' Other', ' Amer-Indian-Eskimo'],
                   type='value', label='race', value=' Other'),
          gr.Radio(choices=[' Male', ' Female'], type='value', label='gender', value=' Female'),
          gr.Slider(minimum=0, maximum=500000, step=1, label='capital_gain', value=80000),
          gr.Slider(minimum=0, maximum=50000, step=1, label='capital_loss', value=1000),
          gr.Slider(minimum=1, maximum=168, step=1, label='hours_per_week', value=40),
          gr.Radio(choices=[' United-States', ' ?', ' Peru', ' Guatemala', ' Mexico', ' Dominican-Republic', ' Ireland', ' Germany', ' Philippines', ' Thailand', ' Haiti',
                            ' El-Salvador', ' Puerto-Rico', ' Vietnam', ' South', ' Columbia', ' Japan', ' India', ' Cambodia', ' Poland', ' Laos', ' England', ' Cuba', ' Taiwan',
                            ' Italy', ' Canada', ' Portugal', ' China', ' Nicaragua', ' Honduras', ' Iran', ' Scotland', ' Jamaica', ' Ecuador', ' Yugoslavia', ' Hungary',
                            ' Hong', ' Greece', ' Trinadad&Tobago', ' Outlying-US(Guam-USVI-etc)', ' France'],
                    type='value', label='native_country', value=' Vietnam'),
          ]

# the app outputs two segmented images
output = gr.Textbox(label='Probability of income larger than 50,000 USD per year:')
# it's good practice to pass examples, description and a title to guide users
title = 'Tab Transformer for Structured data'
description = 'Using Transformer to predict whether the income will be larger than 50,000 USD given the input features.'

article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. Based on this <a href=\"https://keras.io/examples/structured_data/tabtransformer/\">keras example</a> by <a href=\"https://www.linkedin.com/in/khalid-salama-24403144\">Khalid Salama.</a> HuggingFace Model <a href=\"https://huggingface.co/keras-io/tab_transformer\">here</a> "

examples = [[39.0, ' Private', ' Assoc-voc', 11.0, ' Divorced', ' Tech-support', ' Not-in-family', ' White', ' Female', 0.0, 0.0, 40.0, ' United-States'],
            [60.0, ' Private', ' 12th', 8.0, ' Married-civ-spouse', ' Handlers-cleaners', ' Husband', ' White', ' Male', 0.0, 0.0, 40.0, ' United-States'],
            [42.0, ' Private',' Masters', 14.0, ' Married-civ-spouse', ' Prof-specialty', ' Husband', ' Asian-Pac-Islander', ' Male', 0.0, 0.0, 40.0, ' Taiwan',],
            [31.0, ' Local-gov',' Bachelors', 13.0, ' Married-civ-spouse', ' Craft-repair', ' Husband', ' White', ' Male', 0.0, 0.0, 51.0, ' United-States'],
            [30.0, ' Private', ' Masters', 14.0, ' Never-married', ' Prof-specialty', ' Not-in-family', ' Asian-Pac-Islander', ' Male', 0.0, 0.0, 45.0, ' Iran']]

gr.Interface(infer, inputs, output, examples= examples, allow_flagging='never',
             title=title, description=description, article=article, live=False).launch(enable_queue=True, debug=True, inbrowser=True)