retarfi commited on
Commit
160c75c
1 Parent(s): b541a54

add repetition penalty

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -160,6 +160,7 @@ def evaluate(
160
  input=None,
161
  temperature=0.7,
162
  max_tokens=384,
 
163
  ):
164
  num_beams: int = 1
165
  top_p: float = 1.0
@@ -186,13 +187,17 @@ def evaluate(
186
  )
187
  except Exception as e:
188
  print(e)
189
- return f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} tokens are used.", gr.update(interactive=True), gr.update(interactive=True)
 
 
 
 
190
  input_ids = inputs["input_ids"].to(device)
191
  generation_config = GenerationConfig(
192
  temperature=temperature,
193
  top_p=top_p,
194
  top_k=top_k,
195
- repetition_penalty=1.5,
196
  num_beams=num_beams,
197
  pad_token_id=tokenizer.pad_token_id,
198
  eos_token=tokenizer.eos_token_id,
@@ -203,7 +208,7 @@ def evaluate(
203
  generation_config=generation_config,
204
  return_dict_in_generate=True,
205
  output_scores=True,
206
- max_new_tokens=max_tokens-len(input_ids),
207
  )
208
  s = generation_output.sequences[0]
209
  output = tokenizer.decode(s, skip_special_tokens=True)
@@ -292,6 +297,14 @@ with gr.Blocks(
292
  interactive=True,
293
  label="Max length (Pre-prompt + instruction + input + output))",
294
  )
 
 
 
 
 
 
 
 
295
 
296
  with gr.Column(elem_id="user_consent_container") as user_consent_block:
297
  # Get user consent
@@ -334,14 +347,14 @@ with gr.Blocks(
334
  inputs.submit(no_interactive, [], [submit_button, clear_button])
335
  inputs.submit(
336
  evaluate,
337
- [instruction, inputs, temperature, max_tokens],
338
  [outputs, submit_button, clear_button],
339
  )
340
  submit_button.click(no_interactive, [], [submit_button, clear_button])
341
  submit_button.click(
342
  evaluate,
343
  [instruction, inputs, temperature, max_tokens],
344
- [outputs, submit_button, clear_button],
345
  )
346
  clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
347
 
 
160
  input=None,
161
  temperature=0.7,
162
  max_tokens=384,
163
+ repetition_penalty=1.0,
164
  ):
165
  num_beams: int = 1
166
  top_p: float = 1.0
 
187
  )
188
  except Exception as e:
189
  print(e)
190
+ return (
191
+ f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} tokens are used.",
192
+ gr.update(interactive=True),
193
+ gr.update(interactive=True),
194
+ )
195
  input_ids = inputs["input_ids"].to(device)
196
  generation_config = GenerationConfig(
197
  temperature=temperature,
198
  top_p=top_p,
199
  top_k=top_k,
200
+ repetition_penalty=repetition_penalty,
201
  num_beams=num_beams,
202
  pad_token_id=tokenizer.pad_token_id,
203
  eos_token=tokenizer.eos_token_id,
 
208
  generation_config=generation_config,
209
  return_dict_in_generate=True,
210
  output_scores=True,
211
+ max_new_tokens=max_tokens - len(input_ids),
212
  )
213
  s = generation_output.sequences[0]
214
  output = tokenizer.decode(s, skip_special_tokens=True)
 
297
  interactive=True,
298
  label="Max length (Pre-prompt + instruction + input + output))",
299
  )
300
+ repetition_penalty = gr.Slider(
301
+ minimum=1.0,
302
+ maximum=5.0,
303
+ value=1.2,
304
+ step=0.05,
305
+ interactive=True,
306
+ label="Repetition penalty",
307
+ )
308
 
309
  with gr.Column(elem_id="user_consent_container") as user_consent_block:
310
  # Get user consent
 
347
  inputs.submit(no_interactive, [], [submit_button, clear_button])
348
  inputs.submit(
349
  evaluate,
350
+ [instruction, inputs, temperature, max_tokens, repetition_penalty],
351
  [outputs, submit_button, clear_button],
352
  )
353
  submit_button.click(no_interactive, [], [submit_button, clear_button])
354
  submit_button.click(
355
  evaluate,
356
  [instruction, inputs, temperature, max_tokens],
357
+ [outputs, submit_button, clear_button, repetition_penalty],
358
  )
359
  clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
360