Jean-Antoine ZAGATO commited on
Commit
efbb6a7
β€’
1 Parent(s): 24ed1e4

Fixed 2 issues affecting flagging

Browse files
Files changed (1) hide show
  1. app.py +707 -568
app.py CHANGED
@@ -1,10 +1,10 @@
1
- import os
2
- import torch
3
 
4
- import numpy as np
5
  import gradio as gr
6
 
7
- from random import sample
8
  from detoxify import Detoxify
9
  from datasets import load_dataset
10
  from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
@@ -12,35 +12,36 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
12
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
13
  from transformers import BloomTokenizerFast, BloomForCausalLM
14
 
15
- HF_AUTH_TOKEN = os.environ.get('hf_token' or True)
16
 
17
  DATASET = "allenai/real-toxicity-prompts"
18
 
19
  CHECKPOINTS = {
20
- "DistilGPT2 by HuggingFace πŸ€—" : "distilgpt2",
21
- "GPT-Neo 125M by EleutherAI πŸ€–" : "EleutherAI/gpt-neo-125M",
22
- "BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m",
23
- "Custom Model" : None
24
- }
25
 
26
  MODEL_CLASSES = {
27
- "DistilGPT2 by HuggingFace πŸ€—" : (GPT2LMHeadModel, GPT2Tokenizer),
28
- "GPT-Neo 125M by EleutherAI πŸ€–" : (GPTNeoForCausalLM, GPT2Tokenizer),
29
- "BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
30
- "Custom Model" : (AutoModelForCausalLM, AutoTokenizer),
31
- }
32
 
33
  CHOICES = sorted(list(CHECKPOINTS.keys())[:3])
34
 
35
- def load_model(model_name, custom_model_path, token):
 
36
  try:
37
- model_class, tokenizer_class = MODEL_CLASSES[model_name]
38
- model_path = CHECKPOINTS[model_name]
39
-
40
  except KeyError:
41
- model_class, tokenizer_class = MODEL_CLASSES['Custom Model']
42
- model_path = custom_model_path or model_name
43
-
44
  model = model_class.from_pretrained(model_path, use_auth_token=token)
45
  tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token)
46
 
@@ -51,14 +52,17 @@ def load_model(model_name, custom_model_path, token):
51
 
52
  return model, tokenizer
53
 
 
54
  MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
55
 
 
56
  def set_seed(seed, n_gpu):
57
  np.random.seed(seed)
58
  torch.manual_seed(seed)
59
  if n_gpu > 0:
60
  torch.cuda.manual_seed_all(seed)
61
 
 
62
  def adjust_length_to_model(length, max_sequence_length):
63
  if length < 0 and max_sequence_length > 0:
64
  length = max_sequence_length
@@ -68,23 +72,26 @@ def adjust_length_to_model(length, max_sequence_length):
68
  length = MAX_LENGTH # avoid infinite loop
69
  return length
70
 
71
- def generate(model_name,
72
- token,
73
- custom_model_path,
74
- input_sentence,
75
- length = 75,
76
- temperature = 0.7,
77
- top_k = 50,
78
- top_p = 0.95,
79
- seed = 42,
80
- no_cuda = False,
81
- num_return_sequences = 1,
82
- stop_token = '.'
83
- ):
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  # load device
86
- #if not no_cuda:
87
- device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
 
 
88
  n_gpu = 0 if no_cuda else torch.cuda.device_count()
89
 
90
  # Set seed
@@ -94,36 +101,41 @@ def generate(model_name,
94
  model, tokenizer = load_model(model_name, custom_model_path, token)
95
  model.to(device)
96
 
97
- #length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
98
 
99
  # Tokenize input
100
- encoded_prompt = tokenizer.encode(input_sentence,
101
- add_special_tokens=False,
102
- return_tensors='pt')
103
 
104
  encoded_prompt = encoded_prompt.to(device)
105
 
106
- input_ids = encoded_prompt
107
-
108
- # Generate output
109
- output_sequences = model.generate(input_ids=input_ids,
110
- max_length=length + len(encoded_prompt[0]),
111
- temperature=temperature,
112
- top_k=top_k,
113
- top_p=top_p,
114
- do_sample=True,
115
- num_return_sequences=num_return_sequences
116
- )
 
117
  generated_sequences = list()
118
 
119
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
120
  generated_sequence = generated_sequence.tolist()
121
  text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
122
- #remove prompt
123
- text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
124
-
125
- #remove all text after last occurence of stop_token
126
- text = text[:text.rfind(stop_token)+1]
 
 
 
 
127
 
128
  generated_sequences.append(text)
129
 
@@ -131,203 +143,228 @@ def generate(model_name,
131
 
132
 
133
  def show_mode(mode):
134
- if mode == 'Single Model':
135
- return (
136
- gr.update(visible=True),
137
- gr.update(visible=False)
138
- )
139
- if mode == 'Multi-Model':
140
- return (
141
- gr.update(visible=False),
142
- gr.update(visible=True)
143
- )
144
 
145
  def prepare_dataset(dataset):
146
- dataset = load_dataset(dataset, split='train')
147
- return dataset
 
148
 
149
  def load_prompts(dataset):
150
- prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))]
151
- return prompts
 
152
 
153
  def random_sample(prompt_list):
154
- random_sample = sample(prompt_list,10)
155
- return random_sample
 
156
 
157
  def show_dataset(dataset):
158
- raw_data = prepare_dataset(dataset)
159
- prompts = load_prompts(raw_data)
160
-
161
- return (gr.update(choices=random_sample(prompts),
162
- label='You can find below a random subset from the RealToxicityPrompts dataset',
163
- visible=True),
164
- gr.update(visible=True),
165
- prompts,
166
- )
167
-
 
 
 
 
168
  def update_dropdown(prompts):
169
- return gr.update(choices=random_sample(prompts))
 
170
 
171
  def show_search_bar(value):
172
- if value == 'Custom Model':
173
- return (value,
174
- gr.update(visible=True)
175
- )
176
- else:
177
- return (value,
178
- gr.update(visible=False)
179
- )
180
 
181
  def search_model(model_name, token):
182
- api = HfApi()
183
 
184
- model_args = ModelSearchArguments()
185
- filt = ModelFilter(
186
- task=model_args.pipeline_tag.TextGeneration,
187
- library=model_args.library.PyTorch)
188
 
189
- results = api.list_models(filter=filt, search=model_name, use_auth_token=token)
190
- model_list = [model.modelId for model in results]
 
 
 
 
 
 
191
 
192
- return gr.update(visible=True,
193
- choices=model_list,
194
- label='Choose the model',
195
- )
196
 
197
  def show_api_key_textbox(checkbox):
198
- if checkbox:
199
- return gr.update(visible=True)
200
- else:
201
- return gr.update(visible=False)
 
202
 
203
  def forward_model_choice(model_choice_path):
204
- return (model_choice_path,
205
- model_choice_path)
206
 
207
  def auto_complete(input, generated):
208
- output = input + ' ' + generated
209
- output_spans = [{'entity': 'OUTPUT', 'start': len(input), 'end': len(output)}]
210
- completed_prompt = {"text": output, "entities": output_spans}
211
- return completed_prompt
212
-
213
- def process_user_input(model,
214
- token,
215
- custom_model_path,
216
- input,
217
- length,
218
- temperature,
219
- top_p,
220
- top_k):
221
- warning = 'Please enter a valid prompt.'
222
- if input == None:
223
- generated = warning
224
- else:
225
- generated = generate(model_name=model,
226
- token=token,
227
- custom_model_path=custom_model_path,
228
- input_sentence=input,
229
- length=length,
230
- temperature=temperature,
231
- top_p=top_p,
232
- top_k=top_k)
233
- generated_with_spans = auto_complete(input=input, generated=generated)
234
-
235
- return (
236
- gr.update(value=generated_with_spans),
237
- gr.update(visible=True),
238
- gr.update(visible=True),
239
- input,
240
- generated
241
- )
242
 
243
  def pass_to_textbox(input):
244
- return gr.update(value=input)
 
245
 
246
  def run_detoxify(text):
247
- results = Detoxify('original').predict(text)
248
- json_ready_results = {cat:float(score) for (cat,score) in results.items()}
249
- return json_ready_results
 
250
 
251
  def compute_toxi_output(output_text):
252
- scores = run_detoxify(output_text)
253
- return (
254
- gr.update(value=scores, visible=True),
255
- gr.update(visible=True)
256
- )
257
 
258
  def compute_change(input, output):
259
- change_percent = round(((float(output)-input)/input)*100, 2)
260
- return change_percent
 
261
 
262
  def compare_toxi_scores(input_text, output_scores):
263
- input_scores = run_detoxify(input_text)
264
- json_ready_results = {cat:float(score) for (cat,score) in input_scores.items()}
265
 
266
- compare_scores = {
267
- cat:compute_change(json_ready_results[cat], output_scores[cat])
268
- for cat in json_ready_results
269
- for cat in output_scores
270
- }
 
 
 
 
 
271
 
272
- return (
273
- gr.update(value=json_ready_results, visible=True),
274
- gr.update(value=compare_scores, visible=True)
275
- )
276
 
277
  def show_flag_choices():
278
- return gr.update(visible=True)
279
-
280
- def update_flag(flag_value):
281
- return (flag_value,
282
- gr.update(visible=True),
283
- gr.update(visible=True),
284
- gr.update(visible=False)
285
- )
286
-
 
 
 
287
  def upload_flag(*args):
288
- if flagging_callback.flag(list(args), flag_option = None):
289
- return gr.update(visible=True)
 
 
 
290
 
291
  def forward_model_choice_multi(model_choice_path):
292
- CHOICES.append(model_choice_path)
293
- return gr.update(choices = CHOICES)
294
-
295
- def process_user_input_multi(models,
296
- input,
297
- token,
298
- length,
299
- temperature,
300
- top_p,
301
- top_k):
302
- warning = 'Please enter a valid prompt.'
303
- if input == None:
304
- generated = warning
305
- else:
306
- generated_dict= {model:generate(model_name=model,
307
- token=token,
308
- custom_model_path=None,
309
- input_sentence=input,
310
- length=length,
311
- temperature=temperature,
312
- top_p=top_p,
313
- top_k=top_k) for model in sorted(models)}
314
- generated_with_spans_dict = {model:auto_complete(input, generated) for model,generated in generated_dict.items()}
315
-
316
- update_outputs = [gr.HighlightedText.update(value=output, label=model) for model,output in generated_with_spans_dict.items()]
317
- update_hide = [gr.HighlightedText.update(visible=False) for i in range(10-len(models))]
318
- return update_outputs + update_hide
 
 
 
 
 
 
 
 
 
319
 
320
  def show_choices_multi(models):
321
- update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)]
322
- update_hide = [gr.HighlightedText.update(visible=False,value=None, label=None) for i in range(10-len(models))]
 
 
 
 
 
323
 
324
- return update_show + update_hide
325
 
326
  def show_params(checkbox):
327
- if checkbox == True:
328
- return gr.update(visible=True)
329
- else:
330
- return gr.update(visible=False)
 
331
 
332
  CSS = """
333
  #inside_group {
@@ -340,366 +377,468 @@ CSS = """
340
  """
341
 
342
  with gr.Blocks(css=CSS) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- dataset = gr.Variable(value=DATASET)
345
- prompts_var = gr.Variable(value=None)
346
- input_var = gr.Variable(label="Input Prompt", value=None)
347
- output_var = gr.Variable(label="Output",value=None)
348
- model_choice = gr.Variable(label="Model", value=None)
349
- custom_model_path = gr.Variable(value=None)
350
- flag_choice = gr.Variable(label = "Flag", value=None)
351
-
352
- flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
353
- dataset_name = "fsdlredteam/flagged_2",
354
- organization = "fsdlredteam",
355
- private = True )
356
-
357
- gr.Markdown("<p align='center'><img src='https://i.imgur.com/ZxbbLUQ.png>'/></p>")
358
- gr.Markdown("<h1 align='center'>BuggingSpace</h1>")
359
- gr.Markdown("<h2 align='center'>FSDL 2022 Red-Teaming Open-Source Models Project</h2>")
360
- gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
361
- gr.Markdown("### Or compare the output of multiple models at the same time")
362
-
363
-
364
- choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"],
365
- value='Single Model',
366
- interactive=True,
367
- visible=True,
368
- show_label=False)
369
-
370
- with gr.Group() as single_model:
371
-
372
- gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, \
373
  provided you use your private key! "
374
- "Write your prompt or alternatively use one from the \
375
- [RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset.")
376
- gr.Markdown("Use it to audit the model for potential failure modes, \
377
- analyse its output with the Detoxify suite and contribute by reporting any problematic result.")
378
- gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.")
379
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- with gr.Column(scale=1): # input & prompts dataset exploration
383
- gr.Markdown("### 1. Select a prompt", elem_id="inside_group")
384
-
385
- input_text = gr.Textbox(label="Write your prompt below.",
386
- interactive=True,
387
- lines=4,
388
- elem_id="inside_group")
389
-
390
- gr.Markdown("β€” or β€”", elem_id="inside_group")
391
-
392
- inspo_button = gr.Button('Click here if you need some inspiration', elem_id="inside_group")
393
-
394
- prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group")
395
-
396
- randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group")
397
-
398
- show_params_checkbox_single = gr.Checkbox(label='Set custom params',
399
- interactive=True,
400
- value=False)
401
-
402
- with gr.Box(visible=False) as params_box_single:
403
-
404
- length_single = gr.Slider(label='Output length',
405
- visible=True,
406
- interactive=True,
407
- minimum=50,
408
- maximum=200,
409
- value=75)
410
-
411
- top_k_single = gr.Slider(label='top_k',
412
- visible=True,
413
- interactive=True,
414
- minimum=1,
415
- maximum=100,
416
- value=50)
417
-
418
- top_p_single = gr.Slider(label='top_p',
419
- visible=True,
420
- interactive=True,
421
- minimum=0.1,
422
- maximum=1,
423
- value=0.95)
424
-
425
- temperature_single = gr.Slider(label='temperature',
426
- visible=True,
427
- interactive=True,
428
- minimum=0.1,
429
- maximum=1,
430
- value=0.7)
431
-
432
-
433
- with gr.Column(scale=1): # Model choice & output
434
- gr.Markdown("### 2. Evaluate output")
435
-
436
- model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
437
- label='Model',
438
- interactive=True,
439
- elem_id="inside_group")
440
-
441
- search_bar = gr.Textbox(label="Search model",
442
- interactive=True,
443
- visible=False,
444
- elem_id="inside_group")
445
- model_drop = gr.Dropdown(visible=False)
446
-
447
- private_checkbox = gr.Checkbox(visible=True,label="Private Model ?", elem_id="inside_group")
448
-
449
- api_key_textbox = gr.Textbox(label="Enter your AUTH TOKEN below",
450
- value=None,
451
- interactive=True,
452
- visible=False,
453
- elem_id="pw")
454
-
455
- generate_button = gr.Button('Submit your prompt', elem_id="inside_group")
456
-
457
- output_spans = gr.HighlightedText(visible=True, label="Generated text")
458
-
459
- flag_button = gr.Button("Report output here", visible=False, elem_id="inside_group")
460
-
461
- with gr.Row(): # Flagging
462
-
463
- with gr.Column(scale=1):
464
- flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
465
- label="What's wrong with the output ?",
466
- interactive=True,
467
- visible=False,
468
- elem_id="inside_group")
469
-
470
- user_comment = gr.Textbox(label="(Optional) Briefly describe the issue",
471
- visible=False,
472
- interactive=True,
473
- elem_id="inside_group")
474
-
475
- confirm_flag_button = gr.Button("Confirm report", visible=False, elem_id="inside_group")
476
-
477
- with gr.Row(): # Flagging success
478
- success_message = gr.Markdown("Your report has been successfully registered. Thank you!",
479
- visible=False,
480
- elem_id="inside_group")
481
-
482
- with gr.Row(): # Toxicity buttons
483
- toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False, elem_id="inside_group")
484
- toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False, elem_id="inside_group")
485
-
486
- with gr.Row(): # Toxicity scores
487
- toxi_scores_input = gr.JSON(label = "Detoxify classification of your input",
488
- visible=False,
489
- elem_id="inside_group")
490
- toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output",
491
- visible=False,
492
- elem_id="inside_group")
493
- toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output",
494
- visible=False,
495
- elem_id="inside_group")
496
-
497
- with gr.Group(visible=False) as multi_model:
498
- model_list = list()
499
-
500
- gr.Markdown("#### Run the same input on multiple models and compare the outputs")
501
- gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!")
502
- gr.Markdown("Use this feature to compare the same model at different checkpoints")
503
- gr.Markdown('Or to benchmark your model against another one as a reference.')
504
- gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.")
505
-
506
- with gr.Row(elem_id="inside_group"):
507
- with gr.Column():
508
- models_multi = gr.CheckboxGroup(choices=CHOICES,
509
- label='Models',
510
- interactive=True,
511
- elem_id="inside_group",
512
- value=None)
513
- with gr.Column():
514
- generate_button_multi = gr.Button('Submit your prompt',elem_id="inside_group")
515
-
516
- show_params_checkbox_multi = gr.Checkbox(label='Set custom params',
517
- interactive=True,
518
- value=False)
519
-
520
- with gr.Box(visible=False) as params_box_multi:
521
-
522
- length_multi = gr.Slider(label='Output length',
523
- visible=True,
524
- interactive=True,
525
- minimum=50,
526
- maximum=200,
527
- value=75)
528
-
529
- top_k_multi = gr.Slider(label='top_k',
530
- visible=True,
531
- interactive=True,
532
- minimum=1,
533
- maximum=100,
534
- value=50)
535
-
536
- top_p_multi = gr.Slider(label='top_p',
537
- visible=True,
538
- interactive=True,
539
- minimum=0.1,
540
- maximum=1,
541
- value=0.95)
542
-
543
- temperature_multi = gr.Slider(label='temperature',
544
- visible=True,
545
- interactive=True,
546
- minimum=0.1,
547
- maximum=1,
548
- value=0.7)
549
-
550
- with gr.Row(elem_id="inside_group"):
551
-
552
- with gr.Column(elem_id="inside_group", scale=1):
553
- input_text_multi = gr.Textbox(label="Write your prompt below.",
554
- interactive=True,
555
- lines=4,
556
- elem_id="inside_group")
557
-
558
- with gr.Column(elem_id="inside_group", scale=1):
559
- search_bar_multi = gr.Textbox(label="Search another model",
560
- interactive=True,
561
- visible=True,
562
- elem_id="inside_group")
563
-
564
- model_drop_multi = gr.Dropdown(visible=False,
565
- show_progress=True,
566
- elem_id="inside_group")
567
-
568
- private_checkbox_multi = gr.Checkbox(visible=True,label="Private Model ?")
569
-
570
- api_key_textbox_multi = gr.Textbox(label="Enter your AUTH TOKEN below",
571
- value=None,
572
- interactive=True,
573
- visible=False,
574
- elem_id="pw")
575
-
576
- with gr.Row() as outputs_row:
577
- for i in range(10):
578
- output_spans_multi = gr.HighlightedText(visible=False, elem_id="inside_group")
579
- model_list.append(output_spans_multi)
580
-
581
-
582
- with gr.Row():
583
- gr.Markdown('App made during the [FSDL course](https://fullstackdeeplearning.com) \
584
- by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa')
585
-
586
- # Single Model
587
-
588
- choose_mode.change(fn=show_mode,
589
- inputs=choose_mode,
590
- outputs=[single_model, multi_model])
591
-
592
- inspo_button.click(fn=show_dataset,
593
- inputs=dataset,
594
- outputs=[prompts_drop, randomize_button, prompts_var])
595
-
596
- prompts_drop.change(fn=pass_to_textbox,
597
- inputs=prompts_drop,
598
- outputs=input_text)
599
-
600
- randomize_button.click(fn=update_dropdown,
601
- inputs=prompts_var,
602
- outputs=prompts_drop),
603
-
604
- model_radio.change(fn=show_search_bar,
605
- inputs=model_radio,
606
- outputs=[model_choice,search_bar])
607
-
608
- search_bar.submit(fn=search_model,
609
- inputs=[search_bar,api_key_textbox],
610
- outputs=model_drop,
611
- show_progress=True)
612
-
613
- private_checkbox.change(fn=show_api_key_textbox,
614
- inputs=private_checkbox,
615
- outputs=api_key_textbox)
616
-
617
- model_drop.change(fn=forward_model_choice,
618
- inputs=model_drop,
619
- outputs=[model_choice,custom_model_path])
620
-
621
- generate_button.click(fn=process_user_input,
622
- inputs=[model_choice,
623
- api_key_textbox,
624
- custom_model_path,
625
- input_text,
626
- length_single,
627
- temperature_single,
628
- top_p_single,
629
- top_k_single],
630
- outputs=[output_spans,
631
- toxi_button,
632
- flag_button,
633
- input_var,
634
- output_var],
635
- show_progress=True)
636
-
637
- toxi_button.click(fn=compute_toxi_output,
638
- inputs=output_var,
639
- outputs=[toxi_scores_output, toxi_button_compare],
640
- show_progress=True)
641
-
642
- toxi_button_compare.click(fn=compare_toxi_scores,
643
- inputs=[input_text, toxi_scores_output],
644
- outputs=[toxi_scores_input, toxi_scores_compare],
645
- show_progress=True)
646
-
647
- flag_button.click(fn=show_flag_choices,
648
- inputs=None,
649
- outputs=flag_radio)
650
-
651
- flag_radio.change(fn=update_flag,
652
- inputs=flag_radio,
653
- outputs=[flag_choice, confirm_flag_button, user_comment, flag_button])
654
-
655
- flagging_callback.setup([input_var, output_var, model_choice, user_comment, flag_choice], "flagged_data_points")
656
-
657
- confirm_flag_button.click(fn = upload_flag,
658
- inputs = [input_var,
659
- output_var,
660
- model_choice,
661
- user_comment,
662
- flag_choice],
663
- outputs=success_message)
664
-
665
- show_params_checkbox_single.change(fn=show_params,
666
- inputs=show_params_checkbox_single,
667
- outputs=params_box_single)
668
-
669
- # Model comparison
670
-
671
- search_bar_multi.submit(fn=search_model,
672
- inputs=[search_bar_multi, api_key_textbox_multi],
673
- outputs=model_drop_multi,
674
- show_progress=True)
675
-
676
- show_params_checkbox_multi.change(fn=show_params,
677
- inputs=show_params_checkbox_multi,
678
- outputs=params_box_multi)
679
-
680
- private_checkbox_multi.change(fn=show_api_key_textbox,
681
- inputs=private_checkbox_multi,
682
- outputs=api_key_textbox_multi)
683
-
684
- model_drop_multi.change(fn=forward_model_choice_multi,
685
- inputs=model_drop_multi,
686
- outputs=[models_multi])
687
-
688
- models_multi.change(fn=show_choices_multi,
689
- inputs=models_multi,
690
- outputs=model_list)
691
-
692
- generate_button_multi.click(fn=process_user_input_multi,
693
- inputs=[models_multi,
694
- input_text_multi,
695
- api_key_textbox_multi,
696
- length_multi,
697
- temperature_multi,
698
- top_p_multi,
699
- top_k_multi],
700
- outputs=model_list,
701
- show_progress=True)
702
-
703
- #demo.launch(debug=True)
704
  if __name__ == "__main__":
705
- demo.launch(enable_queue=False, debug=True)
 
1
+ import os
2
+ import torch
3
 
4
+ import numpy as np
5
  import gradio as gr
6
 
7
+ from random import sample
8
  from detoxify import Detoxify
9
  from datasets import load_dataset
10
  from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
12
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
13
  from transformers import BloomTokenizerFast, BloomForCausalLM
14
 
15
+ HF_AUTH_TOKEN = os.environ.get("hf_token" or True)
16
 
17
  DATASET = "allenai/real-toxicity-prompts"
18
 
19
  CHECKPOINTS = {
20
+ "DistilGPT2 by HuggingFace πŸ€—": "distilgpt2",
21
+ "GPT-Neo 125M by EleutherAI πŸ€–": "EleutherAI/gpt-neo-125M",
22
+ "BLOOM 560M by BigScience 🌸": "bigscience/bloom-560m",
23
+ "Custom Model": None,
24
+ }
25
 
26
  MODEL_CLASSES = {
27
+ "DistilGPT2 by HuggingFace πŸ€—": (GPT2LMHeadModel, GPT2Tokenizer),
28
+ "GPT-Neo 125M by EleutherAI πŸ€–": (GPTNeoForCausalLM, GPT2Tokenizer),
29
+ "BLOOM 560M by BigScience 🌸": (BloomForCausalLM, BloomTokenizerFast),
30
+ "Custom Model": (AutoModelForCausalLM, AutoTokenizer),
31
+ }
32
 
33
  CHOICES = sorted(list(CHECKPOINTS.keys())[:3])
34
 
35
+
36
+ def load_model(model_name, custom_model_path, token):
37
  try:
38
+ model_class, tokenizer_class = MODEL_CLASSES[model_name]
39
+ model_path = CHECKPOINTS[model_name]
40
+
41
  except KeyError:
42
+ model_class, tokenizer_class = MODEL_CLASSES["Custom Model"]
43
+ model_path = custom_model_path or model_name
44
+
45
  model = model_class.from_pretrained(model_path, use_auth_token=token)
46
  tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token)
47
 
52
 
53
  return model, tokenizer
54
 
55
+
56
  MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
57
 
58
+
59
  def set_seed(seed, n_gpu):
60
  np.random.seed(seed)
61
  torch.manual_seed(seed)
62
  if n_gpu > 0:
63
  torch.cuda.manual_seed_all(seed)
64
 
65
+
66
  def adjust_length_to_model(length, max_sequence_length):
67
  if length < 0 and max_sequence_length > 0:
68
  length = max_sequence_length
72
  length = MAX_LENGTH # avoid infinite loop
73
  return length
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def generate(
77
+ model_name,
78
+ token,
79
+ custom_model_path,
80
+ input_sentence,
81
+ length=75,
82
+ temperature=0.7,
83
+ top_k=50,
84
+ top_p=0.95,
85
+ seed=42,
86
+ no_cuda=False,
87
+ num_return_sequences=1,
88
+ stop_token=".",
89
+ ):
90
  # load device
91
+ # if not no_cuda:
92
+ device = torch.device(
93
+ "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
94
+ )
95
  n_gpu = 0 if no_cuda else torch.cuda.device_count()
96
 
97
  # Set seed
101
  model, tokenizer = load_model(model_name, custom_model_path, token)
102
  model.to(device)
103
 
104
+ # length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
105
 
106
  # Tokenize input
107
+ encoded_prompt = tokenizer.encode(
108
+ input_sentence, add_special_tokens=False, return_tensors="pt"
109
+ )
110
 
111
  encoded_prompt = encoded_prompt.to(device)
112
 
113
+ input_ids = encoded_prompt
114
+
115
+ # Generate output
116
+ output_sequences = model.generate(
117
+ input_ids=input_ids,
118
+ max_length=length + len(encoded_prompt[0]),
119
+ temperature=temperature,
120
+ top_k=top_k,
121
+ top_p=top_p,
122
+ do_sample=True,
123
+ num_return_sequences=num_return_sequences,
124
+ )
125
  generated_sequences = list()
126
 
127
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
128
  generated_sequence = generated_sequence.tolist()
129
  text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
130
+ # remove prompt
131
+ text = text[
132
+ len(
133
+ tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)
134
+ ) :
135
+ ]
136
+
137
+ # remove all text after last occurence of stop_token
138
+ text = text[: text.rfind(stop_token) + 1]
139
 
140
  generated_sequences.append(text)
141
 
143
 
144
 
145
  def show_mode(mode):
146
+ if mode == "Single Model":
147
+ return (gr.update(visible=True), gr.update(visible=False))
148
+ if mode == "Multi-Model":
149
+ return (gr.update(visible=False), gr.update(visible=True))
150
+
 
 
 
 
 
151
 
152
  def prepare_dataset(dataset):
153
+ dataset = load_dataset(dataset, split="train")
154
+ return dataset
155
+
156
 
157
  def load_prompts(dataset):
158
+ prompts = [dataset[i]["prompt"]["text"] for i in range(len(dataset))]
159
+ return prompts
160
+
161
 
162
  def random_sample(prompt_list):
163
+ random_sample = sample(prompt_list, 10)
164
+ return random_sample
165
+
166
 
167
  def show_dataset(dataset):
168
+ raw_data = prepare_dataset(dataset)
169
+ prompts = load_prompts(raw_data)
170
+
171
+ return (
172
+ gr.update(
173
+ choices=random_sample(prompts),
174
+ label="You can find below a random subset from the RealToxicityPrompts dataset",
175
+ visible=True,
176
+ ),
177
+ gr.update(visible=True),
178
+ prompts,
179
+ )
180
+
181
+
182
  def update_dropdown(prompts):
183
+ return gr.update(choices=random_sample(prompts))
184
+
185
 
186
  def show_search_bar(value):
187
+ if value == "Custom Model":
188
+ return (value, gr.update(visible=True))
189
+ else:
190
+ return (value, gr.update(visible=False))
191
+
 
 
 
192
 
193
  def search_model(model_name, token):
194
+ api = HfApi()
195
 
196
+ model_args = ModelSearchArguments()
197
+ filt = ModelFilter(
198
+ task=model_args.pipeline_tag.TextGeneration, library=model_args.library.PyTorch
199
+ )
200
 
201
+ results = api.list_models(filter=filt, search=model_name, use_auth_token=token)
202
+ model_list = [model.modelId for model in results]
203
+
204
+ return gr.update(
205
+ visible=True,
206
+ choices=model_list,
207
+ label="Choose the model",
208
+ )
209
 
 
 
 
 
210
 
211
  def show_api_key_textbox(checkbox):
212
+ if checkbox:
213
+ return gr.update(visible=True)
214
+ else:
215
+ return gr.update(visible=False)
216
+
217
 
218
  def forward_model_choice(model_choice_path):
219
+ return (model_choice_path, model_choice_path)
220
+
221
 
222
  def auto_complete(input, generated):
223
+ output = input + " " + generated
224
+ output_spans = [{"entity": "OUTPUT", "start": len(input), "end": len(output)}]
225
+ completed_prompt = {"text": output, "entities": output_spans}
226
+ return completed_prompt
227
+
228
+
229
+ def process_user_input(
230
+ model, token, custom_model_path, input, length, temperature, top_p, top_k
231
+ ):
232
+ warning = "Please enter a valid prompt."
233
+ if input == None:
234
+ generated = warning
235
+ else:
236
+ generated = generate(
237
+ model_name=model,
238
+ token=token,
239
+ custom_model_path=custom_model_path,
240
+ input_sentence=input,
241
+ length=length,
242
+ temperature=temperature,
243
+ top_p=top_p,
244
+ top_k=top_k,
245
+ )
246
+ generated = generated.replace("\n", " ")
247
+ generated_with_spans = auto_complete(input=input, generated=generated)
248
+
249
+ return (
250
+ gr.update(value=generated_with_spans),
251
+ gr.update(visible=True),
252
+ gr.update(visible=True),
253
+ input,
254
+ generated,
255
+ )
256
+
257
 
258
  def pass_to_textbox(input):
259
+ return gr.update(value=input)
260
+
261
 
262
  def run_detoxify(text):
263
+ results = Detoxify("original").predict(text)
264
+ json_ready_results = {cat: float(score) for (cat, score) in results.items()}
265
+ return json_ready_results
266
+
267
 
268
  def compute_toxi_output(output_text):
269
+ scores = run_detoxify(output_text)
270
+ return (gr.update(value=scores, visible=True), gr.update(visible=True))
271
+
 
 
272
 
273
  def compute_change(input, output):
274
+ change_percent = round(((float(output) - input) / input) * 100, 2)
275
+ return change_percent
276
+
277
 
278
  def compare_toxi_scores(input_text, output_scores):
279
+ input_scores = run_detoxify(input_text)
280
+ json_ready_results = {cat: float(score) for (cat, score) in input_scores.items()}
281
 
282
+ compare_scores = {
283
+ cat: compute_change(json_ready_results[cat], output_scores[cat])
284
+ for cat in json_ready_results
285
+ for cat in output_scores
286
+ }
287
+
288
+ return (
289
+ gr.update(value=json_ready_results, visible=True),
290
+ gr.update(value=compare_scores, visible=True),
291
+ )
292
 
 
 
 
 
293
 
294
  def show_flag_choices():
295
+ return gr.update(visible=True)
296
+
297
+
298
+ def update_flag(flag_value):
299
+ return (
300
+ flag_value,
301
+ gr.update(visible=True),
302
+ gr.update(visible=True),
303
+ gr.update(visible=False),
304
+ )
305
+
306
+
307
  def upload_flag(*args):
308
+ flags = list(args)
309
+ flags[1] = bytes(flags[1], "utf-8")
310
+ flagging_callback.flag(flags)
311
+ return gr.update(visible=True)
312
+
313
 
314
  def forward_model_choice_multi(model_choice_path):
315
+ CHOICES.append(model_choice_path)
316
+ return gr.update(choices=CHOICES)
317
+
318
+
319
+ def process_user_input_multi(models, input, token, length, temperature, top_p, top_k):
320
+ warning = "Please enter a valid prompt."
321
+ if input == None:
322
+ generated = warning
323
+ else:
324
+ generated_dict = {
325
+ model: generate(
326
+ model_name=model,
327
+ token=token,
328
+ custom_model_path=None,
329
+ input_sentence=input,
330
+ length=length,
331
+ temperature=temperature,
332
+ top_p=top_p,
333
+ top_k=top_k,
334
+ )
335
+ for model in sorted(models)
336
+ }
337
+ generated_with_spans_dict = {
338
+ model: auto_complete(input, generated)
339
+ for model, generated in generated_dict.items()
340
+ }
341
+
342
+ update_outputs = [
343
+ gr.HighlightedText.update(value=output, label=model)
344
+ for model, output in generated_with_spans_dict.items()
345
+ ]
346
+ update_hide = [
347
+ gr.HighlightedText.update(visible=False) for i in range(10 - len(models))
348
+ ]
349
+ return update_outputs + update_hide
350
+
351
 
352
  def show_choices_multi(models):
353
+ update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)]
354
+ update_hide = [
355
+ gr.HighlightedText.update(visible=False, value=None, label=None)
356
+ for i in range(10 - len(models))
357
+ ]
358
+
359
+ return update_show + update_hide
360
 
 
361
 
362
  def show_params(checkbox):
363
+ if checkbox == True:
364
+ return gr.update(visible=True)
365
+ else:
366
+ return gr.update(visible=False)
367
+
368
 
369
  CSS = """
370
  #inside_group {
377
  """
378
 
379
  with gr.Blocks(css=CSS) as demo:
380
+ dataset = gr.Variable(value=DATASET)
381
+ prompts_var = gr.Variable(value=None)
382
+ input_var = gr.Variable(label="Input Prompt", value=None)
383
+ output_var = gr.Variable(label="Output", value=None)
384
+ model_choice = gr.Variable(label="Model", value=None)
385
+ custom_model_path = gr.Variable(value=None)
386
+ flag_choice = gr.Variable(label="Flag", value=None)
387
+
388
+ flagging_callback = gr.HuggingFaceDatasetSaver(
389
+ hf_token=HF_AUTH_TOKEN,
390
+ dataset_name="fsdlredteam/flagged_3",
391
+ private=True,
392
+ )
393
 
394
+ gr.Markdown("<p align='center'><img src='https://i.imgur.com/ZxbbLUQ.png>'/></p>")
395
+ gr.Markdown("<h1 align='center'>BuggingSpace</h1>")
396
+ gr.Markdown(
397
+ "<h2 align='center'>FSDL 2022 Red-Teaming Open-Source Models Project</h2>"
398
+ )
399
+ gr.Markdown(
400
+ "### Pick a text generation model below, write a prompt and explore the output"
401
+ )
402
+ gr.Markdown("### Or compare the output of multiple models at the same time")
403
+
404
+ choose_mode = gr.Radio(
405
+ choices=["Single Model", "Multi-Model"],
406
+ value="Single Model",
407
+ interactive=True,
408
+ visible=True,
409
+ show_label=False,
410
+ )
411
+
412
+ with gr.Group() as single_model:
413
+ gr.Markdown(
414
+ "You can upload any model from the Hugging Face hub -even private ones, \
 
 
 
 
 
 
 
 
415
  provided you use your private key! "
416
+ "Write your prompt or alternatively use one from the \
417
+ [RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset."
418
+ )
419
+ gr.Markdown(
420
+ "Use it to audit the model for potential failure modes, \
421
+ analyse its output with the Detoxify suite and contribute by reporting any problematic result."
422
+ )
423
+ gr.Markdown(
424
+ "Beware ! Generation can take up to a few minutes with very large models."
425
+ )
426
+
427
+ with gr.Row():
428
+ with gr.Column(scale=1): # input & prompts dataset exploration
429
+ gr.Markdown("### 1. Select a prompt", elem_id="inside_group")
430
+
431
+ input_text = gr.Textbox(
432
+ label="Write your prompt below.",
433
+ interactive=True,
434
+ lines=4,
435
+ elem_id="inside_group",
436
+ )
437
+
438
+ gr.Markdown("β€” or β€”", elem_id="inside_group")
439
+
440
+ inspo_button = gr.Button(
441
+ "Click here if you need some inspiration", elem_id="inside_group"
442
+ )
443
+
444
+ prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group")
445
+
446
+ randomize_button = gr.Button(
447
+ "Show another subset", visible=False, elem_id="inside_group"
448
+ )
449
+
450
+ show_params_checkbox_single = gr.Checkbox(
451
+ label="Set custom params", interactive=True, value=False
452
+ )
453
+
454
+ with gr.Box(visible=False) as params_box_single:
455
+ length_single = gr.Slider(
456
+ label="Output length",
457
+ visible=True,
458
+ interactive=True,
459
+ minimum=50,
460
+ maximum=200,
461
+ value=75,
462
+ )
463
+
464
+ top_k_single = gr.Slider(
465
+ label="top_k",
466
+ visible=True,
467
+ interactive=True,
468
+ minimum=1,
469
+ maximum=100,
470
+ value=50,
471
+ )
472
+
473
+ top_p_single = gr.Slider(
474
+ label="top_p",
475
+ visible=True,
476
+ interactive=True,
477
+ minimum=0.1,
478
+ maximum=1,
479
+ value=0.95,
480
+ )
481
+
482
+ temperature_single = gr.Slider(
483
+ label="temperature",
484
+ visible=True,
485
+ interactive=True,
486
+ minimum=0.1,
487
+ maximum=1,
488
+ value=0.7,
489
+ )
490
+
491
+ with gr.Column(scale=1): # Model choice & output
492
+ gr.Markdown("### 2. Evaluate output")
493
+
494
+ model_radio = gr.Radio(
495
+ choices=list(CHECKPOINTS.keys()),
496
+ label="Model",
497
+ interactive=True,
498
+ elem_id="inside_group",
499
+ )
500
+
501
+ search_bar = gr.Textbox(
502
+ label="Search model",
503
+ interactive=True,
504
+ visible=False,
505
+ elem_id="inside_group",
506
+ )
507
+ model_drop = gr.Dropdown(visible=False)
508
+
509
+ private_checkbox = gr.Checkbox(
510
+ visible=True, label="Private Model ?", elem_id="inside_group"
511
+ )
512
+
513
+ api_key_textbox = gr.Textbox(
514
+ label="Enter your AUTH TOKEN below",
515
+ value=None,
516
+ interactive=True,
517
+ visible=False,
518
+ elem_id="pw",
519
+ )
520
+
521
+ generate_button = gr.Button(
522
+ "Submit your prompt", elem_id="inside_group"
523
+ )
524
+
525
+ output_spans = gr.HighlightedText(visible=True, label="Generated text")
526
+
527
+ flag_button = gr.Button(
528
+ "Report output here", visible=False, elem_id="inside_group"
529
+ )
530
+
531
+ with gr.Row(): # Flagging
532
+ with gr.Column(scale=1):
533
+ flag_radio = gr.Radio(
534
+ choices=[
535
+ "Toxic",
536
+ "Offensive",
537
+ "Repetitive",
538
+ "Incorrect",
539
+ "Other",
540
+ ],
541
+ label="What's wrong with the output ?",
542
+ interactive=True,
543
+ visible=False,
544
+ elem_id="inside_group",
545
+ )
546
+
547
+ user_comment = gr.Textbox(
548
+ label="(Optional) Briefly describe the issue",
549
+ visible=False,
550
+ interactive=True,
551
+ elem_id="inside_group",
552
+ )
553
+
554
+ confirm_flag_button = gr.Button(
555
+ "Confirm report", visible=False, elem_id="inside_group"
556
+ )
557
+
558
+ with gr.Row(): # Flagging success
559
+ success_message = gr.Markdown(
560
+ "Your report has been successfully registered. Thank you!",
561
+ visible=False,
562
+ elem_id="inside_group",
563
+ )
564
+
565
+ with gr.Row(): # Toxicity buttons
566
+ toxi_button = gr.Button(
567
+ "Run a toxicity analysis of the model's output",
568
+ visible=False,
569
+ elem_id="inside_group",
570
+ )
571
+ toxi_button_compare = gr.Button(
572
+ "Compare toxicity on input and output",
573
+ visible=False,
574
+ elem_id="inside_group",
575
+ )
576
+
577
+ with gr.Row(): # Toxicity scores
578
+ toxi_scores_input = gr.JSON(
579
+ label="Detoxify classification of your input",
580
+ visible=False,
581
+ elem_id="inside_group",
582
+ )
583
+ toxi_scores_output = gr.JSON(
584
+ label="Detoxify classification of the model's output",
585
+ visible=False,
586
+ elem_id="inside_group",
587
+ )
588
+ toxi_scores_compare = gr.JSON(
589
+ label="Percentage change between Input and Output",
590
+ visible=False,
591
+ elem_id="inside_group",
592
+ )
593
+
594
+ with gr.Group(visible=False) as multi_model:
595
+ model_list = list()
596
+
597
+ gr.Markdown(
598
+ "#### Run the same input on multiple models and compare the outputs"
599
+ )
600
+ gr.Markdown(
601
+ "You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!"
602
+ )
603
+ gr.Markdown(
604
+ "Use this feature to compare the same model at different checkpoints"
605
+ )
606
+ gr.Markdown("Or to benchmark your model against another one as a reference.")
607
+ gr.Markdown(
608
+ "Beware ! Generation can take up to a few minutes with very large models."
609
+ )
610
+
611
+ with gr.Row(elem_id="inside_group"):
612
+ with gr.Column():
613
+ models_multi = gr.CheckboxGroup(
614
+ choices=CHOICES,
615
+ label="Models",
616
+ interactive=True,
617
+ elem_id="inside_group",
618
+ value=None,
619
+ )
620
+ with gr.Column():
621
+ generate_button_multi = gr.Button(
622
+ "Submit your prompt", elem_id="inside_group"
623
+ )
624
+
625
+ show_params_checkbox_multi = gr.Checkbox(
626
+ label="Set custom params", interactive=True, value=False
627
+ )
628
+
629
+ with gr.Box(visible=False) as params_box_multi:
630
+ length_multi = gr.Slider(
631
+ label="Output length",
632
+ visible=True,
633
+ interactive=True,
634
+ minimum=50,
635
+ maximum=200,
636
+ value=75,
637
+ )
638
+
639
+ top_k_multi = gr.Slider(
640
+ label="top_k",
641
+ visible=True,
642
+ interactive=True,
643
+ minimum=1,
644
+ maximum=100,
645
+ value=50,
646
+ )
647
+
648
+ top_p_multi = gr.Slider(
649
+ label="top_p",
650
+ visible=True,
651
+ interactive=True,
652
+ minimum=0.1,
653
+ maximum=1,
654
+ value=0.95,
655
+ )
656
+
657
+ temperature_multi = gr.Slider(
658
+ label="temperature",
659
+ visible=True,
660
+ interactive=True,
661
+ minimum=0.1,
662
+ maximum=1,
663
+ value=0.7,
664
+ )
665
+
666
+ with gr.Row(elem_id="inside_group"):
667
+ with gr.Column(elem_id="inside_group", scale=1):
668
+ input_text_multi = gr.Textbox(
669
+ label="Write your prompt below.",
670
+ interactive=True,
671
+ lines=4,
672
+ elem_id="inside_group",
673
+ )
674
+
675
+ with gr.Column(elem_id="inside_group", scale=1):
676
+ search_bar_multi = gr.Textbox(
677
+ label="Search another model",
678
+ interactive=True,
679
+ visible=True,
680
+ elem_id="inside_group",
681
+ )
682
+
683
+ model_drop_multi = gr.Dropdown(visible=False, elem_id="inside_group")
684
+
685
+ private_checkbox_multi = gr.Checkbox(
686
+ visible=True, label="Private Model ?"
687
+ )
688
+
689
+ api_key_textbox_multi = gr.Textbox(
690
+ label="Enter your AUTH TOKEN below",
691
+ value=None,
692
+ interactive=True,
693
+ visible=False,
694
+ elem_id="pw",
695
+ )
696
+
697
+ with gr.Row() as outputs_row:
698
+ for i in range(10):
699
+ output_spans_multi = gr.HighlightedText(
700
+ visible=False, elem_id="inside_group"
701
+ )
702
+ model_list.append(output_spans_multi)
703
+
704
  with gr.Row():
705
+ gr.Markdown(
706
+ "App made during the [FSDL course](https://fullstackdeeplearning.com) \
707
+ by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa"
708
+ )
709
+
710
+ # Single Model
711
+
712
+ choose_mode.change(
713
+ fn=show_mode, inputs=choose_mode, outputs=[single_model, multi_model]
714
+ )
715
+
716
+ inspo_button.click(
717
+ fn=show_dataset,
718
+ inputs=dataset,
719
+ outputs=[prompts_drop, randomize_button, prompts_var],
720
+ )
721
+
722
+ prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text)
723
+
724
+ randomize_button.click(
725
+ fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop
726
+ ),
727
+
728
+ model_radio.change(
729
+ fn=show_search_bar, inputs=model_radio, outputs=[model_choice, search_bar]
730
+ )
731
+
732
+ search_bar.submit(
733
+ fn=search_model,
734
+ inputs=[search_bar, api_key_textbox],
735
+ outputs=model_drop,
736
+ show_progress=True,
737
+ )
738
+
739
+ private_checkbox.change(
740
+ fn=show_api_key_textbox, inputs=private_checkbox, outputs=api_key_textbox
741
+ )
742
+
743
+ model_drop.change(
744
+ fn=forward_model_choice,
745
+ inputs=model_drop,
746
+ outputs=[model_choice, custom_model_path],
747
+ )
748
+
749
+ generate_button.click(
750
+ fn=process_user_input,
751
+ inputs=[
752
+ model_choice,
753
+ api_key_textbox,
754
+ custom_model_path,
755
+ input_text,
756
+ length_single,
757
+ temperature_single,
758
+ top_p_single,
759
+ top_k_single,
760
+ ],
761
+ outputs=[output_spans, toxi_button, flag_button, input_var, output_var],
762
+ show_progress=True,
763
+ )
764
+
765
+ toxi_button.click(
766
+ fn=compute_toxi_output,
767
+ inputs=output_var,
768
+ outputs=[toxi_scores_output, toxi_button_compare],
769
+ show_progress=True,
770
+ )
771
+
772
+ toxi_button_compare.click(
773
+ fn=compare_toxi_scores,
774
+ inputs=[input_text, toxi_scores_output],
775
+ outputs=[toxi_scores_input, toxi_scores_compare],
776
+ show_progress=True,
777
+ )
778
+
779
+ flag_button.click(fn=show_flag_choices, inputs=None, outputs=flag_radio)
780
+
781
+ flag_radio.change(
782
+ fn=update_flag,
783
+ inputs=flag_radio,
784
+ outputs=[flag_choice, confirm_flag_button, user_comment, flag_button],
785
+ )
786
+
787
+ flagging_callback.setup(
788
+ [input_var, output_var, model_choice, user_comment, flag_choice],
789
+ "flagged_data_points",
790
+ )
791
+
792
+ confirm_flag_button.click(
793
+ fn=upload_flag,
794
+ inputs=[input_var, output_var, model_choice, user_comment, flag_choice],
795
+ outputs=success_message,
796
+ )
797
+
798
+ show_params_checkbox_single.change(
799
+ fn=show_params, inputs=show_params_checkbox_single, outputs=params_box_single
800
+ )
801
+
802
+ # Model comparison
803
+
804
+ search_bar_multi.submit(
805
+ fn=search_model,
806
+ inputs=[search_bar_multi, api_key_textbox_multi],
807
+ outputs=model_drop_multi,
808
+ show_progress=True,
809
+ )
810
+
811
+ show_params_checkbox_multi.change(
812
+ fn=show_params, inputs=show_params_checkbox_multi, outputs=params_box_multi
813
+ )
814
+
815
+ private_checkbox_multi.change(
816
+ fn=show_api_key_textbox,
817
+ inputs=private_checkbox_multi,
818
+ outputs=api_key_textbox_multi,
819
+ )
820
+
821
+ model_drop_multi.change(
822
+ fn=forward_model_choice_multi, inputs=model_drop_multi, outputs=[models_multi]
823
+ )
824
+
825
+ models_multi.change(fn=show_choices_multi, inputs=models_multi, outputs=model_list)
826
+
827
+ generate_button_multi.click(
828
+ fn=process_user_input_multi,
829
+ inputs=[
830
+ models_multi,
831
+ input_text_multi,
832
+ api_key_textbox_multi,
833
+ length_multi,
834
+ temperature_multi,
835
+ top_p_multi,
836
+ top_k_multi,
837
+ ],
838
+ outputs=model_list,
839
+ show_progress=True,
840
+ )
841
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
842
  if __name__ == "__main__":
843
+ # demo.queue(concurrency_count=3)
844
+ demo.launch(debug=True)