versae commited on
Commit
38c5edb
1 Parent(s): 08a4871

Add log for preiously generated texts

Browse files
Files changed (1) hide show
  1. gradio_app.py +18 -11
gradio_app.py CHANGED
@@ -159,24 +159,31 @@ class TextGeneration:
159
  # return generated
160
 
161
 
162
- def generate(self, text, generation_kwargs):
163
- max_length = len(self.tokenizer(text)["input_ids"]) + generation_kwargs["max_length"]
 
164
  generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
165
  generated_text = None
166
- if text:
167
  for _ in range(10):
168
  generated_text = self.generator(
169
- text,
170
  **generation_kwargs,
171
  )[0]["generated_text"]
172
  if generation_kwargs["do_clean"]:
173
  generated_text = cleaner.clean_txt(generated_text)
174
- if generated_text.strip().startswith(text):
175
- generated_text = generated_text.replace(text, "", 1).strip()
176
  if generated_text:
 
 
 
 
 
 
177
  return (
178
- text + " " + generated_text,
179
- [(text, None), (generated_text, "BERTIN")]
180
  )
181
  if not generated_text:
182
  return (
@@ -221,7 +228,7 @@ def expand_with_gpt(hidden, text, max_length, top_k, top_p, temperature, do_samp
221
  "do_sample": do_sample,
222
  "do_clean": do_clean,
223
  }
224
- return generator.generate(hidden or text, generation_kwargs)
225
 
226
  def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k, top_p, temperature, do_sample, do_clean):
227
  # agent = AGENT
@@ -339,7 +346,7 @@ with gr.Blocks() as demo:
339
  hidden = gr.Textbox(visible=False, show_label=False)
340
  with gr.Box():
341
  # output = gr.Markdown()
342
- output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"})
343
  with gr.Row():
344
  generate_btn = gr.Button("Generar")
345
  generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output])
@@ -358,7 +365,7 @@ with gr.Blocks() as demo:
358
  with gr.Row():
359
  agent = gr.Textbox(label="Agente", value=AGENT)
360
  user = gr.Textbox(label="Usuario", value=USER)
361
- history = gr.Variable(default_value=[])
362
  chatbot = gr.Chatbot(color_map=("green", "gray"))
363
  with gr.Row():
364
  message = gr.Textbox(placeholder="Escriba aquí su mensaje y pulse 'Enviar'", show_label=False)
 
159
  # return generated
160
 
161
 
162
+ def generate(self, text, generation_kwargs, previous_text=None):
163
+ input_text = previous_text or text
164
+ max_length = len(self.tokenizer(input_text)["input_ids"]) + generation_kwargs["max_length"]
165
  generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
166
  generated_text = None
167
+ if input_text:
168
  for _ in range(10):
169
  generated_text = self.generator(
170
+ input_text,
171
  **generation_kwargs,
172
  )[0]["generated_text"]
173
  if generation_kwargs["do_clean"]:
174
  generated_text = cleaner.clean_txt(generated_text)
175
+ if generated_text.strip().startswith(input_text):
176
+ generated_text = generated_text.replace(input_text, "", 1).strip()
177
  if generated_text:
178
+ if previous_text and previous_text != text:
179
+ diff = [
180
+ (text, None), (previous_text.replace(text, " ", 1).strip(), " "), (generated_text, AGENT)
181
+ ]
182
+ else:
183
+ diff = [(text, None), (generated_text, AGENT)]
184
  return (
185
+ input_text + " " + generated_text,
186
+ diff
187
  )
188
  if not generated_text:
189
  return (
 
228
  "do_sample": do_sample,
229
  "do_clean": do_clean,
230
  }
231
+ return generator.generate(text, generation_kwargs, previous_text=hidden)
232
 
233
  def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k, top_p, temperature, do_sample, do_clean):
234
  # agent = AGENT
 
346
  hidden = gr.Textbox(visible=False, show_label=False)
347
  with gr.Box():
348
  # output = gr.Markdown()
349
+ output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={AGENT: "green", "ERROR": "red", " ": "blue"})
350
  with gr.Row():
351
  generate_btn = gr.Button("Generar")
352
  generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output])
 
365
  with gr.Row():
366
  agent = gr.Textbox(label="Agente", value=AGENT)
367
  user = gr.Textbox(label="Usuario", value=USER)
368
+ history = gr.Variable(value=[])
369
  chatbot = gr.Chatbot(color_map=("green", "gray"))
370
  with gr.Row():
371
  message = gr.Textbox(placeholder="Escriba aquí su mensaje y pulse 'Enviar'", show_label=False)