gilmar's picture
feat: excel file upload
024e8dd
import joblib
import gradio as gr
import pandas as pd
from models import HealthInsurance
def load_data():
global _model
global _column_transformer
global _bins_annual_premium_type
_model = joblib.load(filename = 'parameters/random_forrest.gz')
_column_transformer = joblib.load(filename = 'parameters/column_transformer.joblib')
_bins_annual_premium_type = joblib.load(filename = 'parameters/bins_annual_premium_type.joblib')
def predict(df):
health_insurance = HealthInsurance(_model,_column_transformer,
_bins_annual_premium_type)
df_predicted = health_insurance.predict(df)
return df_predicted[['score','previously_insured',
'annual_premium','vintage','gender',
'age','region_code','policy_sales_channel',
'driving_license','vehicle_age',
'vehicle_damage']]
def create_input_table():
return gr.Dataframe(headers = ['previously_insured',
'annual_premium','vintage','gender',
'age','region_code','policy_sales_channel',
'driving_license','vehicle_age',
'vehicle_damage'],
datatype = ['number','number','number','str','number',
'number','number','number','str','str'],
row_count= 1,
col_count= (10,'fixed'),
type = 'pandas',
label = 'Input')
def create_output_table():
return gr.Dataframe(headers = ['score','previously_insured',
'annual_premium','vintage','gender',
'age','region_code','policy_sales_channel',
'driving_license','vehicle_age',
'vehicle_damage'],
type = 'pandas',
label = 'Output',
wrap = True,
interactive =False)
def create_file_object():
return gr.File(label = 'File upload',
type = 'bytes')
def convert_file_to_pandas(file):
df = pd.read_excel(io = file)
return df
def build_interface():
with gr.Blocks() as interface:
file_object = create_file_object()
input_table = create_input_table()
output_table = create_output_table()
greet_btn = gr.Button("Submit")
greet_btn.click(fn=predict, inputs=input_table, outputs=output_table)
file_object.change(fn = convert_file_to_pandas,
inputs = file_object,
outputs = input_table)
interface.launch()
if __name__ == "__main__":
load_data()
build_interface()