acumplid commited on
Commit
1823861
1 Parent(s): c990537

Included params on ui

Browse files
Files changed (1) hide show
  1. app.py +89 -8
app.py CHANGED
@@ -1,12 +1,15 @@
1
  import os
2
  import gradio as gr
3
- from gradio.components import Textbox, Button
4
  from AinaTheme import theme
5
  from urllib.error import HTTPError
6
 
7
  from rag import RAG
8
  from utils import setup
9
 
 
 
 
10
  setup()
11
 
12
 
@@ -19,9 +22,9 @@ rag = RAG(
19
  )
20
 
21
 
22
- def generate(prompt):
23
  try:
24
- output = rag.get_response(prompt)
25
  return output
26
  except HTTPError as err:
27
  if err.code == 400:
@@ -34,12 +37,23 @@ def generate(prompt):
34
  )
35
 
36
 
37
- def submit_input(input_):
38
  if input_.strip() == "":
39
  gr.Warning("Not possible to inference an empty input")
40
  return None
41
 
42
- output = generate(input_)
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  return output
45
 
@@ -52,8 +66,15 @@ def change_interactive(text):
52
 
53
  def clear():
54
  return (
 
55
  None,
56
- None,
 
 
 
 
 
 
57
  )
58
 
59
 
@@ -86,6 +107,55 @@ def gradio_app():
86
  # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
87
  )
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  with gr.Column(variant="panel"):
90
  output = Textbox(
91
  lines=11, label="Output", interactive=False, show_copy_button=True
@@ -96,6 +166,9 @@ def gradio_app():
96
  )
97
  submit_btn = Button("Submit", variant="primary", interactive=False)
98
 
 
 
 
99
  input_.change(
100
  fn=change_interactive,
101
  inputs=[input_],
@@ -114,10 +187,18 @@ def gradio_app():
114
  )
115
 
116
  clear_btn.click(
117
- fn=clear, inputs=[], outputs=[input_, output], queue=False, api_name=False
 
 
 
 
118
  )
 
119
  submit_btn.click(
120
- fn=submit_input, inputs=[input_], outputs=[output], api_name="get-results"
 
 
 
121
  )
122
 
123
  with gr.Row():
 
1
  import os
2
  import gradio as gr
3
+ from gradio.components import Textbox, Button, Slider, Checkbox
4
  from AinaTheme import theme
5
  from urllib.error import HTTPError
6
 
7
  from rag import RAG
8
  from utils import setup
9
 
10
+ MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", default=100))
11
+ SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
12
+
13
  setup()
14
 
15
 
 
22
  )
23
 
24
 
25
+ def generate(prompt, model_parameters):
26
  try:
27
+ output = rag.get_response(prompt, model_parameters)
28
  return output
29
  except HTTPError as err:
30
  if err.code == 400:
 
37
  )
38
 
39
 
40
+ def submit_input(input_, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature):
41
  if input_.strip() == "":
42
  gr.Warning("Not possible to inference an empty input")
43
  return None
44
 
45
+
46
+ model_parameters = {
47
+ "max_new_tokens": max_new_tokens,
48
+ "repetition_penalty": repetition_penalty,
49
+ "top_k": top_k,
50
+ "top_p": top_p,
51
+ "do_sample": do_sample,
52
+ "num_beams": num_beams,
53
+ "temperature": temperature
54
+ }
55
+
56
+ output = generate(input_, model_parameters)
57
 
58
  return output
59
 
 
66
 
67
  def clear():
68
  return (
69
+ None,
70
  None,
71
+ gr.Slider(value=100),
72
+ gr.Slider(value=1.2),
73
+ gr.Slider(value=50),
74
+ gr.Slider(value=0.95),
75
+ gr.Checkbox(value=True),
76
+ gr.Slider(value=4),
77
+ gr.Slider(value=0.5),
78
  )
79
 
80
 
 
107
  # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
108
  )
109
 
110
+ with gr.Row(variant="panel"):
111
+ with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
112
+ max_new_tokens = Slider(
113
+ minimum=1,
114
+ maximum=200,
115
+ step=1,
116
+ value=MAX_NEW_TOKENS,
117
+ label="Max tokens"
118
+ )
119
+ repetition_penalty = Slider(
120
+ minimum=0.1,
121
+ maximum=10,
122
+ step=0.1,
123
+ value=1.2,
124
+ label="Repetition penalty"
125
+ )
126
+ top_k = Slider(
127
+ minimum=1,
128
+ maximum=100,
129
+ step=1,
130
+ value=50,
131
+ label="Top k"
132
+ )
133
+ top_p = Slider(
134
+ minimum=0.01,
135
+ maximum=0.99,
136
+ value=0.95,
137
+ label="Top p"
138
+ )
139
+ do_sample = Checkbox(
140
+ value=True,
141
+ label="Do sample"
142
+ )
143
+ num_beams = Slider(
144
+ minimum=1,
145
+ maximum=8,
146
+ step=1,
147
+ value=4,
148
+ label="Beams"
149
+ )
150
+ temperature = Slider(
151
+ minimum=0,
152
+ maximum=1,
153
+ value=0.5,
154
+ label="Temperature"
155
+ )
156
+
157
+ parameters_compontents = [max_new_tokens, repetition_penalty, top_k, top_p, do_sample, num_beams, temperature]
158
+
159
  with gr.Column(variant="panel"):
160
  output = Textbox(
161
  lines=11, label="Output", interactive=False, show_copy_button=True
 
166
  )
167
  submit_btn = Button("Submit", variant="primary", interactive=False)
168
 
169
+
170
+
171
+
172
  input_.change(
173
  fn=change_interactive,
174
  inputs=[input_],
 
187
  )
188
 
189
  clear_btn.click(
190
+ fn=clear,
191
+ inputs=[],
192
+ outputs=[input_, output] + parameters_compontents,
193
+ queue=False,
194
+ api_name=False
195
  )
196
+
197
  submit_btn.click(
198
+ fn=submit_input,
199
+ inputs=[input_]+ parameters_compontents,
200
+ outputs=[output],
201
+ api_name="get-results"
202
  )
203
 
204
  with gr.Row():