fix num beams
Browse files
app.py
CHANGED
@@ -16,8 +16,10 @@ def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperatur
|
|
16 |
elif decoding_methods == "Sampling":
|
17 |
do_sample = True
|
18 |
penalty_alpha = 0
|
|
|
19 |
else:
|
20 |
do_sample = False
|
|
|
21 |
print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
|
22 |
prompt = f"User: {user_input}\nAssistant: "
|
23 |
generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
|
@@ -46,7 +48,8 @@ with gr.Blocks() as demo:
|
|
46 |
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
|
47 |
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
|
48 |
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
|
49 |
-
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
|
|
|
50 |
with gr.Row():
|
51 |
button_generate_story = gr.Button("Submit")
|
52 |
with gr.Column():
|
|
|
16 |
elif decoding_methods == "Sampling":
|
17 |
do_sample = True
|
18 |
penalty_alpha = 0
|
19 |
+
num_beams = 1
|
20 |
else:
|
21 |
do_sample = False
|
22 |
+
num_beams = 1
|
23 |
print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
|
24 |
prompt = f"User: {user_input}\nAssistant: "
|
25 |
generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
|
|
|
48 |
top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
|
49 |
temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
|
50 |
repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
|
51 |
+
penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
|
52 |
+
default=0.5, step=0.05, minimum=0.05, maximum=1.0)
|
53 |
with gr.Row():
|
54 |
button_generate_story = gr.Button("Submit")
|
55 |
with gr.Column():
|