emilylearning commited on
Commit
e4f334a
1 Parent(s): a1d9fca

Support user-added models, return n-fit, doc improvement, rand output display text

Browse files
Files changed (1) hide show
  1. app.py +133 -49
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  from matplotlib.ticker import MaxNLocator
@@ -5,15 +7,17 @@ import pandas as pd
5
  import numpy as np
6
  import matplotlib.pyplot as plt
7
 
 
8
  MODEL_NAMES = ["bert-base-uncased",
9
  "distilbert-base-uncased", "xlm-roberta-base"]
 
10
 
11
  DECIMAL_PLACES = 1
12
  EPS = 1e-5 # to avoid /0 errors
13
 
14
  # Example date conts
15
  DATE_SPLIT_KEY = "DATE"
16
- START_YEAR = 1800
17
  STOP_YEAR = 1999
18
  NUM_PTS = 20
19
  DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist()
@@ -131,18 +135,14 @@ GENDERED_LIST = [
131
  ["actor", "actress"],
132
  ]
133
 
134
-
135
  # Fire up the models
136
- # TODO: Make it so models can be added in the future
137
- models_paths = dict()
138
  models = dict()
139
 
 
 
140
 
141
  # %%
142
- for bert_like in MODEL_NAMES:
143
- models_paths[bert_like] = bert_like
144
- models[bert_like] = pipeline(
145
- "fill-mask", model=models_paths[bert_like])
146
 
147
 
148
  def get_gendered_token_ids():
@@ -171,6 +171,8 @@ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_pre
171
  return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
172
 
173
  # %%
 
 
174
  def get_figure(df, gender, n_fit=1):
175
  df = df.set_index('x-axis')
176
  cols = df.columns
@@ -179,7 +181,7 @@ def get_figure(df, gender, n_fit=1):
179
  fig, ax = plt.subplots()
180
  # Trying small fig due to rendering issues on HF, not on VS Code
181
  fig.set_figheight(3)
182
- fig.set_figwidth(8)
183
 
184
  # find stackoverflow reference
185
  p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
@@ -197,30 +199,33 @@ def get_figure(df, gender, n_fit=1):
197
  ax.legend(list(df.columns))
198
 
199
  ax.axis('tight')
200
-
201
- # fig.canvas.draw()
202
-
203
  ax.set_xlabel("Value injected into input text")
204
  ax.set_title(
205
  f"Probability of predicting {gender} pronouns.")
206
  ax.set_ylabel(f"Softmax prob for pronouns")
207
  ax.xaxis.set_major_locator(MaxNLocator(6))
208
- ax.tick_params(axis='x', labelrotation=15)
209
  return fig
210
 
211
 
212
  # %%
213
  def predict_gender_pronouns(
214
- model_type,
 
215
  indie_vars,
216
  split_key,
217
  normalizing,
 
218
  input_text,
219
  ):
220
  """Run inference on input_text for each model type, returning df and plots of percentage
221
  of gender pronouns predicted as female and male in each target text.
222
  """
223
- model = models[model_type]
 
 
 
 
224
  mask_token = model.tokenizer.mask_token
225
 
226
  indie_vars_list = indie_vars.split(',')
@@ -267,17 +272,19 @@ def predict_gender_pronouns(
267
  results_df['female_pronouns'] = female_pronoun_preds
268
  results_df['male_pronouns'] = male_pronoun_preds
269
  female_fig = get_figure(results_df.drop(
270
- 'male_pronouns', axis=1), 'female')
271
  male_fig = get_figure(results_df.drop(
272
- 'female_pronouns', axis=1), 'male')
 
273
 
274
  return (
275
- target_text,
276
  female_fig,
277
  male_fig,
278
  results_df,
279
  )
280
 
 
281
  # %%
282
  title = "Causing Gender Pronouns"
283
  description = """
@@ -287,74 +294,159 @@ description = """
287
 
288
  place_example = [
289
  MODEL_NAMES[0],
 
290
  ', '.join(PLACES),
291
  'PLACE',
292
  "False",
 
293
  'Born in PLACE, she was a teacher.'
294
  ]
295
 
296
  date_example = [
297
  MODEL_NAMES[0],
 
298
  ', '.join(DATES),
299
  'DATE',
300
  "False",
 
301
  'Born in DATE, she was a doctor.'
302
  ]
303
 
304
 
305
  subreddit_example = [
306
  MODEL_NAMES[2],
 
307
  ', '.join(SUBREDDITS),
308
  'SUBREDDIT',
309
  "False",
310
- 'I saw on r/SUBREDDIT that she is a hacker.'
 
 
 
 
 
 
 
 
 
 
 
311
  ]
312
 
313
 
314
  def date_fn():
315
  return date_example
 
 
316
  def place_fn():
317
  return place_example
 
 
318
  def reddit_fn():
319
  return subreddit_example
320
 
321
 
 
 
 
 
322
  # %%
323
  demo = gr.Blocks()
324
  with demo:
325
- gr.Markdown("## Hunt for spurious correlations in our LLMs.")
326
  gr.Markdown("Although genders are relatively evenly distributed across time, place and interests, there are also known gender disparities in terms of access to resources. Here we demonstrate that this access disparity can result in dataset selection bias, causing models to learn a surprising range of spurious associations.")
327
  gr.Markdown("These spurious associations are often considered undesirable, as they do not match our intuition about the real-world domain from which we derive samples for inference-time prediction.")
328
  gr.Markdown("Selection of samples into datasets is a zero-sum-game, with even our high quality datasets forced to trade off one for another, thus inducing selection bias into the learned associations of the model.")
329
 
330
- gr.Markdown("### Dose-response Relationship.")
331
- gr.Markdown("One intuitive way to see the impact that changing one variable may have upon another is to look for a dose-response relationship, in which a larger intervention in the treatment (the value in text form injected in the otherwise unchanged text sample) produces a larger response in the output (the softmax probability of a gendered pronoun). Specifically, below are examples of sweeping through a spectrum of place, date and subreddit interest. We encourage you to try your own!")
 
 
 
 
 
 
 
 
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  with gr.Row():
334
  x_axis = gr.Textbox(
335
  lines=5,
336
- label="Pick a spectrum of values for text injection and x-axis",
337
  )
 
 
 
 
 
338
  with gr.Row():
339
  model_name = gr.Radio(
340
- MODEL_NAMES,
341
  type="value",
342
- label="Pick a BERT-like model.",
343
  )
344
- place_holder = gr.Textbox(
345
- label="Special token used in input text that will be replaced with the above spectrum of values.",
346
- type="index",
347
  )
 
 
 
 
 
 
 
 
348
  to_normalize = gr.Dropdown(
349
  ["False", "True"],
350
- label="Normalize?",
351
  type="index",
352
  )
 
 
 
 
 
 
 
 
 
 
 
 
353
  with gr.Row():
354
  input_text = gr.Textbox(
355
- lines=5,
356
- label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
357
  )
 
 
 
 
358
  with gr.Row():
359
  sample_text = gr.Textbox(
360
  type="auto", label="Output text: Sample of text fed to model")
@@ -367,32 +459,24 @@ with demo:
367
  overflow_row_behaviour="show_ends",
368
  label="Table of softmax probability for pronouns predictions",
369
  )
370
- gr.Markdown("X-axis sorted by older to more recent dates:")
371
- place_gen = gr.Button('Populate fields with a location example')
372
-
373
- gr.Markdown("X-axis sorted by bottom 10 and top 10 Global Gender Gap ranked countries:")
374
- date_gen = gr.Button('Populate fields with a date example')
375
-
376
- gr.Markdown("X-axis sorted in order of increasing self-identified female participation (see [bburky demo](http://bburky.com/subredditgenderratios/)): ")
377
- subreddit_gen = gr.Button('Populate fields with a subreddit example')
378
 
379
  with gr.Row():
380
 
381
- date_gen.click(date_fn, inputs=[], outputs=[model_name,
382
- x_axis, place_holder, to_normalize, input_text])
383
  place_gen.click(place_fn, inputs=[], outputs=[
384
- model_name, x_axis, place_holder, to_normalize, input_text])
385
  subreddit_gen.click(reddit_fn, inputs=[], outputs=[
386
- model_name, x_axis, place_holder, to_normalize, input_text])
 
 
 
387
  with gr.Row():
388
  btn = gr.Button("Hit submit")
389
  btn.click(
390
  predict_gender_pronouns,
391
- inputs=[model_name, x_axis, place_holder,
392
- to_normalize, input_text],
393
  outputs=[sample_text, female_fig, male_fig, df])
394
 
395
- demo.launch(debug=True)
396
-
397
-
398
- # %%
 
1
+ # %%
2
+ from random import random
3
  import gradio as gr
4
  from transformers import pipeline
5
  from matplotlib.ticker import MaxNLocator
 
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
 
10
+
11
  MODEL_NAMES = ["bert-base-uncased",
12
  "distilbert-base-uncased", "xlm-roberta-base"]
13
+ OWN_MODEL_NAME = 'add-your-own'
14
 
15
  DECIMAL_PLACES = 1
16
  EPS = 1e-5 # to avoid /0 errors
17
 
18
  # Example date conts
19
  DATE_SPLIT_KEY = "DATE"
20
+ START_YEAR = 1801
21
  STOP_YEAR = 1999
22
  NUM_PTS = 20
23
  DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist()
 
135
  ["actor", "actress"],
136
  ]
137
 
138
+ # %%
139
  # Fire up the models
 
 
140
  models = dict()
141
 
142
+ for bert_like in MODEL_NAMES:
143
+ models[bert_like] = pipeline("fill-mask", model=bert_like)
144
 
145
  # %%
 
 
 
 
146
 
147
 
148
  def get_gendered_token_ids():
 
171
  return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
172
 
173
  # %%
174
+
175
+
176
  def get_figure(df, gender, n_fit=1):
177
  df = df.set_index('x-axis')
178
  cols = df.columns
 
181
  fig, ax = plt.subplots()
182
  # Trying small fig due to rendering issues on HF, not on VS Code
183
  fig.set_figheight(3)
184
+ fig.set_figwidth(9)
185
 
186
  # find stackoverflow reference
187
  p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
 
199
  ax.legend(list(df.columns))
200
 
201
  ax.axis('tight')
 
 
 
202
  ax.set_xlabel("Value injected into input text")
203
  ax.set_title(
204
  f"Probability of predicting {gender} pronouns.")
205
  ax.set_ylabel(f"Softmax prob for pronouns")
206
  ax.xaxis.set_major_locator(MaxNLocator(6))
207
+ ax.tick_params(axis='x', labelrotation=5)
208
  return fig
209
 
210
 
211
  # %%
212
  def predict_gender_pronouns(
213
+ model_name,
214
+ own_model_name,
215
  indie_vars,
216
  split_key,
217
  normalizing,
218
+ n_fit,
219
  input_text,
220
  ):
221
  """Run inference on input_text for each model type, returning df and plots of percentage
222
  of gender pronouns predicted as female and male in each target text.
223
  """
224
+ if model_name not in MODEL_NAMES:
225
+ model = pipeline("fill-mask", model=own_model_name)
226
+ else:
227
+ model = models[model_name]
228
+
229
  mask_token = model.tokenizer.mask_token
230
 
231
  indie_vars_list = indie_vars.split(',')
 
272
  results_df['female_pronouns'] = female_pronoun_preds
273
  results_df['male_pronouns'] = male_pronoun_preds
274
  female_fig = get_figure(results_df.drop(
275
+ 'male_pronouns', axis=1), 'female', n_fit,)
276
  male_fig = get_figure(results_df.drop(
277
+ 'female_pronouns', axis=1), 'male', n_fit,)
278
+ display_text = f"{random.choice(indie_vars_list)}".join(text_segments)
279
 
280
  return (
281
+ display_text,
282
  female_fig,
283
  male_fig,
284
  results_df,
285
  )
286
 
287
+
288
  # %%
289
  title = "Causing Gender Pronouns"
290
  description = """
 
294
 
295
  place_example = [
296
  MODEL_NAMES[0],
297
+ '',
298
  ', '.join(PLACES),
299
  'PLACE',
300
  "False",
301
+ 1,
302
  'Born in PLACE, she was a teacher.'
303
  ]
304
 
305
  date_example = [
306
  MODEL_NAMES[0],
307
+ '',
308
  ', '.join(DATES),
309
  'DATE',
310
  "False",
311
+ 3,
312
  'Born in DATE, she was a doctor.'
313
  ]
314
 
315
 
316
  subreddit_example = [
317
  MODEL_NAMES[2],
318
+ '',
319
  ', '.join(SUBREDDITS),
320
  'SUBREDDIT',
321
  "False",
322
+ 1,
323
+ 'I saw in r/SUBREDDIT that she is a hacker.'
324
+ ]
325
+
326
+ own_model_example = [
327
+ OWN_MODEL_NAME,
328
+ 'lordtt13/COVID-SciBERT',
329
+ ', '.join(DATES),
330
+ 'DATE',
331
+ "False",
332
+ 3,
333
+ 'Ending her professorship in DATE, she was instrumental in developing the COVID vaccine.'
334
  ]
335
 
336
 
337
  def date_fn():
338
  return date_example
339
+
340
+
341
  def place_fn():
342
  return place_example
343
+
344
+
345
  def reddit_fn():
346
  return subreddit_example
347
 
348
 
349
+ def your_fn():
350
+ return own_model_example
351
+
352
+
353
  # %%
354
  demo = gr.Blocks()
355
  with demo:
356
+ gr.Markdown("## Spurious Correlation Evaluation for our LLMs")
357
  gr.Markdown("Although genders are relatively evenly distributed across time, place and interests, there are also known gender disparities in terms of access to resources. Here we demonstrate that this access disparity can result in dataset selection bias, causing models to learn a surprising range of spurious associations.")
358
  gr.Markdown("These spurious associations are often considered undesirable, as they do not match our intuition about the real-world domain from which we derive samples for inference-time prediction.")
359
  gr.Markdown("Selection of samples into datasets is a zero-sum-game, with even our high quality datasets forced to trade off one for another, thus inducing selection bias into the learned associations of the model.")
360
 
361
+ gr.Markdown("### Data Generating Process")
362
+ gr.Markdown("To pick values below that are most likely to cause spurious correlations, it helps to make some assumptions about the training datasets' likely data generating process, and where selection bias may come in.")
363
+
364
+ gr.Markdown("A plausible data generating processes for both Wikipedia and Reddit sourced data is shown as a DAG below. These DAGs are prone to collider bias when conditioning on `access`. In other words, although in real life `place`, `date`, (subreddit) `interest` and gender are all unconditionally independent, when we condition on their common effect, `access`, they become unconditionally dependent. Composing a dataset often requires the dataset maintainers to condition on `access`. Thus LLMs learn these dataset induced dependencies, appearing to us as spurious correlations.")
365
+ gr.Markdown("""
366
+ <center>
367
+ <img src="https://www.dropbox.com/s/f0numpllywdd271/combo_dag_block_party.png?raw=1"
368
+ alt="DAG of possible data generating process for datasets used in training some of our LLMs.">
369
+ </center>
370
+ """)
371
 
372
+ gr.Markdown("There may be misassumptions in our DAG above, which you can explore below.")
373
+ gr.Markdown("Or you may be interested in applying this demo to your own model of interest. This demo _should_ work with any Hugging Face model that supports the [fill-mask](https://huggingface.co/models?pipeline_tag=fill-mask) task.")
374
+
375
+ gr.Markdown("### Dose-response Relationship")
376
+ gr.Markdown("One intuitive way to see the impact that changing one variable may have upon another is to look for a dose-response relationship, in which a larger intervention in the treatment (the value in text form injected in the otherwise unchanged text sample) produces a larger response in the output (the softmax probability of a gendered pronoun).")
377
+
378
+ gr.Markdown("### This Demo")
379
+ gr.Markdown("This type of plot requires a range of values along which we may see a spectrum of gender representation (or misrepresentation) in our datasets.")
380
+ gr.Markdown("Click on one of the examples below (where we sweep through a spectrum of `places`, `date` and `subreddit` interest) to get an idea of whats intended here. Then try your own!")
381
+
382
+ with gr.Row():
383
+ gr.Markdown("X-axis sorted by older to more recent dates:")
384
+ place_gen = gr.Button('Country example')
385
+
386
+ gr.Markdown(
387
+ "X-axis sorted by bottom 10 and top 10 [Global Gender Gap](https://www3.weforum.org/docs/WEF_GGGR_2021.pdf) ranked countries by World Economic Forum in 2021:")
388
+ date_gen = gr.Button('Date example')
389
+
390
+ gr.Markdown(
391
+ "X-axis sorted in order of increasing self-identified female participation (see [bburky demo](http://bburky.com/subredditgenderratios/)): ")
392
+ subreddit_gen = gr.Button('Subreddit example')
393
+
394
+ gr.Markdown("Date example with your own model loaded! (We recommend you try after seeing how others work. It can take a while to load new model.)")
395
+ your_gen = gr.Button('Your model example')
396
+
397
  with gr.Row():
398
  x_axis = gr.Textbox(
399
  lines=5,
400
+ label="Pick a spectrum of comma separated values for text injection and x-axis",
401
  )
402
+
403
+
404
+ gr.Markdown(
405
+ "Pick a pre-loaded BERT-family model of interest, or add another Hugging Face model that supports the [fill-mask](https://huggingface.co/models?pipeline_tag=fill-mask) task (this may take some time to load).")
406
+
407
  with gr.Row():
408
  model_name = gr.Radio(
409
+ MODEL_NAMES + [OWN_MODEL_NAME],
410
  type="value",
411
+ label="Model: Pick a BERT-like model.",
412
  )
413
+ own_model_name = gr.Textbox(
414
+ label="If you selected an 'add-your-own' model, put your models Hugging Face pipeline name here. We think it should work with any model that supports the fill-mask task.",
 
415
  )
416
+
417
+ gr.Markdown(
418
+ "We are able to test the pre-trained LLMs without any modification to the models, as the gender-pronoun prediction task is simply a special case of the masked language modeling (MLM) task, with which all these models were pre-trained. Rather than random masking, the gender-pronoun prediction task masks only non-gender-neutral terms (listed in prior [Space](https://huggingface.co/spaces/emilylearning/causing_gender_pronouns_two)).")
419
+ gr.Markdown("For the pre-trained LLMs the final prediction is a softmax over the entire tokenizer's vocabulary, from which we sum up the portion of the probability mass from the top five prediction words that are gendered terms. Pick if you want to the predictions normalied to these gendered terms only.")
420
+ gr.Markdown("Also tell the demo what special token you will use in your input text, that you would like replaced with the spectrum of values you listed above, and the degree of polynomial fit used for high-lighting possible dose response trend ")
421
+
422
+
423
+ with gr.Row():
424
  to_normalize = gr.Dropdown(
425
  ["False", "True"],
426
+ label="Normalize model's predictions to only the gendered ones?",
427
  type="index",
428
  )
429
+ place_holder = gr.Textbox(
430
+ label="Special token place-holder that used in input text that will be replaced with the above spectrum of values.",
431
+ )
432
+ n_fit = gr.Dropdown(
433
+ list(range(1, 5)),
434
+ label="Degree of polynomial fit for high-lighting possible dose response trend",
435
+ type="value",
436
+ )
437
+
438
+ gr.Markdown(
439
+ "Finally, add input text that includes at least one gendered pronouns and one place-holder token specified above.")
440
+
441
  with gr.Row():
442
  input_text = gr.Textbox(
443
+ lines=3,
444
+ label="Input Text: Sentence that includes gendered pronouns and your place-holder token specified above.",
445
  )
446
+
447
+ gr.Markdown("### Outputs!")
448
+ gr.Markdown("Scroll down and 'Hit Submit'!")
449
+
450
  with gr.Row():
451
  sample_text = gr.Textbox(
452
  type="auto", label="Output text: Sample of text fed to model")
 
459
  overflow_row_behaviour="show_ends",
460
  label="Table of softmax probability for pronouns predictions",
461
  )
 
 
 
 
 
 
 
 
462
 
463
  with gr.Row():
464
 
465
+ date_gen.click(date_fn, inputs=[], outputs=[model_name, own_model_name,
466
+ x_axis, place_holder, to_normalize, n_fit, input_text])
467
  place_gen.click(place_fn, inputs=[], outputs=[
468
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
469
  subreddit_gen.click(reddit_fn, inputs=[], outputs=[
470
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
471
+ your_gen.click(your_fn, inputs=[], outputs=[
472
+ model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
473
+
474
  with gr.Row():
475
  btn = gr.Button("Hit submit")
476
  btn.click(
477
  predict_gender_pronouns,
478
+ inputs=[model_name, own_model_name, x_axis, place_holder,
479
+ to_normalize, n_fit, input_text],
480
  outputs=[sample_text, female_fig, male_fig, df])
481
 
482
+ demo.launch(debug=True)