Leyo commited on
Commit
cf142d2
1 Parent(s): c586e09

add beam search

Browse files
Files changed (1) hide show
  1. app_dialogue.py +14 -1
app_dialogue.py CHANGED
@@ -286,12 +286,22 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
286
  decoding_strategy = gr.Radio(
287
  [
288
  "greedy",
 
 
289
  "sampling_top_k",
290
  "sampling_top_p",
291
  ],
292
  value="greedy",
293
  label="Decoding strategy",
294
  )
 
 
 
 
 
 
 
 
295
  temperature = gr.Slider(
296
  minimum=0.0,
297
  maximum=1.0,
@@ -436,6 +446,7 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
436
  user_prompt,
437
  chat_history,
438
  decoding_strategy,
 
439
  temperature,
440
  no_repeat_ngram_size,
441
  max_new_tokens,
@@ -455,7 +466,7 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
455
  # repetition_penalty = 1.0
456
  hide_special_tokens = False
457
  # decoding_strategy = "greedy"
458
- num_beams = 3
459
  # length_penalty = 1.0
460
  # top_k = 50
461
  # top_p = 0.95
@@ -497,6 +508,7 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
497
  textbox,
498
  chatbot,
499
  decoding_strategy,
 
500
  temperature,
501
  no_repeat_ngram_size,
502
  max_new_tokens,
@@ -515,6 +527,7 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
515
  textbox,
516
  chatbot,
517
  decoding_strategy,
 
518
  temperature,
519
  no_repeat_ngram_size,
520
  max_new_tokens,
 
286
  decoding_strategy = gr.Radio(
287
  [
288
  "greedy",
289
+ "beam_search",
290
+ "beam_sampling",
291
  "sampling_top_k",
292
  "sampling_top_p",
293
  ],
294
  value="greedy",
295
  label="Decoding strategy",
296
  )
297
+ num_beams = top_k = gr.Slider(
298
+ minimum=0,
299
+ maximum=20,
300
+ value=3.0,
301
+ step=1.0,
302
+ interactive=True,
303
+ label="Number of beams",
304
+ )
305
  temperature = gr.Slider(
306
  minimum=0.0,
307
  maximum=1.0,
 
446
  user_prompt,
447
  chat_history,
448
  decoding_strategy,
449
+ num_beams,
450
  temperature,
451
  no_repeat_ngram_size,
452
  max_new_tokens,
 
466
  # repetition_penalty = 1.0
467
  hide_special_tokens = False
468
  # decoding_strategy = "greedy"
469
+ # num_beams = 3
470
  # length_penalty = 1.0
471
  # top_k = 50
472
  # top_p = 0.95
 
508
  textbox,
509
  chatbot,
510
  decoding_strategy,
511
+ num_beams,
512
  temperature,
513
  no_repeat_ngram_size,
514
  max_new_tokens,
 
527
  textbox,
528
  chatbot,
529
  decoding_strategy,
530
+ num_beams,
531
  temperature,
532
  no_repeat_ngram_size,
533
  max_new_tokens,