J-Antoine ZAGATO commited on
Commit
a0c663d
1 Parent(s): 40d38f3

Completed model comparison + added private models support + custom params support

Browse files
Files changed (1) hide show
  1. app.py +274 -22
app.py CHANGED
@@ -30,17 +30,19 @@ MODEL_CLASSES = {
30
  "Custom Model" : (AutoModelForCausalLM, AutoTokenizer),
31
  }
32
 
33
- def load_model(model_name, custom_model_path):
 
 
34
  try:
35
  model_class, tokenizer_class = MODEL_CLASSES[model_name]
36
  model_path = CHECKPOINTS[model_name]
37
 
38
  except KeyError:
39
  model_class, tokenizer_class = MODEL_CLASSES['Custom Model']
40
- model_path = custom_model_path
41
 
42
- model = model_class.from_pretrained(model_path)
43
- tokenizer = tokenizer_class.from_pretrained(model_path)
44
 
45
  tokenizer.pad_token = tokenizer.eos_token
46
  model.config.pad_token_id = model.config.eos_token_id
@@ -67,6 +69,7 @@ def adjust_length_to_model(length, max_sequence_length):
67
  return length
68
 
69
  def generate(model_name,
 
70
  custom_model_path,
71
  input_sentence,
72
  length = 75,
@@ -88,7 +91,7 @@ def generate(model_name,
88
  set_seed(seed, n_gpu)
89
 
90
  # Load model
91
- model, tokenizer = load_model(model_name, custom_model_path)
92
  model.to(device)
93
 
94
  #length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
@@ -126,6 +129,7 @@ def generate(model_name,
126
 
127
  return generated_sequences[0]
128
 
 
129
  def show_mode(mode):
130
  if mode == 'Single Model':
131
  return (
@@ -174,7 +178,7 @@ def show_search_bar(value):
174
  gr.update(visible=False)
175
  )
176
 
177
- def search_model(model_name):
178
  api = HfApi()
179
 
180
  model_args = ModelSearchArguments()
@@ -182,7 +186,7 @@ def search_model(model_name):
182
  task=model_args.pipeline_tag.TextGeneration,
183
  library=model_args.library.PyTorch)
184
 
185
- results = api.list_models(filter=filt, search=model_name)
186
  model_list = [model.modelId for model in results]
187
 
188
  return gr.update(visible=True,
@@ -190,6 +194,12 @@ def search_model(model_name):
190
  label='Choose the model',
191
  )
192
 
 
 
 
 
 
 
193
  def forward_model_choice(model_choice_path):
194
  return (model_choice_path,
195
  model_choice_path)
@@ -200,16 +210,30 @@ def auto_complete(input, generated):
200
  completed_prompt = {"text": output, "entities": output_spans}
201
  return completed_prompt
202
 
203
- def process_user_input(model, custom_model_path, input):
 
 
 
 
 
 
 
204
  warning = 'Please enter a valid prompt.'
205
  if input == None:
206
  generated = warning
207
  else:
208
- generated = generate(model, custom_model_path, input)
209
- generated_with_spans = auto_complete(input, generated)
 
 
 
 
 
 
 
210
 
211
  return (
212
- generated_with_spans,
213
  gr.update(visible=True),
214
  gr.update(visible=True),
215
  input,
@@ -264,11 +288,55 @@ def upload_flag(*args):
264
  if flagging_callback.flag(list(args), flag_option = None):
265
  return gr.update(visible=True)
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  CSS = """
268
  #inside_group {
269
  padding-top: 0.6em;
270
  padding-bottom: 0.6em;
271
  }
 
 
 
272
  """
273
 
274
  with gr.Blocks(css=CSS) as demo:
@@ -286,9 +354,12 @@ with gr.Blocks(css=CSS) as demo:
286
  organization = "fsdlredteam",
287
  private = True )
288
 
289
- gr.Markdown("# Project Interface proposal")
 
 
290
  gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
291
- gr.Markdown("### Or compare multiple models")
 
292
 
293
  choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"],
294
  value='Single Model',
@@ -297,6 +368,12 @@ with gr.Blocks(css=CSS) as demo:
297
  show_label=False)
298
 
299
  with gr.Group() as single_model:
 
 
 
 
 
 
300
  with gr.Row():
301
 
302
  with gr.Column(scale=1): # input & prompts dataset exploration
@@ -315,11 +392,44 @@ with gr.Blocks(css=CSS) as demo:
315
 
316
  randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group")
317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
  with gr.Column(scale=1): # Model choice & output
320
  gr.Markdown("### 2. Evaluate output")
321
 
322
-
323
  model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
324
  label='Model',
325
  interactive=True,
@@ -331,11 +441,19 @@ with gr.Blocks(css=CSS) as demo:
331
  elem_id="inside_group")
332
  model_drop = gr.Dropdown(visible=False)
333
 
334
- generate_button = gr.Button('Submit your prompt')
 
 
 
 
 
 
335
 
336
- output_spans = gr.HighlightedText(visible=True, label="Generated text", elem_id="inside_group")
337
 
338
- flag_button = gr.Button("Report output here", visible=False)
 
 
339
 
340
  with gr.Row(): # Flagging
341
 
@@ -373,9 +491,94 @@ with gr.Blocks(css=CSS) as demo:
373
  visible=False,
374
  elem_id="inside_group")
375
 
376
- with gr.Group() as multi_model:
377
- gr.Markdown("Model comparison will be here")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
 
379
 
380
  choose_mode.change(fn=show_mode,
381
  inputs=choose_mode,
@@ -398,16 +601,27 @@ with gr.Blocks(css=CSS) as demo:
398
  outputs=[model_choice,search_bar])
399
 
400
  search_bar.submit(fn=search_model,
401
- inputs=search_bar,
402
  outputs=model_drop,
403
  show_progress=True)
404
 
 
 
 
 
405
  model_drop.change(fn=forward_model_choice,
406
  inputs=model_drop,
407
  outputs=[model_choice,custom_model_path])
408
 
409
  generate_button.click(fn=process_user_input,
410
- inputs=[model_choice, custom_model_path, input_text],
 
 
 
 
 
 
 
411
  outputs=[output_spans,
412
  toxi_button,
413
  flag_button,
@@ -442,7 +656,45 @@ with gr.Blocks(css=CSS) as demo:
442
  user_comment,
443
  flag_choice],
444
  outputs=success_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
446
  #demo.launch(debug=True)
447
  if __name__ == "__main__":
448
- demo.launch(enable_queue=False)
 
30
  "Custom Model" : (AutoModelForCausalLM, AutoTokenizer),
31
  }
32
 
33
+ CHOICES = sorted(list(CHECKPOINTS.keys())[:3])
34
+
35
+ def load_model(model_name, custom_model_path, token):
36
  try:
37
  model_class, tokenizer_class = MODEL_CLASSES[model_name]
38
  model_path = CHECKPOINTS[model_name]
39
 
40
  except KeyError:
41
  model_class, tokenizer_class = MODEL_CLASSES['Custom Model']
42
+ model_path = custom_model_path or model_name
43
 
44
+ model = model_class.from_pretrained(model_path, use_auth_token=token)
45
+ tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token)
46
 
47
  tokenizer.pad_token = tokenizer.eos_token
48
  model.config.pad_token_id = model.config.eos_token_id
 
69
  return length
70
 
71
  def generate(model_name,
72
+ token,
73
  custom_model_path,
74
  input_sentence,
75
  length = 75,
 
91
  set_seed(seed, n_gpu)
92
 
93
  # Load model
94
+ model, tokenizer = load_model(model_name, custom_model_path, token)
95
  model.to(device)
96
 
97
  #length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
 
129
 
130
  return generated_sequences[0]
131
 
132
+
133
  def show_mode(mode):
134
  if mode == 'Single Model':
135
  return (
 
178
  gr.update(visible=False)
179
  )
180
 
181
+ def search_model(model_name, token):
182
  api = HfApi()
183
 
184
  model_args = ModelSearchArguments()
 
186
  task=model_args.pipeline_tag.TextGeneration,
187
  library=model_args.library.PyTorch)
188
 
189
+ results = api.list_models(filter=filt, search=model_name, use_auth_token=token)
190
  model_list = [model.modelId for model in results]
191
 
192
  return gr.update(visible=True,
 
194
  label='Choose the model',
195
  )
196
 
197
+ def show_api_key_textbox(checkbox):
198
+ if checkbox:
199
+ return gr.update(visible=True)
200
+ else:
201
+ return gr.update(visible=False)
202
+
203
  def forward_model_choice(model_choice_path):
204
  return (model_choice_path,
205
  model_choice_path)
 
210
  completed_prompt = {"text": output, "entities": output_spans}
211
  return completed_prompt
212
 
213
+ def process_user_input(model,
214
+ token,
215
+ custom_model_path,
216
+ input,
217
+ length,
218
+ temperature,
219
+ top_p,
220
+ top_k):
221
  warning = 'Please enter a valid prompt.'
222
  if input == None:
223
  generated = warning
224
  else:
225
+ generated = generate(model_name=model,
226
+ token=token,
227
+ custom_model_path=custom_model_path,
228
+ input_sentence=input,
229
+ length=length,
230
+ temperature=temperature,
231
+ top_p=top_p,
232
+ top_k=top_k)
233
+ generated_with_spans = auto_complete(input=input, generated=generated)
234
 
235
  return (
236
+ gr.update(value=generated_with_spans),
237
  gr.update(visible=True),
238
  gr.update(visible=True),
239
  input,
 
288
  if flagging_callback.flag(list(args), flag_option = None):
289
  return gr.update(visible=True)
290
 
291
+ def forward_model_choice_multi(model_choice_path):
292
+ CHOICES.append(model_choice_path)
293
+ return gr.update(choices = CHOICES)
294
+
295
+ def process_user_input_multi(models,
296
+ input,
297
+ token,
298
+ length,
299
+ temperature,
300
+ top_p,
301
+ top_k):
302
+ warning = 'Please enter a valid prompt.'
303
+ if input == None:
304
+ generated = warning
305
+ else:
306
+ generated_dict= {model:generate(model_name=model,
307
+ token=token,
308
+ custom_model_path=None,
309
+ input_sentence=input,
310
+ length=length,
311
+ temperature=temperature,
312
+ top_p=top_p,
313
+ top_k=top_k) for model in sorted(models)}
314
+ generated_with_spans_dict = {model:auto_complete(input, generated) for model,generated in generated_dict.items()}
315
+
316
+ update_outputs = [gr.HighlightedText.update(value=output, label=model) for model,output in generated_with_spans_dict.items()]
317
+ update_hide = [gr.HighlightedText.update(visible=False) for i in range(10-len(models))]
318
+ return update_outputs + update_hide
319
+
320
+ def show_choices_multi(models):
321
+ update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)]
322
+ update_hide = [gr.HighlightedText.update(visible=False,value=None, label=None) for i in range(10-len(models))]
323
+
324
+ return update_show + update_hide
325
+
326
+ def show_params(checkbox):
327
+ if checkbox == True:
328
+ return gr.update(visible=True)
329
+ else:
330
+ return gr.update(visible=False)
331
+
332
  CSS = """
333
  #inside_group {
334
  padding-top: 0.6em;
335
  padding-bottom: 0.6em;
336
  }
337
+ #pw textarea {
338
+ -webkit-text-security: disc;
339
+ }
340
  """
341
 
342
  with gr.Blocks(css=CSS) as demo:
 
354
  organization = "fsdlredteam",
355
  private = True )
356
 
357
+ gr.Markdown("# FSDL 2022 Red-Teaming Open-Source Models Interface")
358
+ gr.Markdown("<img src=https://i.imgur.com/ZxbbLUQ.png>")
359
+
360
  gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
361
+ gr.Markdown("### Or compare the output of multiple models at the same time")
362
+
363
 
364
  choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"],
365
  value='Single Model',
 
368
  show_label=False)
369
 
370
  with gr.Group() as single_model:
371
+
372
+ gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!")
373
+ gr.Markdown("Write your prompt or alternatively use one from the [RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset")
374
+ gr.Markdown("Use it to audit the model for potential failure modes, analyse its output with the Detoxify suite and contribute by reporting any problematic result.")
375
+ gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.")
376
+
377
  with gr.Row():
378
 
379
  with gr.Column(scale=1): # input & prompts dataset exploration
 
392
 
393
  randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group")
394
 
395
+ show_params_checkbox_single = gr.Checkbox(label='Set custom params',
396
+ interactive=True,
397
+ value=False)
398
+
399
+ with gr.Box(visible=False) as params_box_single:
400
+
401
+ length_single = gr.Slider(label='Output length',
402
+ visible=True,
403
+ interactive=True,
404
+ minimum=50,
405
+ maximum=200,
406
+ value=75)
407
+
408
+ top_k_single = gr.Slider(label='top_k',
409
+ visible=True,
410
+ interactive=True,
411
+ minimum=1,
412
+ maximum=100,
413
+ value=50)
414
+
415
+ top_p_single = gr.Slider(label='top_p',
416
+ visible=True,
417
+ interactive=True,
418
+ minimum=0.1,
419
+ maximum=1,
420
+ value=0.95)
421
+
422
+ temperature_single = gr.Slider(label='temperature',
423
+ visible=True,
424
+ interactive=True,
425
+ minimum=0.1,
426
+ maximum=1,
427
+ value=0.7)
428
+
429
 
430
  with gr.Column(scale=1): # Model choice & output
431
  gr.Markdown("### 2. Evaluate output")
432
 
 
433
  model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
434
  label='Model',
435
  interactive=True,
 
441
  elem_id="inside_group")
442
  model_drop = gr.Dropdown(visible=False)
443
 
444
+ private_checkbox = gr.Checkbox(visible=True,label="Private Model ?", elem_id="inside_group")
445
+
446
+ api_key_textbox = gr.Textbox(label="Enter your AUTH TOKEN below",
447
+ value=None,
448
+ interactive=True,
449
+ visible=False,
450
+ elem_id="pw")
451
 
452
+ generate_button = gr.Button('Submit your prompt', elem_id="inside_group")
453
 
454
+ output_spans = gr.HighlightedText(visible=True, label="Generated text")
455
+
456
+ flag_button = gr.Button("Report output here", visible=False, elem_id="inside_group")
457
 
458
  with gr.Row(): # Flagging
459
 
 
491
  visible=False,
492
  elem_id="inside_group")
493
 
494
+ with gr.Group(visible=False) as multi_model:
495
+ model_list = list()
496
+
497
+ gr.Markdown("#### Run the same input on multiple models and compare the outputs")
498
+ gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!")
499
+ gr.Markdown("Use this feature to compare the same model at different checkpoints")
500
+ gr.Markdown('Or to benchmark your model against another one as a reference.')
501
+ gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.")
502
+
503
+ with gr.Row(elem_id="inside_group"):
504
+ with gr.Column():
505
+ models_multi = gr.CheckboxGroup(choices=CHOICES,
506
+ label='Models',
507
+ interactive=True,
508
+ elem_id="inside_group",
509
+ value=None)
510
+ with gr.Column():
511
+ generate_button_multi = gr.Button('Submit your prompt',elem_id="inside_group")
512
+
513
+ show_params_checkbox_multi = gr.Checkbox(label='Set custom params',
514
+ interactive=True,
515
+ value=False)
516
+
517
+ with gr.Box(visible=False) as params_box_multi:
518
+
519
+ length_multi = gr.Slider(label='Output length',
520
+ visible=True,
521
+ interactive=True,
522
+ minimum=50,
523
+ maximum=200,
524
+ value=75)
525
+
526
+ top_k_multi = gr.Slider(label='top_k',
527
+ visible=True,
528
+ interactive=True,
529
+ minimum=1,
530
+ maximum=100,
531
+ value=50)
532
+
533
+ top_p_multi = gr.Slider(label='top_p',
534
+ visible=True,
535
+ interactive=True,
536
+ minimum=0.1,
537
+ maximum=1,
538
+ value=0.95)
539
+
540
+ temperature_multi = gr.Slider(label='temperature',
541
+ visible=True,
542
+ interactive=True,
543
+ minimum=0.1,
544
+ maximum=1,
545
+ value=0.7)
546
+
547
+ with gr.Row(elem_id="inside_group"):
548
+
549
+ with gr.Column(elem_id="inside_group", scale=1):
550
+ input_text_multi = gr.Textbox(label="Write your prompt below.",
551
+ interactive=True,
552
+ lines=4,
553
+ elem_id="inside_group")
554
+
555
+ with gr.Column(elem_id="inside_group", scale=1):
556
+ search_bar_multi = gr.Textbox(label="Search another model",
557
+ interactive=True,
558
+ visible=True,
559
+ elem_id="inside_group")
560
+
561
+ model_drop_multi = gr.Dropdown(visible=False,
562
+ show_progress=True,
563
+ elem_id="inside_group")
564
+
565
+ private_checkbox_multi = gr.Checkbox(visible=True,label="Private Model ?")
566
+
567
+ api_key_textbox_multi = gr.Textbox(label="Enter your AUTH TOKEN below",
568
+ value=None,
569
+ interactive=True,
570
+ visible=False,
571
+ elem_id="pw")
572
+
573
+ with gr.Row() as outputs_row:
574
+ for i in range(10):
575
+ output_spans_multi = gr.HighlightedText(visible=False, elem_id="inside_group")
576
+ model_list.append(output_spans_multi)
577
+
578
+
579
+ gr.Markdown('App made during the FSDL course by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa')
580
 
581
+ # Single Model
582
 
583
  choose_mode.change(fn=show_mode,
584
  inputs=choose_mode,
 
601
  outputs=[model_choice,search_bar])
602
 
603
  search_bar.submit(fn=search_model,
604
+ inputs=[search_bar,api_key_textbox],
605
  outputs=model_drop,
606
  show_progress=True)
607
 
608
+ private_checkbox.change(fn=show_api_key_textbox,
609
+ inputs=private_checkbox,
610
+ outputs=api_key_textbox)
611
+
612
  model_drop.change(fn=forward_model_choice,
613
  inputs=model_drop,
614
  outputs=[model_choice,custom_model_path])
615
 
616
  generate_button.click(fn=process_user_input,
617
+ inputs=[model_choice,
618
+ api_key_textbox,
619
+ custom_model_path,
620
+ input_text,
621
+ length_single,
622
+ temperature_single,
623
+ top_p_single,
624
+ top_k_single],
625
  outputs=[output_spans,
626
  toxi_button,
627
  flag_button,
 
656
  user_comment,
657
  flag_choice],
658
  outputs=success_message)
659
+
660
+ show_params_checkbox_single.change(fn=show_params,
661
+ inputs=show_params_checkbox_single,
662
+ outputs=params_box_single)
663
+
664
+ # Model comparison
665
+
666
+ search_bar_multi.submit(fn=search_model,
667
+ inputs=[search_bar_multi, api_key_textbox_multi],
668
+ outputs=model_drop_multi,
669
+ show_progress=True)
670
+
671
+ show_params_checkbox_multi.change(fn=show_params,
672
+ inputs=show_params_checkbox_multi,
673
+ outputs=params_box_multi)
674
+
675
+ private_checkbox_multi.change(fn=show_api_key_textbox,
676
+ inputs=private_checkbox_multi,
677
+ outputs=api_key_textbox_multi)
678
+
679
+ model_drop_multi.change(fn=forward_model_choice_multi,
680
+ inputs=model_drop_multi,
681
+ outputs=[models_multi])
682
+
683
+ models_multi.change(fn=show_choices_multi,
684
+ inputs=models_multi,
685
+ outputs=model_list)
686
+
687
+ generate_button_multi.click(fn=process_user_input_multi,
688
+ inputs=[models_multi,
689
+ input_text_multi,
690
+ api_key_textbox_multi,
691
+ length_multi,
692
+ temperature_multi,
693
+ top_p_multi,
694
+ top_k_multi],
695
+ outputs=model_list,
696
+ show_progress=True)
697
 
698
  #demo.launch(debug=True)
699
  if __name__ == "__main__":
700
+ demo.launch(enable_queue=False, debug=True)