cahya commited on
Commit
8454eb5
1 Parent(s): a4b0cb5

fix num beams

Browse files
Files changed (1) hide show
  1. app.py +4 -1
app.py CHANGED
@@ -16,8 +16,10 @@ def get_answer(user_input, decoding_methods, num_beams, top_k, top_p, temperatur
16
  elif decoding_methods == "Sampling":
17
  do_sample = True
18
  penalty_alpha = 0
 
19
  else:
20
  do_sample = False
 
21
  print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
22
  prompt = f"User: {user_input}\nAssistant: "
23
  generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
@@ -46,7 +48,8 @@ with gr.Blocks() as demo:
46
  top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
47
  temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
48
  repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
49
- penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
 
50
  with gr.Row():
51
  button_generate_story = gr.Button("Submit")
52
  with gr.Column():
 
16
  elif decoding_methods == "Sampling":
17
  do_sample = True
18
  penalty_alpha = 0
19
+ num_beams = 1
20
  else:
21
  do_sample = False
22
+ num_beams = 1
23
  print(user_input, decoding_methods, do_sample, top_k, top_p, temperature, repetition_penalty, penalty_alpha)
24
  prompt = f"User: {user_input}\nAssistant: "
25
  generated_text = text_generation(f"{prompt}", min_length=50, max_length=200, num_return_sequences=1,
 
48
  top_p = gr.inputs.Slider(label="Top P", default=0.9, step=0.05, minimum=0.1, maximum=1.0)
49
  temperature = gr.inputs.Slider(label="Temperature", default=0.5, step=0.05, minimum=0.1, maximum=1.0)
50
  repetition_penalty = gr.inputs.Slider(label="Repetition Penalty", default=1.1, step=0.05, minimum=1.0, maximum=2.0)
51
+ penalty_alpha = gr.inputs.Slider(label="The penalty alpha for contrastive search",
52
+ default=0.5, step=0.05, minimum=0.05, maximum=1.0)
53
  with gr.Row():
54
  button_generate_story = gr.Button("Submit")
55
  with gr.Column():