lvwerra HF staff commited on
Commit
ddfd0c4
1 Parent(s): 9330390

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -45,7 +45,7 @@ def save_inputs_and_outputs(inputs, outputs, generate_kwargs):
45
  commit_url = repo.push_to_hub()
46
 
47
 
48
- def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k=100, do_save=True):
49
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
50
 
51
  temperature = float(temperature)
@@ -57,7 +57,7 @@ def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, top_k
57
  temperature=temperature,
58
  max_new_tokens=max_new_tokens,
59
  top_p=top_p,
60
- top_k=top_k,
61
  do_sample=True,
62
  truncate=999,
63
  seed=42,
@@ -162,17 +162,17 @@ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
162
  interactive=True,
163
  info="Higher values sample more low-probability tokens",
164
  )
165
- top_k = gr.Slider(
166
- label="Top-k",
167
- value=50,
168
- minimum=0,
169
- maximum=100,
170
- step=2,
171
  interactive=True,
172
- info="Sample from top-k tokens",
173
  )
174
 
175
- submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k, do_save], outputs=[output])
176
  instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
177
  share_button.click(None, [], [], _js=share_js)
178
 
 
45
  commit_url = repo.push_to_hub()
46
 
47
 
48
+ def generate(instruction, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, do_save=True):
49
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
50
 
51
  temperature = float(temperature)
 
57
  temperature=temperature,
58
  max_new_tokens=max_new_tokens,
59
  top_p=top_p,
60
+ repetition_penalty=repetition_penalty,
61
  do_sample=True,
62
  truncate=999,
63
  seed=42,
 
162
  interactive=True,
163
  info="Higher values sample more low-probability tokens",
164
  )
165
+ repetition_penalty = gr.Slider(
166
+ label="Repetition penalty",
167
+ value=1.0,
168
+ minimum=1.0,
169
+ maximum=2.0,
170
+ step=0.05,
171
  interactive=True,
172
+ info="Penalize repeated tokens",
173
  )
174
 
175
+ submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty, do_save], outputs=[output])
176
  instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
177
  share_button.click(None, [], [], _js=share_js)
178