sylviachency JAZ commited on
Commit
01dff0a
0 Parent(s):

Duplicate from fsdlredteam/BuggingSpace

Browse files

Co-authored-by: Jean-Antoine Z. <JAZ@users.noreply.huggingface.co>

Files changed (4) hide show
  1. .gitattributes +31 -0
  2. README.md +14 -0
  3. app.py +705 -0
  4. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: BuggingSpace
3
+ emoji: 🤔
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.3.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: fsdlredteam/BuggingSpace
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ import numpy as np
5
+ import gradio as gr
6
+
7
+ from random import sample
8
+ from detoxify import Detoxify
9
+ from datasets import load_dataset
10
+ from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
13
+ from transformers import BloomTokenizerFast, BloomForCausalLM
14
+
15
+ HF_AUTH_TOKEN = os.environ.get('hf_token' or True)
16
+
17
+ DATASET = "allenai/real-toxicity-prompts"
18
+
19
+ CHECKPOINTS = {
20
+ "DistilGPT2 by HuggingFace 🤗" : "distilgpt2",
21
+ "GPT-Neo 125M by EleutherAI 🤖" : "EleutherAI/gpt-neo-125M",
22
+ "BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m",
23
+ "Custom Model" : None
24
+ }
25
+
26
+ MODEL_CLASSES = {
27
+ "DistilGPT2 by HuggingFace 🤗" : (GPT2LMHeadModel, GPT2Tokenizer),
28
+ "GPT-Neo 125M by EleutherAI 🤖" : (GPTNeoForCausalLM, GPT2Tokenizer),
29
+ "BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
30
+ "Custom Model" : (AutoModelForCausalLM, AutoTokenizer),
31
+ }
32
+
33
+ CHOICES = sorted(list(CHECKPOINTS.keys())[:3])
34
+
35
+ def load_model(model_name, custom_model_path, token):
36
+ try:
37
+ model_class, tokenizer_class = MODEL_CLASSES[model_name]
38
+ model_path = CHECKPOINTS[model_name]
39
+
40
+ except KeyError:
41
+ model_class, tokenizer_class = MODEL_CLASSES['Custom Model']
42
+ model_path = custom_model_path or model_name
43
+
44
+ model = model_class.from_pretrained(model_path, use_auth_token=token)
45
+ tokenizer = tokenizer_class.from_pretrained(model_path, use_auth_token=token)
46
+
47
+ tokenizer.pad_token = tokenizer.eos_token
48
+ model.config.pad_token_id = model.config.eos_token_id
49
+
50
+ model.eval()
51
+
52
+ return model, tokenizer
53
+
54
+ MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
55
+
56
+ def set_seed(seed, n_gpu):
57
+ np.random.seed(seed)
58
+ torch.manual_seed(seed)
59
+ if n_gpu > 0:
60
+ torch.cuda.manual_seed_all(seed)
61
+
62
+ def adjust_length_to_model(length, max_sequence_length):
63
+ if length < 0 and max_sequence_length > 0:
64
+ length = max_sequence_length
65
+ elif 0 < max_sequence_length < length:
66
+ length = max_sequence_length # No generation bigger than model size
67
+ elif length < 0:
68
+ length = MAX_LENGTH # avoid infinite loop
69
+ return length
70
+
71
+ def generate(model_name,
72
+ token,
73
+ custom_model_path,
74
+ input_sentence,
75
+ length = 75,
76
+ temperature = 0.7,
77
+ top_k = 50,
78
+ top_p = 0.95,
79
+ seed = 42,
80
+ no_cuda = False,
81
+ num_return_sequences = 1,
82
+ stop_token = '.'
83
+ ):
84
+
85
+ # load device
86
+ #if not no_cuda:
87
+ device = torch.device("cuda" if torch.cuda.is_available() and not no_cuda else "cpu")
88
+ n_gpu = 0 if no_cuda else torch.cuda.device_count()
89
+
90
+ # Set seed
91
+ set_seed(seed, n_gpu)
92
+
93
+ # Load model
94
+ model, tokenizer = load_model(model_name, custom_model_path, token)
95
+ model.to(device)
96
+
97
+ #length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
98
+
99
+ # Tokenize input
100
+ encoded_prompt = tokenizer.encode(input_sentence,
101
+ add_special_tokens=False,
102
+ return_tensors='pt')
103
+
104
+ encoded_prompt = encoded_prompt.to(device)
105
+
106
+ input_ids = encoded_prompt
107
+
108
+ # Generate output
109
+ output_sequences = model.generate(input_ids=input_ids,
110
+ max_length=length + len(encoded_prompt[0]),
111
+ temperature=temperature,
112
+ top_k=top_k,
113
+ top_p=top_p,
114
+ do_sample=True,
115
+ num_return_sequences=num_return_sequences
116
+ )
117
+ generated_sequences = list()
118
+
119
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
120
+ generated_sequence = generated_sequence.tolist()
121
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
122
+ #remove prompt
123
+ text = text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
124
+
125
+ #remove all text after last occurence of stop_token
126
+ text = text[:text.rfind(stop_token)+1]
127
+
128
+ generated_sequences.append(text)
129
+
130
+ return generated_sequences[0]
131
+
132
+
133
+ def show_mode(mode):
134
+ if mode == 'Single Model':
135
+ return (
136
+ gr.update(visible=True),
137
+ gr.update(visible=False)
138
+ )
139
+ if mode == 'Multi-Model':
140
+ return (
141
+ gr.update(visible=False),
142
+ gr.update(visible=True)
143
+ )
144
+
145
+ def prepare_dataset(dataset):
146
+ dataset = load_dataset(dataset, split='train')
147
+ return dataset
148
+
149
+ def load_prompts(dataset):
150
+ prompts = [dataset[i]['prompt']['text'] for i in range(len(dataset))]
151
+ return prompts
152
+
153
+ def random_sample(prompt_list):
154
+ random_sample = sample(prompt_list,10)
155
+ return random_sample
156
+
157
+ def show_dataset(dataset):
158
+ raw_data = prepare_dataset(dataset)
159
+ prompts = load_prompts(raw_data)
160
+
161
+ return (gr.update(choices=random_sample(prompts),
162
+ label='You can find below a random subset from the RealToxicityPrompts dataset',
163
+ visible=True),
164
+ gr.update(visible=True),
165
+ prompts,
166
+ )
167
+
168
+ def update_dropdown(prompts):
169
+ return gr.update(choices=random_sample(prompts))
170
+
171
+ def show_search_bar(value):
172
+ if value == 'Custom Model':
173
+ return (value,
174
+ gr.update(visible=True)
175
+ )
176
+ else:
177
+ return (value,
178
+ gr.update(visible=False)
179
+ )
180
+
181
+ def search_model(model_name, token):
182
+ api = HfApi()
183
+
184
+ model_args = ModelSearchArguments()
185
+ filt = ModelFilter(
186
+ task=model_args.pipeline_tag.TextGeneration,
187
+ library=model_args.library.PyTorch)
188
+
189
+ results = api.list_models(filter=filt, search=model_name, use_auth_token=token)
190
+ model_list = [model.modelId for model in results]
191
+
192
+ return gr.update(visible=True,
193
+ choices=model_list,
194
+ label='Choose the model',
195
+ )
196
+
197
+ def show_api_key_textbox(checkbox):
198
+ if checkbox:
199
+ return gr.update(visible=True)
200
+ else:
201
+ return gr.update(visible=False)
202
+
203
+ def forward_model_choice(model_choice_path):
204
+ return (model_choice_path,
205
+ model_choice_path)
206
+
207
+ def auto_complete(input, generated):
208
+ output = input + ' ' + generated
209
+ output_spans = [{'entity': 'OUTPUT', 'start': len(input), 'end': len(output)}]
210
+ completed_prompt = {"text": output, "entities": output_spans}
211
+ return completed_prompt
212
+
213
+ def process_user_input(model,
214
+ token,
215
+ custom_model_path,
216
+ input,
217
+ length,
218
+ temperature,
219
+ top_p,
220
+ top_k):
221
+ warning = 'Please enter a valid prompt.'
222
+ if input == None:
223
+ generated = warning
224
+ else:
225
+ generated = generate(model_name=model,
226
+ token=token,
227
+ custom_model_path=custom_model_path,
228
+ input_sentence=input,
229
+ length=length,
230
+ temperature=temperature,
231
+ top_p=top_p,
232
+ top_k=top_k)
233
+ generated_with_spans = auto_complete(input=input, generated=generated)
234
+
235
+ return (
236
+ gr.update(value=generated_with_spans),
237
+ gr.update(visible=True),
238
+ gr.update(visible=True),
239
+ input,
240
+ generated
241
+ )
242
+
243
+ def pass_to_textbox(input):
244
+ return gr.update(value=input)
245
+
246
+ def run_detoxify(text):
247
+ results = Detoxify('original').predict(text)
248
+ json_ready_results = {cat:float(score) for (cat,score) in results.items()}
249
+ return json_ready_results
250
+
251
+ def compute_toxi_output(output_text):
252
+ scores = run_detoxify(output_text)
253
+ return (
254
+ gr.update(value=scores, visible=True),
255
+ gr.update(visible=True)
256
+ )
257
+
258
+ def compute_change(input, output):
259
+ change_percent = round(((float(output)-input)/input)*100, 2)
260
+ return change_percent
261
+
262
+ def compare_toxi_scores(input_text, output_scores):
263
+ input_scores = run_detoxify(input_text)
264
+ json_ready_results = {cat:float(score) for (cat,score) in input_scores.items()}
265
+
266
+ compare_scores = {
267
+ cat:compute_change(json_ready_results[cat], output_scores[cat])
268
+ for cat in json_ready_results
269
+ for cat in output_scores
270
+ }
271
+
272
+ return (
273
+ gr.update(value=json_ready_results, visible=True),
274
+ gr.update(value=compare_scores, visible=True)
275
+ )
276
+
277
+ def show_flag_choices():
278
+ return gr.update(visible=True)
279
+
280
+ def update_flag(flag_value):
281
+ return (flag_value,
282
+ gr.update(visible=True),
283
+ gr.update(visible=True),
284
+ gr.update(visible=False)
285
+ )
286
+
287
+ def upload_flag(*args):
288
+ if flagging_callback.flag(list(args), flag_option = None):
289
+ return gr.update(visible=True)
290
+
291
+ def forward_model_choice_multi(model_choice_path):
292
+ CHOICES.append(model_choice_path)
293
+ return gr.update(choices = CHOICES)
294
+
295
+ def process_user_input_multi(models,
296
+ input,
297
+ token,
298
+ length,
299
+ temperature,
300
+ top_p,
301
+ top_k):
302
+ warning = 'Please enter a valid prompt.'
303
+ if input == None:
304
+ generated = warning
305
+ else:
306
+ generated_dict= {model:generate(model_name=model,
307
+ token=token,
308
+ custom_model_path=None,
309
+ input_sentence=input,
310
+ length=length,
311
+ temperature=temperature,
312
+ top_p=top_p,
313
+ top_k=top_k) for model in sorted(models)}
314
+ generated_with_spans_dict = {model:auto_complete(input, generated) for model,generated in generated_dict.items()}
315
+
316
+ update_outputs = [gr.HighlightedText.update(value=output, label=model) for model,output in generated_with_spans_dict.items()]
317
+ update_hide = [gr.HighlightedText.update(visible=False) for i in range(10-len(models))]
318
+ return update_outputs + update_hide
319
+
320
+ def show_choices_multi(models):
321
+ update_show = [gr.HighlightedText.update(visible=True) for model in sorted(models)]
322
+ update_hide = [gr.HighlightedText.update(visible=False,value=None, label=None) for i in range(10-len(models))]
323
+
324
+ return update_show + update_hide
325
+
326
+ def show_params(checkbox):
327
+ if checkbox == True:
328
+ return gr.update(visible=True)
329
+ else:
330
+ return gr.update(visible=False)
331
+
332
+ CSS = """
333
+ #inside_group {
334
+ padding-top: 0.6em;
335
+ padding-bottom: 0.6em;
336
+ }
337
+ #pw textarea {
338
+ -webkit-text-security: disc;
339
+ }
340
+ """
341
+
342
+ with gr.Blocks(css=CSS) as demo:
343
+
344
+ dataset = gr.Variable(value=DATASET)
345
+ prompts_var = gr.Variable(value=None)
346
+ input_var = gr.Variable(label="Input Prompt", value=None)
347
+ output_var = gr.Variable(label="Output",value=None)
348
+ model_choice = gr.Variable(label="Model", value=None)
349
+ custom_model_path = gr.Variable(value=None)
350
+ flag_choice = gr.Variable(label = "Flag", value=None)
351
+
352
+ flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
353
+ dataset_name = "fsdlredteam/flagged_2",
354
+ organization = "fsdlredteam",
355
+ private = True )
356
+
357
+ gr.Markdown("<p align='center'><img src='https://i.imgur.com/ZxbbLUQ.png>'/></p>")
358
+ gr.Markdown("<h1 align='center'>BuggingSpace</h1>")
359
+ gr.Markdown("<h2 align='center'>FSDL 2022 Red-Teaming Open-Source Models Project</h2>")
360
+ gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
361
+ gr.Markdown("### Or compare the output of multiple models at the same time")
362
+
363
+
364
+ choose_mode = gr.Radio(choices=['Single Model', "Multi-Model"],
365
+ value='Single Model',
366
+ interactive=True,
367
+ visible=True,
368
+ show_label=False)
369
+
370
+ with gr.Group() as single_model:
371
+
372
+ gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, \
373
+ provided you use your private key! "
374
+ "Write your prompt or alternatively use one from the \
375
+ [RealToxicityPrompts](https://allenai.org/data/real-toxicity-prompts) dataset.")
376
+ gr.Markdown("Use it to audit the model for potential failure modes, \
377
+ analyse its output with the Detoxify suite and contribute by reporting any problematic result.")
378
+ gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.")
379
+
380
+ with gr.Row():
381
+
382
+ with gr.Column(scale=1): # input & prompts dataset exploration
383
+ gr.Markdown("### 1. Select a prompt", elem_id="inside_group")
384
+
385
+ input_text = gr.Textbox(label="Write your prompt below.",
386
+ interactive=True,
387
+ lines=4,
388
+ elem_id="inside_group")
389
+
390
+ gr.Markdown("— or —", elem_id="inside_group")
391
+
392
+ inspo_button = gr.Button('Click here if you need some inspiration', elem_id="inside_group")
393
+
394
+ prompts_drop = gr.Dropdown(visible=False, elem_id="inside_group")
395
+
396
+ randomize_button = gr.Button('Show another subset', visible=False, elem_id="inside_group")
397
+
398
+ show_params_checkbox_single = gr.Checkbox(label='Set custom params',
399
+ interactive=True,
400
+ value=False)
401
+
402
+ with gr.Box(visible=False) as params_box_single:
403
+
404
+ length_single = gr.Slider(label='Output length',
405
+ visible=True,
406
+ interactive=True,
407
+ minimum=50,
408
+ maximum=200,
409
+ value=75)
410
+
411
+ top_k_single = gr.Slider(label='top_k',
412
+ visible=True,
413
+ interactive=True,
414
+ minimum=1,
415
+ maximum=100,
416
+ value=50)
417
+
418
+ top_p_single = gr.Slider(label='top_p',
419
+ visible=True,
420
+ interactive=True,
421
+ minimum=0.1,
422
+ maximum=1,
423
+ value=0.95)
424
+
425
+ temperature_single = gr.Slider(label='temperature',
426
+ visible=True,
427
+ interactive=True,
428
+ minimum=0.1,
429
+ maximum=1,
430
+ value=0.7)
431
+
432
+
433
+ with gr.Column(scale=1): # Model choice & output
434
+ gr.Markdown("### 2. Evaluate output")
435
+
436
+ model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
437
+ label='Model',
438
+ interactive=True,
439
+ elem_id="inside_group")
440
+
441
+ search_bar = gr.Textbox(label="Search model",
442
+ interactive=True,
443
+ visible=False,
444
+ elem_id="inside_group")
445
+ model_drop = gr.Dropdown(visible=False)
446
+
447
+ private_checkbox = gr.Checkbox(visible=True,label="Private Model ?", elem_id="inside_group")
448
+
449
+ api_key_textbox = gr.Textbox(label="Enter your AUTH TOKEN below",
450
+ value=None,
451
+ interactive=True,
452
+ visible=False,
453
+ elem_id="pw")
454
+
455
+ generate_button = gr.Button('Submit your prompt', elem_id="inside_group")
456
+
457
+ output_spans = gr.HighlightedText(visible=True, label="Generated text")
458
+
459
+ flag_button = gr.Button("Report output here", visible=False, elem_id="inside_group")
460
+
461
+ with gr.Row(): # Flagging
462
+
463
+ with gr.Column(scale=1):
464
+ flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
465
+ label="What's wrong with the output ?",
466
+ interactive=True,
467
+ visible=False,
468
+ elem_id="inside_group")
469
+
470
+ user_comment = gr.Textbox(label="(Optional) Briefly describe the issue",
471
+ visible=False,
472
+ interactive=True,
473
+ elem_id="inside_group")
474
+
475
+ confirm_flag_button = gr.Button("Confirm report", visible=False, elem_id="inside_group")
476
+
477
+ with gr.Row(): # Flagging success
478
+ success_message = gr.Markdown("Your report has been successfully registered. Thank you!",
479
+ visible=False,
480
+ elem_id="inside_group")
481
+
482
+ with gr.Row(): # Toxicity buttons
483
+ toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False, elem_id="inside_group")
484
+ toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False, elem_id="inside_group")
485
+
486
+ with gr.Row(): # Toxicity scores
487
+ toxi_scores_input = gr.JSON(label = "Detoxify classification of your input",
488
+ visible=False,
489
+ elem_id="inside_group")
490
+ toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output",
491
+ visible=False,
492
+ elem_id="inside_group")
493
+ toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output",
494
+ visible=False,
495
+ elem_id="inside_group")
496
+
497
+ with gr.Group(visible=False) as multi_model:
498
+ model_list = list()
499
+
500
+ gr.Markdown("#### Run the same input on multiple models and compare the outputs")
501
+ gr.Markdown("You can upload any model from the Hugging Face hub -even private ones, provided you use your private key!")
502
+ gr.Markdown("Use this feature to compare the same model at different checkpoints")
503
+ gr.Markdown('Or to benchmark your model against another one as a reference.')
504
+ gr.Markdown("Beware ! Generation can take up to a few minutes with very large models.")
505
+
506
+ with gr.Row(elem_id="inside_group"):
507
+ with gr.Column():
508
+ models_multi = gr.CheckboxGroup(choices=CHOICES,
509
+ label='Models',
510
+ interactive=True,
511
+ elem_id="inside_group",
512
+ value=None)
513
+ with gr.Column():
514
+ generate_button_multi = gr.Button('Submit your prompt',elem_id="inside_group")
515
+
516
+ show_params_checkbox_multi = gr.Checkbox(label='Set custom params',
517
+ interactive=True,
518
+ value=False)
519
+
520
+ with gr.Box(visible=False) as params_box_multi:
521
+
522
+ length_multi = gr.Slider(label='Output length',
523
+ visible=True,
524
+ interactive=True,
525
+ minimum=50,
526
+ maximum=200,
527
+ value=75)
528
+
529
+ top_k_multi = gr.Slider(label='top_k',
530
+ visible=True,
531
+ interactive=True,
532
+ minimum=1,
533
+ maximum=100,
534
+ value=50)
535
+
536
+ top_p_multi = gr.Slider(label='top_p',
537
+ visible=True,
538
+ interactive=True,
539
+ minimum=0.1,
540
+ maximum=1,
541
+ value=0.95)
542
+
543
+ temperature_multi = gr.Slider(label='temperature',
544
+ visible=True,
545
+ interactive=True,
546
+ minimum=0.1,
547
+ maximum=1,
548
+ value=0.7)
549
+
550
+ with gr.Row(elem_id="inside_group"):
551
+
552
+ with gr.Column(elem_id="inside_group", scale=1):
553
+ input_text_multi = gr.Textbox(label="Write your prompt below.",
554
+ interactive=True,
555
+ lines=4,
556
+ elem_id="inside_group")
557
+
558
+ with gr.Column(elem_id="inside_group", scale=1):
559
+ search_bar_multi = gr.Textbox(label="Search another model",
560
+ interactive=True,
561
+ visible=True,
562
+ elem_id="inside_group")
563
+
564
+ model_drop_multi = gr.Dropdown(visible=False,
565
+ show_progress=True,
566
+ elem_id="inside_group")
567
+
568
+ private_checkbox_multi = gr.Checkbox(visible=True,label="Private Model ?")
569
+
570
+ api_key_textbox_multi = gr.Textbox(label="Enter your AUTH TOKEN below",
571
+ value=None,
572
+ interactive=True,
573
+ visible=False,
574
+ elem_id="pw")
575
+
576
+ with gr.Row() as outputs_row:
577
+ for i in range(10):
578
+ output_spans_multi = gr.HighlightedText(visible=False, elem_id="inside_group")
579
+ model_list.append(output_spans_multi)
580
+
581
+
582
+ with gr.Row():
583
+ gr.Markdown('App made during the [FSDL course](https://fullstackdeeplearning.com) \
584
+ by Team53: Jean-Antoine, Sajenthan, Sashank, Kemp, Srihari, Astitwa')
585
+
586
+ # Single Model
587
+
588
+ choose_mode.change(fn=show_mode,
589
+ inputs=choose_mode,
590
+ outputs=[single_model, multi_model])
591
+
592
+ inspo_button.click(fn=show_dataset,
593
+ inputs=dataset,
594
+ outputs=[prompts_drop, randomize_button, prompts_var])
595
+
596
+ prompts_drop.change(fn=pass_to_textbox,
597
+ inputs=prompts_drop,
598
+ outputs=input_text)
599
+
600
+ randomize_button.click(fn=update_dropdown,
601
+ inputs=prompts_var,
602
+ outputs=prompts_drop),
603
+
604
+ model_radio.change(fn=show_search_bar,
605
+ inputs=model_radio,
606
+ outputs=[model_choice,search_bar])
607
+
608
+ search_bar.submit(fn=search_model,
609
+ inputs=[search_bar,api_key_textbox],
610
+ outputs=model_drop,
611
+ show_progress=True)
612
+
613
+ private_checkbox.change(fn=show_api_key_textbox,
614
+ inputs=private_checkbox,
615
+ outputs=api_key_textbox)
616
+
617
+ model_drop.change(fn=forward_model_choice,
618
+ inputs=model_drop,
619
+ outputs=[model_choice,custom_model_path])
620
+
621
+ generate_button.click(fn=process_user_input,
622
+ inputs=[model_choice,
623
+ api_key_textbox,
624
+ custom_model_path,
625
+ input_text,
626
+ length_single,
627
+ temperature_single,
628
+ top_p_single,
629
+ top_k_single],
630
+ outputs=[output_spans,
631
+ toxi_button,
632
+ flag_button,
633
+ input_var,
634
+ output_var],
635
+ show_progress=True)
636
+
637
+ toxi_button.click(fn=compute_toxi_output,
638
+ inputs=output_var,
639
+ outputs=[toxi_scores_output, toxi_button_compare],
640
+ show_progress=True)
641
+
642
+ toxi_button_compare.click(fn=compare_toxi_scores,
643
+ inputs=[input_text, toxi_scores_output],
644
+ outputs=[toxi_scores_input, toxi_scores_compare],
645
+ show_progress=True)
646
+
647
+ flag_button.click(fn=show_flag_choices,
648
+ inputs=None,
649
+ outputs=flag_radio)
650
+
651
+ flag_radio.change(fn=update_flag,
652
+ inputs=flag_radio,
653
+ outputs=[flag_choice, confirm_flag_button, user_comment, flag_button])
654
+
655
+ flagging_callback.setup([input_var, output_var, model_choice, user_comment, flag_choice], "flagged_data_points")
656
+
657
+ confirm_flag_button.click(fn = upload_flag,
658
+ inputs = [input_var,
659
+ output_var,
660
+ model_choice,
661
+ user_comment,
662
+ flag_choice],
663
+ outputs=success_message)
664
+
665
+ show_params_checkbox_single.change(fn=show_params,
666
+ inputs=show_params_checkbox_single,
667
+ outputs=params_box_single)
668
+
669
+ # Model comparison
670
+
671
+ search_bar_multi.submit(fn=search_model,
672
+ inputs=[search_bar_multi, api_key_textbox_multi],
673
+ outputs=model_drop_multi,
674
+ show_progress=True)
675
+
676
+ show_params_checkbox_multi.change(fn=show_params,
677
+ inputs=show_params_checkbox_multi,
678
+ outputs=params_box_multi)
679
+
680
+ private_checkbox_multi.change(fn=show_api_key_textbox,
681
+ inputs=private_checkbox_multi,
682
+ outputs=api_key_textbox_multi)
683
+
684
+ model_drop_multi.change(fn=forward_model_choice_multi,
685
+ inputs=model_drop_multi,
686
+ outputs=[models_multi])
687
+
688
+ models_multi.change(fn=show_choices_multi,
689
+ inputs=models_multi,
690
+ outputs=model_list)
691
+
692
+ generate_button_multi.click(fn=process_user_input_multi,
693
+ inputs=[models_multi,
694
+ input_text_multi,
695
+ api_key_textbox_multi,
696
+ length_multi,
697
+ temperature_multi,
698
+ top_p_multi,
699
+ top_k_multi],
700
+ outputs=model_list,
701
+ show_progress=True)
702
+
703
+ #demo.launch(debug=True)
704
+ if __name__ == "__main__":
705
+ demo.launch(enable_queue=False, debug=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ numpy==1.21.6
3
+ gradio==3.3.1
4
+ detoxify==0.5.0
5
+ datasets==2.5.1
6
+ transformers==4.22.1