J-Antoine ZAGATO commited on
Commit
9d80551
1 Parent(s): 10d46ff

Added toxicity comparison & flagging + refactoring

Browse files
Files changed (1) hide show
  1. app.py +102 -22
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
 
3
  import numpy as np
@@ -9,6 +10,7 @@ from datasets import load_dataset
9
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
10
  from transformers import BloomTokenizerFast, BloomForCausalLM
11
 
 
12
  DATASET = "allenai/real-toxicity-prompts"
13
 
14
  CHECKPOINTS = {
@@ -140,19 +142,21 @@ def show_dataset(dataset):
140
  def update_dropdown(prompts):
141
  return gr.update(choices=random_sample(prompts))
142
 
143
- def show_text(text):
144
- new_text = "lol " + text
145
- return gr.update(visible = True, value=new_text)
146
-
147
  def process_user_input(model, input):
148
  warning = 'Please enter a valid prompt.'
149
  if input == None:
150
- input = warning
151
- generated = generate(model, input)
 
152
 
153
  return (
154
  gr.update(visible = True, value=generated),
155
- gr.update(visible=True)
 
 
 
 
 
156
  )
157
 
158
  def pass_to_textbox(input):
@@ -161,21 +165,52 @@ def pass_to_textbox(input):
161
  def run_detoxify(text):
162
  results = Detoxify('original').predict(text)
163
  json_ready_results = {cat:float(score) for (cat,score) in results.items()}
 
 
 
 
 
 
 
 
164
 
165
- return gr.update(value=json_ready_results, visible=True)
 
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  with gr.Blocks() as demo:
169
  gr.Markdown("# Project Interface proposal")
170
-
171
  dataset = gr.Variable(value=DATASET)
172
  prompts_var = gr.Variable(value=None)
 
 
 
 
 
 
173
 
174
  with gr.Row(equal_height=True):
175
- with gr.Column():
 
176
  gr.Markdown("### 1. Select a prompt")
177
 
178
- input_text = gr.Textbox(label="Write your prompt below.", interactive=True)
179
  gr.Markdown("— or —")
180
  inspo_button = gr.Button('Click here if you need some inspiration')
181
 
@@ -184,11 +219,8 @@ with gr.Blocks() as demo:
184
 
185
  randomize_button = gr.Button('Show another subset', visible=False)
186
 
187
- inspo_button.click(fn=show_dataset, inputs=dataset, outputs=[prompts_drop, randomize_button, prompts_var])
188
- randomize_button.click(fn=update_dropdown, inputs=prompts_var, outputs=prompts_drop)
189
 
190
- with gr.Column():
191
-
192
  gr.Markdown("### 2. Evaluate output")
193
 
194
  generate_button = gr.Button('Pick a model below and submit your prompt')
@@ -199,16 +231,64 @@ with gr.Blocks() as demo:
199
  model_radio.change(fn=lambda value: value, inputs=model_radio, outputs=model_choice)
200
 
201
  output_text = gr.Textbox(label="Generated prompt.", visible=False)
202
-
203
- toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False)
204
- toxi_scores = gr.JSON(visible=False)
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- generate_button.click(fn=process_user_input,
208
- inputs=[model_choice, input_text],
209
- outputs=[output_text,toxi_button])
210
 
211
- toxi_button.click(fn=run_detoxify, inputs=output_text, outputs=toxi_scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  #demo.launch(debug=True)
214
  if __name__ == "__main__":
 
1
+ import os
2
  import torch
3
 
4
  import numpy as np
 
10
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPTNeoForCausalLM
11
  from transformers import BloomTokenizerFast, BloomForCausalLM
12
 
13
+ HF_AUTH_TOKEN = os.environ.get('hf_token' or True)
14
  DATASET = "allenai/real-toxicity-prompts"
15
 
16
  CHECKPOINTS = {
 
142
  def update_dropdown(prompts):
143
  return gr.update(choices=random_sample(prompts))
144
 
 
 
 
 
145
  def process_user_input(model, input):
146
  warning = 'Please enter a valid prompt.'
147
  if input == None:
148
+ generated = warning
149
+ else:
150
+ generated = generate(model, input)
151
 
152
  return (
153
  gr.update(visible = True, value=generated),
154
+ gr.update(visible=True),
155
+ gr.update(visible=True),
156
+ gr.update(visible=True),
157
+ gr.update(visible=True),
158
+ input,
159
+ generated
160
  )
161
 
162
  def pass_to_textbox(input):
 
165
  def run_detoxify(text):
166
  results = Detoxify('original').predict(text)
167
  json_ready_results = {cat:float(score) for (cat,score) in results.items()}
168
+ return json_ready_results
169
+
170
+ def compute_toxi_output(output_text):
171
+ scores = run_detoxify(output_text)
172
+ return (
173
+ gr.update(value=scores, visible=True),
174
+ gr.update(visible=True)
175
+ )
176
 
177
+ def compute_change(input, output):
178
+ change_percent = round(((float(output)-input)/input)*100, 2)
179
+ return change_percent
180
 
181
+ def compare_toxi_scores(input_text, output_scores):
182
+ input_scores = run_detoxify(input_text)
183
+ json_ready_results = {cat:float(score) for (cat,score) in input_scores.items()}
184
+
185
+ compare_scores = {
186
+ cat:compute_change(json_ready_results[cat], output_scores[cat])
187
+ for cat in json_ready_results
188
+ for cat in output_scores
189
+ }
190
+
191
+ return (
192
+ gr.update(value=json_ready_results, visible=True),
193
+ gr.update(value=compare_scores, visible=True)
194
+ )
195
 
196
  with gr.Blocks() as demo:
197
  gr.Markdown("# Project Interface proposal")
198
+ gr.Markdown("### Write description and user instructions here")
199
  dataset = gr.Variable(value=DATASET)
200
  prompts_var = gr.Variable(value=None)
201
+ input_var = gr.Variable(label="Input Prompt", value=None)
202
+ output_var = gr.Variable(label="Output",value=None)
203
+ flagging_callback = gr.HuggingFaceDatasetSaver(hf_token = HF_AUTH_TOKEN,
204
+ dataset_name = "fsdlredteam/flagged",
205
+ organization = "fsdlredteam",
206
+ private = True )
207
 
208
  with gr.Row(equal_height=True):
209
+
210
+ with gr.Column(): # input & prompts dataset exploration
211
  gr.Markdown("### 1. Select a prompt")
212
 
213
+ input_text = gr.Textbox(label="Write your prompt below.", interactive=True, lines=4)
214
  gr.Markdown("— or —")
215
  inspo_button = gr.Button('Click here if you need some inspiration')
216
 
 
219
 
220
  randomize_button = gr.Button('Show another subset', visible=False)
221
 
 
 
222
 
223
+ with gr.Column(): # Model choice & output
 
224
  gr.Markdown("### 2. Evaluate output")
225
 
226
  generate_button = gr.Button('Pick a model below and submit your prompt')
 
231
  model_radio.change(fn=lambda value: value, inputs=model_radio, outputs=model_choice)
232
 
233
  output_text = gr.Textbox(label="Generated prompt.", visible=False)
 
 
 
234
 
235
+ with gr.Row(equal_height=True): # Flagging
236
+ flagging_callback.setup([input_text, output_text, model_radio], "flagged_data_points")
237
+
238
+ toxi_flag_button = gr.Button("Report toxic output here", visible=False)
239
+ unexpected_flag_button = gr.Button("Report incorrect output here", visible=False)
240
+ other_flag_button = gr.Button("Report other inappropriate output here", visible=False)
241
+
242
+ with gr.Row(equal_height=True): # Toxicity buttons
243
+ toxi_button = gr.Button("Run a toxicity analysis of the model's output", visible=False)
244
+ toxi_button_compare = gr.Button("Compare toxicity on input and output", visible=False)
245
+
246
+ with gr.Row(equal_height=True): # Toxicity scores
247
+ toxi_scores_input = gr.JSON(label = "Detoxify classification of your input", visible=False)
248
+ toxi_scores_output = gr.JSON(label="Detoxify classification of the model's output", visible=False)
249
+ toxi_scores_compare = gr.JSON(label = "Percentage change between Input and Output", visible=False)
250
 
 
 
 
251
 
252
+ inspo_button.click(fn=show_dataset,
253
+ inputs=dataset,
254
+ outputs=[prompts_drop, randomize_button, prompts_var])
255
+
256
+ randomize_button.click(fn=update_dropdown,
257
+ inputs=prompts_var,
258
+ outputs=prompts_drop)
259
+
260
+ generate_button.click(fn=process_user_input,
261
+ inputs=[model_choice, input_text],
262
+ outputs=[output_text,
263
+ toxi_button,
264
+ toxi_flag_button,
265
+ unexpected_flag_button,
266
+ other_flag_button,
267
+ input_var,
268
+ output_var])
269
+
270
+ toxi_button.click(fn=compute_toxi_output,
271
+ inputs=output_text,
272
+ outputs=[toxi_scores_output, toxi_button_compare])
273
+
274
+ toxi_button_compare.click(fn=compare_toxi_scores,
275
+ inputs=[input_text, toxi_scores_output],
276
+ outputs=[toxi_scores_input, toxi_scores_compare])
277
+
278
+ toxi_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "toxic"),
279
+ inputs=[input_text, output_text, model_radio],
280
+ outputs=None,
281
+ preprocess=False)
282
+
283
+ unexpected_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "unexpected"),
284
+ inputs=[input_text, output_text, model_radio],
285
+ outputs=None,
286
+ preprocess=False)
287
+
288
+ other_flag_button.click(lambda *args: flagging_callback.flag(args, flag_option = "other"),
289
+ inputs=[input_text, output_text, model_radio],
290
+ outputs=None,
291
+ preprocess=False)
292
 
293
  #demo.launch(debug=True)
294
  if __name__ == "__main__":