Luster commited on
Commit
60ac467
1 Parent(s): 2d56377

codigo base app pregunta y respuesta

Browse files
Files changed (2) hide show
  1. app.py +53 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from simpletransformers.t5 import T5Model , T5Args
3
+
4
+
5
+ model_args = T5Args()
6
+ model_args.num_train_epochs = 3
7
+ #model_args.no_save = True
8
+ #model_args.evaluate_generated_text = True
9
+ #model_args.evaluate_during_training = True
10
+ #model_args.evaluate_during_training_verbose = True
11
+ model_args.overwrite_output_dir = True
12
+ model_args.fp16 = False
13
+ model_args.use_cuda = False
14
+ model_args.use_multiprocessing = False
15
+ model_args.use_multiprocessing_for_evaluation = False
16
+ model_args.use_multiprocessed_decoding = False
17
+ model_args.learning_rate=0.001
18
+ #model_args.num_beams = 3
19
+ model_args.train_batch_size = 4
20
+ model_args.eval_batch_size = 4
21
+ model_args.adafactor_beta1 = 0
22
+ model_args.length_penalty=1.5
23
+ model_args.max_length=100
24
+ model_args.max_seq_length = 100
25
+
26
+
27
+ model = T5Model("mt5", "hackathon-pln-es/itama", args=model_args , use_cuda=False)
28
+
29
+
30
+ def predict(input_text):
31
+ p = model.predict([input_text])[0]
32
+ return p
33
+
34
+
35
+ gr.Interface(
36
+ fn=predict,
37
+ inputs=gr.inputs.Textbox(lines=1, label="Pregunta por profesión - {profesión}: {pregunta}"),
38
+ outputs=[
39
+ gr.outputs.Textbox(label="Respuesta"),
40
+ ],
41
+ theme="peach",
42
+ title='Modelo predicctivo AMA Reddit',
43
+ description='Modelo T5 Transformer (mt5-base), utilizando dataset de preguntas y respuestas de AMA Reddit',
44
+ examples=[
45
+ 'psicologo: cuanto trabajas al año?',
46
+ 'jefe: cuanto trabajas al año?',
47
+ 'profesor: cuando dinero ganas al año?',
48
+ ],
49
+ article=article,
50
+ allow_flagging="manual",
51
+ #flagging_options=["right translation", "wrong translation", "error", "other"],
52
+ flagging_dir="logs"
53
+ ).launch(enable_queue=True)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ simpletransformers==0.63.6
3
+ torch