lewtun HF staff commited on
Commit
b0042a5
1 Parent(s): 97f74a7

Remove lenght penalty

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -32,7 +32,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=HF_TOKEN)
32
  PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer: """
33
 
34
 
35
- def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
36
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
37
  # COMMENT IN FOR NON STREAMING
38
  # generation_config = GenerationConfig(
@@ -66,9 +66,8 @@ def generate(instruction, temperature, max_new_tokens, top_p, length_penalty):
66
  top_p=top_p,
67
  temperature=temperature,
68
  max_new_tokens=max_new_tokens,
69
- # early_stopping=True, # Not sure if we want this
70
- top_k=0, # Maybe set top_k=40 if results are bad
71
- length_penalty=length_penalty,
72
  eos_token_id=tokenizer.eos_token_id,
73
  pad_token_id=tokenizer.eos_token_id,
74
  )
@@ -163,20 +162,18 @@ with gr.Blocks(theme=theme) as demo:
163
  interactive=True,
164
  info="Higher values sample fewer low-probability tokens",
165
  )
166
- length_penalty = gr.Slider(
167
- label="Length penalty",
168
- value=1.0,
169
- minimum=-10.0,
170
- maximum=10.0,
171
- step=0.1,
172
  interactive=True,
173
- info="> 0 longer, < 0 shorter",
174
  )
175
 
176
- submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, length_penalty], outputs=[output])
177
- instruction.submit(
178
- generate, inputs=[instruction, temperature, max_new_tokens, top_p, length_penalty], outputs=[output]
179
- )
180
 
181
  demo.queue()
182
  demo.launch()
 
32
  PROMPT_TEMPLATE = """Question: {prompt}\n\nAnswer: """
33
 
34
 
35
+ def generate(instruction, temperature=1, max_new_tokens=256, top_p=1, top_k=50):
36
  formatted_instruction = PROMPT_TEMPLATE.format(prompt=instruction)
37
  # COMMENT IN FOR NON STREAMING
38
  # generation_config = GenerationConfig(
 
66
  top_p=top_p,
67
  temperature=temperature,
68
  max_new_tokens=max_new_tokens,
69
+ do_sample=True,
70
+ top_k=top_k,
 
71
  eos_token_id=tokenizer.eos_token_id,
72
  pad_token_id=tokenizer.eos_token_id,
73
  )
 
162
  interactive=True,
163
  info="Higher values sample fewer low-probability tokens",
164
  )
165
+ top_k = gr.Slider(
166
+ label="Top-k",
167
+ value=50,
168
+ minimum=0,
169
+ maximum=100,
170
+ step=2,
171
  interactive=True,
172
+ info="Sample from top-k tokens",
173
  )
174
 
175
+ submit.click(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
176
+ instruction.submit(generate, inputs=[instruction, temperature, max_new_tokens, top_p, top_k], outputs=[output])
 
 
177
 
178
  demo.queue()
179
  demo.launch()