Luster commited on
Commit
1860488
1 Parent(s): 8171c1b

cambio de la pregunta con dropdow

Browse files
Files changed (1) hide show
  1. app.py +37 -26
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from simpletransformers.t5 import T5Model , T5Args
3
 
4
 
5
  model_args = T5Args()
@@ -14,17 +14,18 @@ model_args.use_cuda = False
14
  model_args.use_multiprocessing = False
15
  model_args.use_multiprocessing_for_evaluation = False
16
  model_args.use_multiprocessed_decoding = False
17
- model_args.learning_rate=0.001
18
  #model_args.num_beams = 3
19
  model_args.train_batch_size = 4
20
  model_args.eval_batch_size = 4
21
  model_args.adafactor_beta1 = 0
22
- model_args.length_penalty=1.5
23
- model_args.max_length=100
24
  model_args.max_seq_length = 100
25
 
26
 
27
- model = T5Model("mt5", "hackathon-pln-es/itama", args=model_args , use_cuda=False)
 
28
 
29
  article = '''
30
  # ITAMA
@@ -34,27 +35,37 @@ ser considerado tabú. Esperamos poder generar un modelo con estas preguntas-res
34
  este conocimiento y responder a preguntas frecuentes en topicos de interés común y/o bienestar personal. '''
35
 
36
 
37
- def predict(input_text):
38
- p = model.predict([input_text])[0]
39
- return p
 
40
 
41
 
42
  gr.Interface(
43
- fn=predict,
44
- inputs=gr.inputs.Textbox(lines=1, label="Pregunta por profesión - {profesión}: {pregunta}"),
45
- outputs=[
46
- gr.outputs.Textbox(label="Respuesta"),
47
- ],
48
- theme="peach",
49
- title='Modelo predicctivo AMA Reddit',
50
- description='Modelo T5 Transformer (mt5-base), utilizando dataset de preguntas y respuestas de AMA Reddit',
51
- examples=[
52
- 'psicologo: cuanto trabajas al año?',
53
- 'jefe: cuanto trabajas al año?',
54
- 'profesor: cuando dinero ganas al año?',
55
- ],
56
- article=article,
57
- allow_flagging="manual",
58
- #flagging_options=["right translation", "wrong translation", "error", "other"],
59
- flagging_dir="logs"
60
- ).launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from simpletransformers.t5 import T5Model, T5Args
3
 
4
 
5
  model_args = T5Args()
 
14
  model_args.use_multiprocessing = False
15
  model_args.use_multiprocessing_for_evaluation = False
16
  model_args.use_multiprocessed_decoding = False
17
+ model_args.learning_rate = 0.001
18
  #model_args.num_beams = 3
19
  model_args.train_batch_size = 4
20
  model_args.eval_batch_size = 4
21
  model_args.adafactor_beta1 = 0
22
+ model_args.length_penalty = 1.5
23
+ model_args.max_length = 100
24
  model_args.max_seq_length = 100
25
 
26
 
27
+ model = T5Model("mt5", "hackathon-pln-es/itama",
28
+ args=model_args, use_cuda=False)
29
 
30
  article = '''
31
  # ITAMA
 
35
  este conocimiento y responder a preguntas frecuentes en topicos de interés común y/o bienestar personal. '''
36
 
37
 
38
+ def predict(profession, question):
39
+
40
+ prediction = model.predict([f'{profession}: {question}'])[0]
41
+ return prediction
42
 
43
 
44
  gr.Interface(
45
+ fn=predict,
46
+ #inputs=gr.inputs.Textbox(lines=1, label="Pregunta por profesión - {profesión}: {pregunta}"),
47
+ inputs=[
48
+ #gr.inputs.Textbox(label="Pregunta por profesión - {profesión}: {pregunta}"),
49
+ gr.inputs.Dropdown(choices=["medico", "psicologo", "ciencias", "ingeniero",
50
+ "profesor", "jefe", "abogado"], type="value", default='medico', label="Profesión"),
51
+ gr.inputs.Textbox(label="Pregunta"),
52
+
53
+ ],
54
+
55
+
56
+ outputs=[
57
+ gr.outputs.Textbox(label="Respuesta"),
58
+ ],
59
+ theme="peach",
60
+ title='Modelo predicctivo AMA Reddit',
61
+ description='Modelo T5 Transformer (mt5-base), utilizando dataset de preguntas y respuestas de AMA Reddit',
62
+ examples=[
63
+ ['psicologo','cuanto trabajas al año?'],
64
+ ['jefe','cuanto trabajas al año?'],
65
+ ['profesor','cuando dinero ganas al año?'],
66
+ ],
67
+ article=article,
68
+ # allow_flagging="manual",
69
+ #flagging_options=["right translation", "wrong translation", "error", "other"],
70
+ # flagging_dir="logs"
71
+ ).launch(enable_queue=True)