Leyo commited on
Commit
cf12ee0
1 Parent(s): 4932b87

add sampling decoding strategies

Browse files
Files changed (1) hide show
  1. app_dialogue.py +12 -1
app_dialogue.py CHANGED
@@ -283,6 +283,14 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
283
  )
284
 
285
  with gr.Accordion("Parameters", open=False, visible=True) as parameter_row:
 
 
 
 
 
 
 
 
286
  temperature = gr.Slider(
287
  minimum=0.0,
288
  maximum=1.0,
@@ -426,6 +434,7 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
426
  def model_inference(
427
  user_prompt,
428
  chat_history,
 
429
  temperature=1.0,
430
  no_repeat_ngram_size=0,
431
  max_new_tokens=512,
@@ -444,7 +453,7 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
444
  force_words = ""
445
  # repetition_penalty = 1.0
446
  hide_special_tokens = False
447
- decoding_strategy = "greedy"
448
  num_beams = 3
449
  # length_penalty = 1.0
450
  # top_k = 50
@@ -486,6 +495,7 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
486
  inputs=[
487
  textbox,
488
  chatbot,
 
489
  temperature,
490
  no_repeat_ngram_size,
491
  max_new_tokens,
@@ -503,6 +513,7 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
503
  inputs=[
504
  textbox,
505
  chatbot,
 
506
  temperature,
507
  no_repeat_ngram_size,
508
  max_new_tokens,
 
283
  )
284
 
285
  with gr.Accordion("Parameters", open=False, visible=True) as parameter_row:
286
+ decoding_strategy = gr.Radio(
287
+ [
288
+ "greedy",
289
+ "sampling_top_k",
290
+ "sampling_top_p",
291
+ ],
292
+ label="Decoding strategy",
293
+ )
294
  temperature = gr.Slider(
295
  minimum=0.0,
296
  maximum=1.0,
 
434
  def model_inference(
435
  user_prompt,
436
  chat_history,
437
+ decoding_strategy="greedy",
438
  temperature=1.0,
439
  no_repeat_ngram_size=0,
440
  max_new_tokens=512,
 
453
  force_words = ""
454
  # repetition_penalty = 1.0
455
  hide_special_tokens = False
456
+ # decoding_strategy = "greedy"
457
  num_beams = 3
458
  # length_penalty = 1.0
459
  # top_k = 50
 
495
  inputs=[
496
  textbox,
497
  chatbot,
498
+ decoding_strategy,
499
  temperature,
500
  no_repeat_ngram_size,
501
  max_new_tokens,
 
513
  inputs=[
514
  textbox,
515
  chatbot,
516
+ decoding_strategy,
517
  temperature,
518
  no_repeat_ngram_size,
519
  max_new_tokens,