import gradio as gr from simpletransformers.t5 import T5Model , T5Args model_args = T5Args() model_args.num_train_epochs = 3 #model_args.no_save = True #model_args.evaluate_generated_text = True #model_args.evaluate_during_training = True #model_args.evaluate_during_training_verbose = True model_args.overwrite_output_dir = True model_args.fp16 = False model_args.use_cuda = False model_args.use_multiprocessing = False model_args.use_multiprocessing_for_evaluation = False model_args.use_multiprocessed_decoding = False model_args.learning_rate=0.001 #model_args.num_beams = 3 model_args.train_batch_size = 4 model_args.eval_batch_size = 4 model_args.adafactor_beta1 = 0 model_args.length_penalty=1.5 model_args.max_length=100 model_args.max_seq_length = 100 model = T5Model("mt5", "hackathon-pln-es/itama", args=model_args , use_cuda=False) def predict(input_text): p = model.predict([input_text])[0] return p gr.Interface( fn=predict, inputs=gr.inputs.Textbox(lines=1, label="Pregunta por profesión - {profesión}: {pregunta}"), outputs=[ gr.outputs.Textbox(label="Respuesta"), ], theme="peach", title='Modelo predicctivo AMA Reddit', description='Modelo T5 Transformer (mt5-base), utilizando dataset de preguntas y respuestas de AMA Reddit', examples=[ 'psicologo: cuanto trabajas al año?', 'jefe: cuanto trabajas al año?', 'profesor: cuando dinero ganas al año?', ], article=article, allow_flagging="manual", #flagging_options=["right translation", "wrong translation", "error", "other"], flagging_dir="logs" ).launch(enable_queue=True)