retarfi
commited on
Commit
•
160c75c
1
Parent(s):
b541a54
add repetition penalty
Browse files
app.py
CHANGED
@@ -160,6 +160,7 @@ def evaluate(
|
|
160 |
input=None,
|
161 |
temperature=0.7,
|
162 |
max_tokens=384,
|
|
|
163 |
):
|
164 |
num_beams: int = 1
|
165 |
top_p: float = 1.0
|
@@ -186,13 +187,17 @@ def evaluate(
|
|
186 |
)
|
187 |
except Exception as e:
|
188 |
print(e)
|
189 |
-
return
|
|
|
|
|
|
|
|
|
190 |
input_ids = inputs["input_ids"].to(device)
|
191 |
generation_config = GenerationConfig(
|
192 |
temperature=temperature,
|
193 |
top_p=top_p,
|
194 |
top_k=top_k,
|
195 |
-
repetition_penalty=
|
196 |
num_beams=num_beams,
|
197 |
pad_token_id=tokenizer.pad_token_id,
|
198 |
eos_token=tokenizer.eos_token_id,
|
@@ -203,7 +208,7 @@ def evaluate(
|
|
203 |
generation_config=generation_config,
|
204 |
return_dict_in_generate=True,
|
205 |
output_scores=True,
|
206 |
-
max_new_tokens=max_tokens-len(input_ids),
|
207 |
)
|
208 |
s = generation_output.sequences[0]
|
209 |
output = tokenizer.decode(s, skip_special_tokens=True)
|
@@ -292,6 +297,14 @@ with gr.Blocks(
|
|
292 |
interactive=True,
|
293 |
label="Max length (Pre-prompt + instruction + input + output))",
|
294 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
with gr.Column(elem_id="user_consent_container") as user_consent_block:
|
297 |
# Get user consent
|
@@ -334,14 +347,14 @@ with gr.Blocks(
|
|
334 |
inputs.submit(no_interactive, [], [submit_button, clear_button])
|
335 |
inputs.submit(
|
336 |
evaluate,
|
337 |
-
[instruction, inputs, temperature, max_tokens],
|
338 |
[outputs, submit_button, clear_button],
|
339 |
)
|
340 |
submit_button.click(no_interactive, [], [submit_button, clear_button])
|
341 |
submit_button.click(
|
342 |
evaluate,
|
343 |
[instruction, inputs, temperature, max_tokens],
|
344 |
-
[outputs, submit_button, clear_button],
|
345 |
)
|
346 |
clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
|
347 |
|
|
|
160 |
input=None,
|
161 |
temperature=0.7,
|
162 |
max_tokens=384,
|
163 |
+
repetition_penalty=1.0,
|
164 |
):
|
165 |
num_beams: int = 1
|
166 |
top_p: float = 1.0
|
|
|
187 |
)
|
188 |
except Exception as e:
|
189 |
print(e)
|
190 |
+
return (
|
191 |
+
f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} tokens are used.",
|
192 |
+
gr.update(interactive=True),
|
193 |
+
gr.update(interactive=True),
|
194 |
+
)
|
195 |
input_ids = inputs["input_ids"].to(device)
|
196 |
generation_config = GenerationConfig(
|
197 |
temperature=temperature,
|
198 |
top_p=top_p,
|
199 |
top_k=top_k,
|
200 |
+
repetition_penalty=repetition_penalty,
|
201 |
num_beams=num_beams,
|
202 |
pad_token_id=tokenizer.pad_token_id,
|
203 |
eos_token=tokenizer.eos_token_id,
|
|
|
208 |
generation_config=generation_config,
|
209 |
return_dict_in_generate=True,
|
210 |
output_scores=True,
|
211 |
+
max_new_tokens=max_tokens - len(input_ids),
|
212 |
)
|
213 |
s = generation_output.sequences[0]
|
214 |
output = tokenizer.decode(s, skip_special_tokens=True)
|
|
|
297 |
interactive=True,
|
298 |
label="Max length (Pre-prompt + instruction + input + output))",
|
299 |
)
|
300 |
+
repetition_penalty = gr.Slider(
|
301 |
+
minimum=1.0,
|
302 |
+
maximum=5.0,
|
303 |
+
value=1.2,
|
304 |
+
step=0.05,
|
305 |
+
interactive=True,
|
306 |
+
label="Repetition penalty",
|
307 |
+
)
|
308 |
|
309 |
with gr.Column(elem_id="user_consent_container") as user_consent_block:
|
310 |
# Get user consent
|
|
|
347 |
inputs.submit(no_interactive, [], [submit_button, clear_button])
|
348 |
inputs.submit(
|
349 |
evaluate,
|
350 |
+
[instruction, inputs, temperature, max_tokens, repetition_penalty],
|
351 |
[outputs, submit_button, clear_button],
|
352 |
)
|
353 |
submit_button.click(no_interactive, [], [submit_button, clear_button])
|
354 |
submit_button.click(
|
355 |
evaluate,
|
356 |
[instruction, inputs, temperature, max_tokens],
|
357 |
+
[outputs, submit_button, clear_button, repetition_penalty],
|
358 |
)
|
359 |
clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
|
360 |
|