File size: 1,625 Bytes
60ac467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
from simpletransformers.t5 import T5Model , T5Args


model_args = T5Args()
model_args.num_train_epochs = 3
#model_args.no_save = True
#model_args.evaluate_generated_text = True
#model_args.evaluate_during_training = True
#model_args.evaluate_during_training_verbose = True
model_args.overwrite_output_dir = True
model_args.fp16 = False
model_args.use_cuda = False
model_args.use_multiprocessing = False
model_args.use_multiprocessing_for_evaluation = False
model_args.use_multiprocessed_decoding = False
model_args.learning_rate=0.001
#model_args.num_beams = 3
model_args.train_batch_size = 4
model_args.eval_batch_size = 4
model_args.adafactor_beta1 = 0
model_args.length_penalty=1.5
model_args.max_length=100
model_args.max_seq_length = 100


model = T5Model("mt5", "hackathon-pln-es/itama", args=model_args , use_cuda=False)


def predict(input_text):
  p = model.predict([input_text])[0]
  return p


gr.Interface(
   fn=predict,
   inputs=gr.inputs.Textbox(lines=1, label="Pregunta por profesión - {profesión}: {pregunta}"),
   outputs=[
     gr.outputs.Textbox(label="Respuesta"),
     ],
   theme="peach",
   title='Modelo predicctivo AMA Reddit',
   description='Modelo T5 Transformer (mt5-base), utilizando dataset de preguntas y respuestas de AMA Reddit',
   examples=[
     'psicologo: cuanto trabajas al año?',
     'jefe: cuanto trabajas al año?',
     'profesor: cuando dinero ganas al año?',
     ],
   article=article,
   allow_flagging="manual",
   #flagging_options=["right translation", "wrong translation", "error", "other"],
   flagging_dir="logs"
   ).launch(enable_queue=True)