import gradio as gr from catboost import CatBoostClassifier import pickle import random municipios = pickle.load(open('municipios.pkl', 'rb')) unique_values = pickle.load(open('unique_values.pkl', 'rb')) unique_sexo = unique_values["sexo"].tolist() unique_sexo.remove('No especificado') unique_entidad = unique_values['ent_ocurr'].tolist() unique_entidad.remove('Entidad no especificada') unique_entidad.sort() unique_escolaridad = unique_values['escolarida'].tolist() unique_escolaridad.remove('No especificado') unique_escolaridad.sort() unique_estadocivil = unique_values['edo_civil'].tolist() unique_estadocivil.sort() unique_sitio = unique_values['sitio_ocur'].tolist() unique_sitio.sort() unique_derechohabiente = unique_values['derechohab'].tolist() unique_derechohabiente.remove('Otra / No especificada') unique_derechohabiente.sort() unique_actividad = unique_values['cond_act'].tolist() unique_actividad.sort(reverse=True) unique_mes = ['Enero', 'Febrero', 'Marzo', 'Abril', 'Mayo', 'Junio', 'Julio', 'Agosto', 'Septiembre', 'Octubre', 'Noviembre','Diciembre'] model = CatBoostClassifier() model.load_model("model") def predict(*args): preds_proba = model.predict_proba([args]) return {"Covid-19": float(preds_proba[0][0]), "Otra causa de muerte": 1 - float(preds_proba[0][0])} with gr.Blocks() as demo: with gr.Row(): with gr.Column(visible=True) as details_col: entidad = gr.Dropdown(label="Entidad de deceso", choices=unique_entidad, value=lambda: random.choice(unique_entidad)) municipio = gr.Dropdown(label="Municipio de deceso", choices=[], interactive=True) sexo= gr.Radio(label='Género', choices=unique_sexo, value=lambda: random.choice(unique_sexo)) edad = gr.Slider(1, 120, label="Edad", randomize=True, interactive=True) escolaridad = gr.Dropdown(label="Escolaridad", choices=unique_escolaridad, value=lambda: random.choice(unique_escolaridad)) estadocivil = gr.Dropdown(label="Estado civil", choices=unique_estadocivil, value=lambda: random.choice(unique_estadocivil)) sitio = gr.Dropdown(label="Sitio donde ocurrió el deceso", choices=unique_sitio, value=lambda: random.choice(unique_sitio)) derechohabiente = gr.Dropdown(label="¿Era derechohabiente?", choices=unique_derechohabiente, value=lambda: random.choice(unique_derechohabiente)) actividad = gr.Dropdown(label="¿Era económicamente activo?", choices=unique_actividad, value=lambda: random.choice(unique_actividad)) mes = gr.Dropdown(label="Mes del deceso", choices=unique_mes, value=lambda: random.choice(unique_mes)) dia = gr.Slider(1, 31, step=1, label="Día del deceso", randomize=True, interactive=True) gr.Examples([['Ciudad de México','Iztapalapa','Mujer',68,22,'Junio','Primaria incompleta','Casado (a)','IMSS','No','Sí'], ['Sinaloa','Ahome','Hombre',77,27,'Agosto','Primaria incompleta','Casado (a)','IMSS','Si','No']], inputs=[entidad,municipio,sexo,edad,dia,mes,escolaridad,estadocivil,sitio,derechohabiente,actividad]) with gr.Column(): label = gr.Label() predecir_btn = gr.Button("Predicción") predecir_btn.click( predict, inputs=[ entidad, municipio, sexo, edad, dia, mes, escolaridad, estadocivil, sitio, derechohabiente, actividad, ], outputs=[label], ) def filter_entidad(entidad): return gr.Dropdown.update( choices=municipios[entidad], value=municipios[entidad][0] ), gr.update(visible=True) entidad.change(filter_entidad, entidad, [municipio, details_col]) if __name__ == "__main__": demo.launch()