import gradio as gr from simpletransformers.t5 import T5Model, T5Args model_args = T5Args() model_args.num_train_epochs = 3 #model_args.no_save = True #model_args.evaluate_generated_text = True #model_args.evaluate_during_training = True #model_args.evaluate_during_training_verbose = True model_args.overwrite_output_dir = True model_args.fp16 = False model_args.use_cuda = False model_args.use_multiprocessing = False model_args.use_multiprocessing_for_evaluation = False model_args.use_multiprocessed_decoding = False model_args.learning_rate = 0.001 #model_args.num_beams = 3 model_args.train_batch_size = 4 model_args.eval_batch_size = 4 model_args.adafactor_beta1 = 0 model_args.length_penalty = 1.5 model_args.max_length = 100 model_args.max_seq_length = 100 model = T5Model("mt5", "hackathon-pln-es/itama", args=model_args, use_cuda=False) article = ''' # ITAMA Reddit (y sus subreddits en español), proveen una gran cantidad de hilos en donde expertos se ofrecen a contestar voluntariamente preguntas y los usuarios realizan preguntas que en un contexto normal podrían ser considerado tabú. Esperamos poder generar un modelo con estas preguntas-respuestas que pueda consolidar este conocimiento y responder a preguntas frecuentes en topicos de interés común y/o bienestar personal. ''' def predict(profession, question): prediction = model.predict([f'{profession}: {question}'])[0] return prediction gr.Interface( fn=predict, #inputs=gr.inputs.Textbox(lines=1, label="Pregunta por profesión - {profesión}: {pregunta}"), inputs=[ #gr.inputs.Textbox(label="Pregunta por profesión - {profesión}: {pregunta}"), gr.inputs.Dropdown(choices=["medico", "psicologo", "ciencias", "ingeniero", "profesor", "jefe", "abogado"], type="value", default='medico', label="Profesión"), gr.inputs.Textbox(label="Pregunta"), ], outputs=[ gr.outputs.Textbox(label="Respuesta"), ], theme="peach", title='Modelo predicctivo AMA Reddit', description='Modelo T5 Transformer (mt5-base), utilizando dataset de preguntas y respuestas de AMA Reddit', examples=[ ['psicologo','cuanto trabajas al año?'], ['jefe','cuanto trabajas al año?'], ['profesor','cuando dinero ganas al año?'], ], article=article, # allow_flagging="manual", #flagging_options=["right translation", "wrong translation", "error", "other"], # flagging_dir="logs" ).launch(enable_queue=True)