Leyo commited on
Commit
5167a8a
·
1 Parent(s): e2307a6

add functionnal sliders for hyperparameters

Browse files
Files changed (1) hide show
  1. app_dialogue.py +70 -14
app_dialogue.py CHANGED
@@ -282,7 +282,15 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
282
  interactive=True,
283
  label="Top P",
284
  )
285
- max_output_tokens = gr.Slider(
 
 
 
 
 
 
 
 
286
  minimum=0,
287
  maximum=1024,
288
  value=512,
@@ -290,6 +298,46 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
290
  interactive=True,
291
  label="Max output tokens",
292
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  with gr.Column(scale=6):
295
  chatbot = gr.Chatbot(
@@ -357,22 +405,30 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
357
  def model_inference(
358
  user_prompt,
359
  chat_history,
 
 
 
 
 
 
 
 
 
360
  ):
361
  global processor, model, tokenizer
362
-
363
- temperature = 1.0
364
- no_repeat_ngram_size = 0
365
- max_new_tokens = 512
366
- min_length = 16
367
  force_words = ""
368
- repetition_penalty = 1.0
369
  hide_special_tokens = False
370
  decoding_strategy = "greedy"
371
  num_beams = 3
372
- length_penalty = 1.0
373
- top_k = 50
374
- top_p = 0.95
375
- penalty_alpha = 0.95
376
 
377
  formated_prompt = format_prompt_with_history_and_system_conditioning(
378
  current_user_prompt=user_prompt.strip(),
@@ -406,13 +462,13 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
406
 
407
  textbox.submit(
408
  fn=model_inference,
409
- inputs=[textbox, chatbot],
410
  outputs=[textbox, chatbot],
411
  )
412
  submit_btn.click(
413
  fn=model_inference,
414
- inputs=[textbox, chatbot],
415
- outputs=[textbox, chatbot],
416
  )
417
 
418
  demo.queue()
 
282
  interactive=True,
283
  label="Top P",
284
  )
285
+ top_k = gr.Slider(
286
+ minimum=0.0,
287
+ maximum=100.0,
288
+ value=50.0,
289
+ step=1.0,
290
+ interactive=True,
291
+ label="Top K",
292
+ )
293
+ max_new_tokens = gr.Slider(
294
  minimum=0,
295
  maximum=1024,
296
  value=512,
 
298
  interactive=True,
299
  label="Max output tokens",
300
  )
301
+ repetition_penalty = gr.Slider(
302
+ minimum=0.0,
303
+ maximum=10.0,
304
+ value=1.0,
305
+ step=0.1,
306
+ interactive=True,
307
+ label="Repetition penalty",
308
+ )
309
+ min_length = gr.Slider(
310
+ minimum=0.0,
311
+ maximum=50.0,
312
+ value=0.0,
313
+ step=1.0,
314
+ interactive=True,
315
+ label="No repeat ngram size",
316
+ )
317
+ length_penalty = gr.Slider(
318
+ minimum=0.0,
319
+ maximum=10.0,
320
+ value=1.0,
321
+ step=0.1,
322
+ interactive=True,
323
+ label="Length penalty",
324
+ )
325
+ no_repeat_ngram_size = gr.Slider(
326
+ minimum=0.0,
327
+ maximum=10.0,
328
+ value=0.0,
329
+ step=1.0,
330
+ interactive=True,
331
+ label="No repeat ngram size",
332
+ )
333
+ penalty_alpha = gr.Slider(
334
+ minimum=0.0,
335
+ maximum=10.0,
336
+ value=0.95,
337
+ step=1.0,
338
+ interactive=True,
339
+ label="Penalty alpha",
340
+ )
341
 
342
  with gr.Column(scale=6):
343
  chatbot = gr.Chatbot(
 
405
  def model_inference(
406
  user_prompt,
407
  chat_history,
408
+ temperature = 1.0,
409
+ no_repeat_ngram_size = 0,
410
+ max_new_tokens = 512,
411
+ min_length = 16,
412
+ repetition_penalty = 1.0,
413
+ length_penalty = 1.0,
414
+ top_k = 50,
415
+ top_p = 0.95,
416
+ penalty_alpha = 0.95,
417
  ):
418
  global processor, model, tokenizer
419
+ # temperature = 1.0
420
+ # no_repeat_ngram_size = 0
421
+ # max_new_tokens = 512
422
+ # min_length = 16
 
423
  force_words = ""
424
+ # repetition_penalty = 1.0
425
  hide_special_tokens = False
426
  decoding_strategy = "greedy"
427
  num_beams = 3
428
+ # length_penalty = 1.0
429
+ # top_k = 50
430
+ # top_p = 0.95
431
+ # penalty_alpha = 0.95
432
 
433
  formated_prompt = format_prompt_with_history_and_system_conditioning(
434
  current_user_prompt=user_prompt.strip(),
 
462
 
463
  textbox.submit(
464
  fn=model_inference,
465
+ inputs=[textbox, chatbot, temperature, ],
466
  outputs=[textbox, chatbot],
467
  )
468
  submit_btn.click(
469
  fn=model_inference,
470
+ inputs=[textbox, chatbot, temperature, no_repeat_ngram_size, max_new_tokens, min_length, repetition_penalty, length_penalty, top_k, top_p, penalty_alpha],
471
+ outputs=[textbox, chatbot, temperature, no_repeat_ngram_size, max_new_tokens, min_length, repetition_penalty, length_penalty, top_k, top_p, penalty_alpha],
472
  )
473
 
474
  demo.queue()