J-Antoine ZAGATO commited on
Commit
5962754
β€’
1 Parent(s): 09ef45c

Multiple UI changes + added modelsearch + added more flagging options and user feedback

Browse files
Files changed (1) hide show
  1. app.py +166 -60
app.py CHANGED
@@ -7,28 +7,38 @@ import gradio as gr
7
  from random import sample
8
  from detoxify import Detoxify
9
  from datasets import load_dataset
 
 
10
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
11
  from transformers import BloomTokenizerFast, BloomForCausalLM
12
 
13
  HF_AUTH_TOKEN = os.environ.get('hf_token' or True)
 
14
  DATASET = "allenai/real-toxicity-prompts"
15
 
16
  CHECKPOINTS = {
17
  "DistilGPT2 by HuggingFace πŸ€—" : "distilgpt2",
18
  "GPT-Neo 125M by EleutherAI πŸ€–" : "EleutherAI/gpt-neo-125M",
19
- "BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m"
 
20
  }
21
 
22
  MODEL_CLASSES = {
23
  "DistilGPT2 by HuggingFace πŸ€—" : (GPT2LMHeadModel, GPT2Tokenizer),
24
  "GPT-Neo 125M by EleutherAI πŸ€–" : (GPTNeoForCausalLM, GPT2Tokenizer),
25
  "BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
 
26
  }
27
 
28
- def load_model(model_name):
29
- model_class, tokenizer_class = MODEL_CLASSES[model_name]
30
-
31
- model_path = CHECKPOINTS[model_name]
 
 
 
 
 
32
  model = model_class.from_pretrained(model_path)
33
  tokenizer = tokenizer_class.from_pretrained(model_path)
34
 
@@ -57,6 +67,7 @@ def adjust_length_to_model(length, max_sequence_length):
57
  return length
58
 
59
  def generate(model_name,
 
60
  input_sentence,
61
  length = 75,
62
  temperature = 0.7,
@@ -77,7 +88,7 @@ def generate(model_name,
77
  set_seed(seed, n_gpu)
78
 
79
  # Load model
80
- model, tokenizer = load_model(model_name)
81
  model.to(device)
82
 
83
  #length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
@@ -116,7 +127,6 @@ def generate(model_name,
116
  return generated_sequences[0]
117
 
118
  def prepare_dataset(dataset):
119
-
120
  dataset = load_dataset(dataset, split='train')
121
  return dataset
122
 
@@ -142,17 +152,52 @@ def show_dataset(dataset):
142
  def update_dropdown(prompts):
143
  return gr.update(choices=random_sample(prompts))
144
 
145
- def process_user_input(model, input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  warning = 'Please enter a valid prompt.'
147
  if input == None:
148
  generated = warning
149
  else:
150
- generated = generate(model, input)
 
151
 
152
  return (
153
- gr.update(visible = True, value=generated),
154
- gr.update(visible=True),
155
- gr.update(visible=True),
156
  gr.update(visible=True),
157
  gr.update(visible=True),
158
  input,
@@ -193,103 +238,164 @@ def compare_toxi_scores(input_text, output_scores):
193
  gr.update(value=compare_scores, visible=True)
194
  )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  with gr.Blocks() as demo:
197
  gr.Markdown("# Project Interface proposal")
198
- gr.Markdown("### Write description and user instructions here")
 
199
  dataset = gr.Variable(value=DATASET)
200
  prompts_var = gr.Variable(value=None)
201
  input_var = gr.Variable(label="Input Prompt", value=None)
202
  output_var = gr.Variable(label="Output",value=None)
 
 
 
 
 
203
  flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
204
- dataset_name = "fsdlredteam/flagged",
205
  organization = "fsdlredteam",
206
  private = True )
207
 
208
- with gr.Row(equal_height=True):
209
 
210
- with gr.Column(): # input & prompts dataset exploration
211
  gr.Markdown("### 1. Select a prompt")
212
 
213
  input_text = gr.Textbox(label="Write your prompt below.", interactive=True, lines=4)
 
214
  gr.Markdown("β€” or β€”")
 
215
  inspo_button = gr.Button('Click here if you need some inspiration')
216
 
217
  prompts_drop = gr.Dropdown(visible=False)
218
- prompts_drop.change(fn=pass_to_textbox, inputs=prompts_drop, outputs=input_text)
219
 
220
  randomize_button = gr.Button('Show another subset', visible=False)
221
 
222
 
223
- with gr.Column(): # Model choice & output
224
  gr.Markdown("### 2. Evaluate output")
225
 
226
- generate_button = gr.Button('Pick a model below and submit your prompt')
227
  model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
228
  label='Model',
229
  interactive=True)
230
- model_choice = gr.Variable(value=None)
231
- model_radio.change(fn=lambda value: value, inputs=model_radio, outputs=model_choice)
 
 
 
 
 
232
 
233
- output_text = gr.Textbox(label="Generated prompt.", visible=False)
234
 
235
- with gr.Row(equal_height=True): # Flagging
236
- flagging_callback.setup([input_text, output_text, model_radio], "flagged_data_points")
237
 
238
- toxi_flag_button = gr.Button("Report toxic output here", visible=False)
239
- unexpected_flag_button = gr.Button("Report incorrect output here", visible=False)
240
- other_flag_button = gr.Button("Report other inappropriate output here", visible=False)
 
 
 
 
 
 
241
 
242
- with gr.Row(equal_height=True): # Toxicity buttons
 
 
 
 
 
 
243
  toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False)
244
  toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False)
245
 
246
- with gr.Row(equal_height=True): # Toxicity scores
247
- toxi_scores_input = gr.JSON(label = "Detoxify classification of your input", visible=False)
248
- toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output", visible=False)
249
- toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output", visible=False)
 
 
 
250
 
251
 
252
  inspo_button.click(fn=show_dataset,
253
  inputs=dataset,
254
  outputs=[prompts_drop, randomize_button, prompts_var])
255
 
 
 
 
 
256
  randomize_button.click(fn=update_dropdown,
257
  inputs=prompts_var,
258
- outputs=prompts_drop)
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  generate_button.click(fn=process_user_input,
261
- inputs=[model_choice, input_text],
262
- outputs=[output_text,
263
- toxi_button,
264
- toxi_flag_button,
265
- unexpected_flag_button,
266
- other_flag_button,
267
  input_var,
268
- output_var])
 
269
 
270
  toxi_button.click(fn=compute_toxi_output,
271
- inputs=output_text,
272
- outputs=[toxi_scores_output, toxi_button_compare])
 
273
 
274
  toxi_button_compare.click(fn=compare_toxi_scores,
275
  inputs=[input_text, toxi_scores_output],
276
- outputs=[toxi_scores_input, toxi_scores_compare])
277
-
278
- toxi_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "toxic"),
279
- inputs=[input_text, output_text, model_radio],
280
- outputs=None,)
281
- #preprocess=False) #preprocess throws an error on HF space
282
-
283
- unexpected_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "unexpected"),
284
- inputs=[input_text, output_text, model_radio],
285
- outputs=None,)
286
- #preprocess=False)
287
-
288
- other_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "other"),
289
- inputs=[input_text, output_text, model_radio],
290
- outputs=None,)
291
- #preprocess=False)
292
-
 
 
 
 
293
  #demo.launch(debug=True)
294
  if __name__ == "__main__":
295
- demo.launch(enable_queue=False)
 
7
  from random import sample
8
  from detoxify import Detoxify
9
  from datasets import load_dataset
10
+ from huggingface_hub import HfApi, ModelFilter, ModelSearchArguments
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
13
  from transformers import BloomTokenizerFast, BloomForCausalLM
14
 
15
  HF_AUTH_TOKEN = os.environ.get('hf_token' or True)
16
+
17
  DATASET = "allenai/real-toxicity-prompts"
18
 
19
  CHECKPOINTS = {
20
  "DistilGPT2 by HuggingFace πŸ€—" : "distilgpt2",
21
  "GPT-Neo 125M by EleutherAI πŸ€–" : "EleutherAI/gpt-neo-125M",
22
+ "BLOOM 560M by BigScience 🌸" : "bigscience/bloom-560m",
23
+ "Custom Model" : None
24
  }
25
 
26
  MODEL_CLASSES = {
27
  "DistilGPT2 by HuggingFace πŸ€—" : (GPT2LMHeadModel, GPT2Tokenizer),
28
  "GPT-Neo 125M by EleutherAI πŸ€–" : (GPTNeoForCausalLM, GPT2Tokenizer),
29
  "BLOOM 560M by BigScience 🌸" : (BloomForCausalLM, BloomTokenizerFast),
30
+ "Custom Model" : (AutoModelForCausalLM, AutoTokenizer),
31
  }
32
 
33
+ def load_model(model_name, custom_model_path):
34
+ try:
35
+ model_class, tokenizer_class = MODEL_CLASSES[model_name]
36
+ model_path = CHECKPOINTS[model_name]
37
+
38
+ except KeyError:
39
+ model_class, tokenizer_class = MODEL_CLASSES['Custom Model']
40
+ model_path = custom_model_path
41
+
42
  model = model_class.from_pretrained(model_path)
43
  tokenizer = tokenizer_class.from_pretrained(model_path)
44
 
 
67
  return length
68
 
69
  def generate(model_name,
70
+ custom_model_path,
71
  input_sentence,
72
  length = 75,
73
  temperature = 0.7,
 
88
  set_seed(seed, n_gpu)
89
 
90
  # Load model
91
+ model, tokenizer = load_model(model_name, custom_model_path)
92
  model.to(device)
93
 
94
  #length = adjust_length_to_model(length, max_sequence_length=model.config.max_position_embeddings)
 
127
  return generated_sequences[0]
128
 
129
  def prepare_dataset(dataset):
 
130
  dataset = load_dataset(dataset, split='train')
131
  return dataset
132
 
 
152
  def update_dropdown(prompts):
153
  return gr.update(choices=random_sample(prompts))
154
 
155
+ def show_search_bar(value):
156
+ if value == 'Custom Model':
157
+ return (value,
158
+ gr.update(visible=True)
159
+ )
160
+ else:
161
+ return (value,
162
+ gr.update(visible=False)
163
+ )
164
+
165
+ def search_model(model_name):
166
+ api = HfApi()
167
+
168
+ model_args = ModelSearchArguments()
169
+ filt = ModelFilter(
170
+ task=model_args.pipeline_tag.TextGeneration,
171
+ library=model_args.library.PyTorch)
172
+
173
+ results = api.list_models(filter=filt, search=model_name)
174
+ model_list = [model.modelId for model in results]
175
+
176
+ return gr.update(visible=True,
177
+ choices=model_list,
178
+ label='Choose the model',
179
+ )
180
+
181
+ def forward_model_choice(model_choice_path):
182
+ return (model_choice_path,
183
+ model_choice_path)
184
+
185
+ def auto_complete(input, generated):
186
+ output = input + ' ' + generated
187
+ output_spans = [{'entity': 'OUTPUT', 'start': len(input), 'end': len(output)}]
188
+ completed_prompt = {"text": output, "entities": output_spans}
189
+ return completed_prompt
190
+
191
+ def process_user_input(model, custom_model_path, input):
192
  warning = 'Please enter a valid prompt.'
193
  if input == None:
194
  generated = warning
195
  else:
196
+ generated = generate(model, custom_model_path, input)
197
+ generated_with_spans = auto_complete(input, generated)
198
 
199
  return (
200
+ generated_with_spans,
 
 
201
  gr.update(visible=True),
202
  gr.update(visible=True),
203
  input,
 
238
  gr.update(value=compare_scores, visible=True)
239
  )
240
 
241
+ def show_flag_choices():
242
+ return gr.update(visible=True)
243
+
244
+ def update_flag(flag_value):
245
+ return (flag_value,
246
+ gr.update(visible=True),
247
+ gr.update(visible=True),
248
+ gr.update(visible=False)
249
+ )
250
+
251
+ def upload_flag(*args):
252
+ if flagging_callback.flag(list(args), flag_option = None):
253
+ return gr.update(visible=True)
254
+
255
  with gr.Blocks() as demo:
256
  gr.Markdown("# Project Interface proposal")
257
+ gr.Markdown("### Pick a text generation model below, write a prompt and explore the output")
258
+
259
  dataset = gr.Variable(value=DATASET)
260
  prompts_var = gr.Variable(value=None)
261
  input_var = gr.Variable(label="Input Prompt", value=None)
262
  output_var = gr.Variable(label="Output",value=None)
263
+ model_choice = gr.Variable(label="Model", value=None)
264
+ custom_model_path = gr.Variable(value=None)
265
+ flag_choice = gr.Variable(label = "Flag", value=None)
266
+
267
+
268
  flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
269
+ dataset_name = "fsdlredteam/flagged_2",
270
  organization = "fsdlredteam",
271
  private = True )
272
 
273
+ with gr.Row():
274
 
275
+ with gr.Column(scale=1): # input & prompts dataset exploration
276
  gr.Markdown("### 1. Select a prompt")
277
 
278
  input_text = gr.Textbox(label="Write your prompt below.", interactive=True, lines=4)
279
+
280
  gr.Markdown("β€” or β€”")
281
+
282
  inspo_button = gr.Button('Click here if you need some inspiration')
283
 
284
  prompts_drop = gr.Dropdown(visible=False)
 
285
 
286
  randomize_button = gr.Button('Show another subset', visible=False)
287
 
288
 
289
+ with gr.Column(scale=1): # Model choice & output
290
  gr.Markdown("### 2. Evaluate output")
291
 
292
+
293
  model_radio = gr.Radio(choices=list(CHECKPOINTS.keys()),
294
  label='Model',
295
  interactive=True)
296
+
297
+ search_bar = gr.Textbox(label="Search model", interactive=True, visible=False)
298
+ model_drop = gr.Dropdown(visible=False)
299
+
300
+ generate_button = gr.Button('Submit your prompt')
301
+
302
+ output_spans = gr.HighlightedText(visible=True, label="Generated text")
303
 
304
+ flag_button = gr.Button("Report output here", visible=False)
305
 
306
+ with gr.Row(): # Flagging
 
307
 
308
+ with gr.Column(scale=1):
309
+ flag_radio = gr.Radio(choices=["Toxic", "Offensive", "Repetitive", "Incorrect", "Other",],
310
+ label="What's wrong with the output ?",
311
+ interactive=True,
312
+ visible=False)
313
+
314
+ user_comment = gr.Textbox(label="(Optional) Briefly describe the issue",
315
+ visible=False,
316
+ interactive=True)
317
 
318
+ confirm_flag_button = gr.Button("Confirm report", visible=False)
319
+
320
+ with gr.Row(): # Flagging success
321
+ success_message = gr.Markdown("Your report has been successfully registered. Thank you!",
322
+ visible=False,)
323
+
324
+ with gr.Row(): # Toxicity buttons
325
  toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False)
326
  toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False)
327
 
328
+ with gr.Row(): # Toxicity scores
329
+ toxi_scores_input = gr.JSON(label = "Detoxify classification of your input",
330
+ visible=False)
331
+ toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output",
332
+ visible=False)
333
+ toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output",
334
+ visible=False)
335
 
336
 
337
  inspo_button.click(fn=show_dataset,
338
  inputs=dataset,
339
  outputs=[prompts_drop, randomize_button, prompts_var])
340
 
341
+ prompts_drop.change(fn=pass_to_textbox,
342
+ inputs=prompts_drop,
343
+ outputs=input_text)
344
+
345
  randomize_button.click(fn=update_dropdown,
346
  inputs=prompts_var,
347
+ outputs=prompts_drop),
348
+
349
+ model_radio.change(fn=show_search_bar,
350
+ inputs=model_radio,
351
+ outputs=[model_choice,search_bar])
352
+
353
+ search_bar.submit(fn=search_model,
354
+ inputs=search_bar,
355
+ outputs=model_drop,
356
+ show_progress=True)
357
+
358
+ model_drop.change(fn=forward_model_choice,
359
+ inputs=model_drop,
360
+ outputs=[model_choice,custom_model_path])
361
 
362
  generate_button.click(fn=process_user_input,
363
+ inputs=[model_choice, custom_model_path, input_text],
364
+ outputs=[output_spans,
365
+ toxi_button,
366
+ flag_button,
 
 
367
  input_var,
368
+ output_var],
369
+ show_progress=True)
370
 
371
  toxi_button.click(fn=compute_toxi_output,
372
+ inputs=output_var,
373
+ outputs=[toxi_scores_output, toxi_button_compare],
374
+ show_progress=True)
375
 
376
  toxi_button_compare.click(fn=compare_toxi_scores,
377
  inputs=[input_text, toxi_scores_output],
378
+ outputs=[toxi_scores_input, toxi_scores_compare],
379
+ show_progress=True)
380
+
381
+ flag_button.click(fn=show_flag_choices,
382
+ inputs=None,
383
+ outputs=flag_radio)
384
+
385
+ flag_radio.change(fn=update_flag,
386
+ inputs=flag_radio,
387
+ outputs=[flag_choice, confirm_flag_button, user_comment, flag_button])
388
+
389
+ flagging_callback.setup([input_var, output_var, model_choice, user_comment, flag_choice], "flagged_data_points")
390
+
391
+ confirm_flag_button.click(fn = upload_flag,
392
+ inputs = [input_var,
393
+ output_var,
394
+ model_choice,
395
+ user_comment,
396
+ flag_choice],
397
+ outputs=success_message)
398
+
399
  #demo.launch(debug=True)
400
  if __name__ == "__main__":
401
+ demo.launch(enable_queue=False)