Spaces:
Running
Running
Update chatbot.py
Browse files- chatbot.py +12 -13
chatbot.py
CHANGED
@@ -315,7 +315,7 @@ def model_inference(
|
|
315 |
temperature,
|
316 |
max_new_tokens,
|
317 |
repetition_penalty,
|
318 |
-
|
319 |
web_search,
|
320 |
):
|
321 |
# Define generation_args at the beginning of the function
|
@@ -332,7 +332,6 @@ def model_inference(
|
|
332 |
generate_kwargs = dict(
|
333 |
max_new_tokens=4000,
|
334 |
do_sample=True,
|
335 |
-
min_p=0.08,
|
336 |
)
|
337 |
# Format the prompt for the language model
|
338 |
formatted_prompt = format_prompt(
|
@@ -352,7 +351,6 @@ def model_inference(
|
|
352 |
generate_kwargs = dict(
|
353 |
max_new_tokens=5000,
|
354 |
do_sample=True,
|
355 |
-
min_p=0.08,
|
356 |
)
|
357 |
# Format the prompt for the language model
|
358 |
formatted_prompt = format_prompt(
|
@@ -391,15 +389,16 @@ def model_inference(
|
|
391 |
}
|
392 |
assert decoding_strategy in [
|
393 |
"Greedy",
|
394 |
-
"
|
395 |
]
|
396 |
|
397 |
if decoding_strategy == "Greedy":
|
398 |
generation_args["do_sample"] = False
|
399 |
-
elif decoding_strategy == "
|
400 |
generation_args["temperature"] = temperature
|
401 |
generation_args["do_sample"] = True
|
402 |
-
generation_args["
|
|
|
403 |
(
|
404 |
resulting_text,
|
405 |
resulting_images,
|
@@ -441,7 +440,7 @@ FEATURES = datasets.Features(
|
|
441 |
"temperature": datasets.Value("float32"),
|
442 |
"max_new_tokens": datasets.Value("int32"),
|
443 |
"repetition_penalty": datasets.Value("float32"),
|
444 |
-
"
|
445 |
}
|
446 |
)
|
447 |
|
@@ -466,9 +465,9 @@ repetition_penalty = gr.Slider(
|
|
466 |
decoding_strategy = gr.Radio(
|
467 |
[
|
468 |
"Greedy",
|
469 |
-
"
|
470 |
],
|
471 |
-
value="
|
472 |
label="Decoding strategy",
|
473 |
interactive=True,
|
474 |
info="Higher values are equivalent to sampling more low-probability tokens.",
|
@@ -483,14 +482,14 @@ temperature = gr.Slider(
|
|
483 |
label="Sampling temperature",
|
484 |
info="Higher values will produce more diverse outputs.",
|
485 |
)
|
486 |
-
|
487 |
minimum=0.01,
|
488 |
-
maximum=0.
|
489 |
-
value=0.
|
490 |
step=0.01,
|
491 |
visible=True,
|
492 |
interactive=True,
|
493 |
-
label="
|
494 |
info="Higher values are equivalent to sampling more low-probability tokens.",
|
495 |
)
|
496 |
|
|
|
315 |
temperature,
|
316 |
max_new_tokens,
|
317 |
repetition_penalty,
|
318 |
+
top_p,
|
319 |
web_search,
|
320 |
):
|
321 |
# Define generation_args at the beginning of the function
|
|
|
332 |
generate_kwargs = dict(
|
333 |
max_new_tokens=4000,
|
334 |
do_sample=True,
|
|
|
335 |
)
|
336 |
# Format the prompt for the language model
|
337 |
formatted_prompt = format_prompt(
|
|
|
351 |
generate_kwargs = dict(
|
352 |
max_new_tokens=5000,
|
353 |
do_sample=True,
|
|
|
354 |
)
|
355 |
# Format the prompt for the language model
|
356 |
formatted_prompt = format_prompt(
|
|
|
389 |
}
|
390 |
assert decoding_strategy in [
|
391 |
"Greedy",
|
392 |
+
"Top P Sampling",
|
393 |
]
|
394 |
|
395 |
if decoding_strategy == "Greedy":
|
396 |
generation_args["do_sample"] = False
|
397 |
+
elif decoding_strategy == "Top P Sampling":
|
398 |
generation_args["temperature"] = temperature
|
399 |
generation_args["do_sample"] = True
|
400 |
+
generation_args["top_p"] = top_p
|
401 |
+
# Creating model inputs
|
402 |
(
|
403 |
resulting_text,
|
404 |
resulting_images,
|
|
|
440 |
"temperature": datasets.Value("float32"),
|
441 |
"max_new_tokens": datasets.Value("int32"),
|
442 |
"repetition_penalty": datasets.Value("float32"),
|
443 |
+
"top_p": datasets.Value("int32"),
|
444 |
}
|
445 |
)
|
446 |
|
|
|
465 |
decoding_strategy = gr.Radio(
|
466 |
[
|
467 |
"Greedy",
|
468 |
+
"Top P Sampling",
|
469 |
],
|
470 |
+
value="Top P Sampling",
|
471 |
label="Decoding strategy",
|
472 |
interactive=True,
|
473 |
info="Higher values are equivalent to sampling more low-probability tokens.",
|
|
|
482 |
label="Sampling temperature",
|
483 |
info="Higher values will produce more diverse outputs.",
|
484 |
)
|
485 |
+
top_p = gr.Slider(
|
486 |
minimum=0.01,
|
487 |
+
maximum=0.99,
|
488 |
+
value=0.9,
|
489 |
step=0.01,
|
490 |
visible=True,
|
491 |
interactive=True,
|
492 |
+
label="Top P",
|
493 |
info="Higher values are equivalent to sampling more low-probability tokens.",
|
494 |
)
|
495 |
|