emilylearning commited on
Commit
1c81243
1 Parent(s): 01dc8f8

works on vs code...

Browse files
Files changed (1) hide show
  1. app.py +400 -0
app.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from matplotlib.ticker import MaxNLocator
4
+ import pandas as pd
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ MODEL_NAMES = ["bert-base-uncased",
9
+ "distilbert-base-uncased", "xlm-roberta-base"]
10
+
11
+ DECIMAL_PLACES = 1
12
+ EPS = 1e-5 # to avoid /0 errors
13
+
14
+ # Example date conts
15
+ DATE_SPLIT_KEY = "DATE"
16
+ START_YEAR = 1800
17
+ STOP_YEAR = 1999
18
+ NUM_PTS = 20
19
+ DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist()
20
+ DATES = [f'{d}' for d in DATES]
21
+
22
+ # Example place conts
23
+ # https://www3.weforum.org/docs/WEF_GGGR_2021.pdf
24
+ # Bottom 10 and top 10 Global Gender Gap ranked countries.
25
+ PLACE_SPLIT_KEY = "PLACE"
26
+ PLACES = [
27
+ "Afghanistan",
28
+ "Yemen",
29
+ "Iraq",
30
+ "Pakistan",
31
+ "Syria",
32
+ "Democratic Republic of Congo",
33
+ "Iran",
34
+ "Mali",
35
+ "Chad",
36
+ "Saudi Arabia",
37
+ "Switzerland",
38
+ "Ireland",
39
+ "Lithuania",
40
+ "Rwanda",
41
+ "Namibia",
42
+ "Sweden",
43
+ "New Zealand",
44
+ "Norway",
45
+ "Finland",
46
+ "Iceland"]
47
+
48
+
49
+ # Example Reddit interest consts
50
+ # in order of increasing self-identified female participation.
51
+ # See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 400000
52
+ SUBREDDITS = [
53
+ "GlobalOffensive",
54
+ "pcmasterrace",
55
+ "nfl",
56
+ "sports",
57
+ "The_Donald",
58
+ "leagueoflegends",
59
+ "Overwatch",
60
+ "gonewild",
61
+ "Futurology",
62
+ "space",
63
+ "technology",
64
+ "gaming",
65
+ "Jokes",
66
+ "dataisbeautiful",
67
+ "woahdude",
68
+ "askscience",
69
+ "wow",
70
+ "anime",
71
+ "BlackPeopleTwitter",
72
+ "politics",
73
+ "pokemon",
74
+ "worldnews",
75
+ "reddit.com",
76
+ "interestingasfuck",
77
+ "videos",
78
+ "nottheonion",
79
+ "television",
80
+ "science",
81
+ "atheism",
82
+ "movies",
83
+ "gifs",
84
+ "Music",
85
+ "trees",
86
+ "EarthPorn",
87
+ "GetMotivated",
88
+ "pokemongo",
89
+ "news",
90
+ # removing below subreddit as most of the tokens are taken up by it:
91
+ # ['ff', '##ff', '##ff', '##fu', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', ...]
92
+ # "fffffffuuuuuuuuuuuu",
93
+ "Fitness",
94
+ "Showerthoughts",
95
+ "OldSchoolCool",
96
+ "explainlikeimfive",
97
+ "todayilearned",
98
+ "gameofthrones",
99
+ "AdviceAnimals",
100
+ "DIY",
101
+ "WTF",
102
+ "IAmA",
103
+ "cringepics",
104
+ "tifu",
105
+ "mildlyinteresting",
106
+ "funny",
107
+ "pics",
108
+ "LifeProTips",
109
+ "creepy",
110
+ "personalfinance",
111
+ "food",
112
+ "AskReddit",
113
+ "books",
114
+ "aww",
115
+ "sex",
116
+ "relationships",
117
+ ]
118
+
119
+ GENDERED_LIST = [
120
+ ['he', 'she'],
121
+ ['him', 'her'],
122
+ ['his', 'hers'],
123
+ ["himself", "herself"],
124
+ ['male', 'female'],
125
+ ['man', 'woman'],
126
+ ['men', 'women'],
127
+ ["husband", "wife"],
128
+ ['father', 'mother'],
129
+ ['boyfriend', 'girlfriend'],
130
+ ['brother', 'sister'],
131
+ ["actor", "actress"],
132
+ ]
133
+
134
+
135
+ # Fire up the models
136
+ # TODO: Make it so models can be added in the future
137
+ models_paths = dict()
138
+ models = dict()
139
+
140
+
141
+ # %%
142
+ for bert_like in MODEL_NAMES:
143
+ models_paths[bert_like] = bert_like
144
+ models[bert_like] = pipeline(
145
+ "fill-mask", model=models_paths[bert_like])
146
+
147
+
148
+ def get_gendered_token_ids():
149
+ male_gendered_tokens = [list[0] for list in GENDERED_LIST]
150
+ female_gendered_tokens = [list[1] for list in GENDERED_LIST]
151
+
152
+ return male_gendered_tokens, female_gendered_tokens
153
+
154
+
155
+ def prepare_text_for_masking(input_text, mask_token, gendered_tokens, split_key):
156
+ text_w_masks_list = [
157
+ mask_token if word in gendered_tokens else word for word in input_text.split()]
158
+ num_masks = len([m for m in text_w_masks_list if m == mask_token])
159
+
160
+ text_portions = ' '.join(text_w_masks_list).split(split_key)
161
+ return text_portions, num_masks
162
+
163
+
164
+ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_preds):
165
+ pronoun_preds = [sum([
166
+ pronoun["score"] if pronoun["token_str"].lower(
167
+ ) in gendered_token else 0.0
168
+ for pronoun in top_preds])
169
+ for top_preds in mask_filled_text
170
+ ]
171
+ return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
172
+
173
+
174
+ def get_figure(df, gender, n_fit=1):
175
+ df = df.set_index('x-axis')
176
+ cols = df.columns
177
+ xs = list(range(len(df)))
178
+ ys = df[cols[0]]
179
+ fig, ax = plt.subplots()
180
+
181
+ # find stackoverflow reference
182
+ p, C_p = np.polyfit(xs, ys, n_fit, cov=1)
183
+ t = np.linspace(min(xs)-1, max(xs)+1, 10*len(xs))
184
+ TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T
185
+
186
+ # matrix multiplication calculates the polynomial values
187
+ yi = np.dot(TT, p)
188
+ C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T
189
+ sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal
190
+
191
+ ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha=.25)
192
+ ax.plot(t, yi, '-')
193
+ ax.plot(df, 'ro')
194
+ ax.legend(list(df.columns))
195
+
196
+ ax.axis('tight')
197
+
198
+ # fig.canvas.draw()
199
+
200
+ ax.set_xlabel("Value injected into input text")
201
+ ax.set_title(
202
+ f"Probability of predicting {gender} pronouns.")
203
+ ax.set_ylabel(f"Softmax prob for pronouns")
204
+ ax.xaxis.set_major_locator(MaxNLocator(6))
205
+ ax.tick_params(axis='x', labelrotation=15)
206
+ return fig
207
+
208
+
209
+ # %%
210
+ def predict_gender_pronouns(
211
+ model_type,
212
+ indie_vars,
213
+ split_key,
214
+ normalizing,
215
+ n_fit,
216
+ input_text,
217
+ ):
218
+ """Run inference on input_text for each model type, returning df and plots of precentage
219
+ of gender pronouns predicted as female and male in each target text.
220
+ """
221
+ model = models[model_type]
222
+ mask_token = model.tokenizer.mask_token
223
+
224
+ indie_vars_list = indie_vars.split(',')
225
+
226
+ male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids()
227
+
228
+ text_segments, num_preds = prepare_text_for_masking(
229
+ input_text, mask_token, male_gendered_tokens + female_gendered_tokens, split_key)
230
+
231
+ male_pronoun_preds = []
232
+ female_pronoun_preds = []
233
+ for indie_var in indie_vars_list:
234
+
235
+ target_text = f"{indie_var}".join(text_segments)
236
+ mask_filled_text = model(target_text)
237
+ # Quick hack as realized return type based on how many MASKs in text.
238
+ if type(mask_filled_text[0]) is not list:
239
+ mask_filled_text = [mask_filled_text]
240
+
241
+ female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
242
+ mask_filled_text,
243
+ female_gendered_tokens,
244
+ num_preds
245
+ ))
246
+ male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
247
+ mask_filled_text,
248
+ male_gendered_tokens,
249
+ num_preds
250
+ ))
251
+
252
+ if normalizing:
253
+ total_gendered_probs = np.add(
254
+ female_pronoun_preds, male_pronoun_preds)
255
+ female_pronoun_preds = np.around(
256
+ np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100,
257
+ decimals=DECIMAL_PLACES
258
+ )
259
+ male_pronoun_preds = np.around(
260
+ np.divide(male_pronoun_preds, total_gendered_probs+EPS)*100,
261
+ decimals=DECIMAL_PLACES
262
+ )
263
+
264
+ results_df = pd.DataFrame({'x-axis': indie_vars_list})
265
+ results_df['female_pronouns'] = female_pronoun_preds
266
+ results_df['male_pronouns'] = male_pronoun_preds
267
+ female_fig = get_figure(results_df.drop(
268
+ 'male_pronouns', axis=1), 'female', n_fit)
269
+ male_fig = get_figure(results_df.drop(
270
+ 'female_pronouns', axis=1), 'male', n_fit)
271
+
272
+ return (
273
+ target_text,
274
+ female_fig,
275
+ male_fig,
276
+ results_df,
277
+ )
278
+
279
+ # %%
280
+ title = "Causing Gender Pronouns"
281
+ description = """
282
+ ## Intro
283
+
284
+ """
285
+
286
+ place_example = [
287
+ MODEL_NAMES[0],
288
+ ','.join(PLACES),
289
+ 'PLACE',
290
+ "False",
291
+ 1,
292
+ 'Born in PLACE, she was a teacher.'
293
+ ]
294
+
295
+ date_example = [
296
+ MODEL_NAMES[0],
297
+ ','.join(DATES),
298
+ 'DATE',
299
+ "False",
300
+ 2,
301
+ 'Born in DATE, she was a doctor.'
302
+ ]
303
+
304
+
305
+ subreddit_example = [
306
+ MODEL_NAMES[2],
307
+ ','.join(SUBREDDITS),
308
+ 'SUBREDDIT',
309
+ "False",
310
+ 1,
311
+ 'I saw on r/SUBREDDIT that she is a hacker.'
312
+ ]
313
+
314
+
315
+ def date_fn():
316
+ return date_example
317
+ def place_fn():
318
+ return place_example
319
+ def reddit_fn():
320
+ return subreddit_example
321
+
322
+
323
+ # %%
324
+ demo = gr.Blocks()
325
+ with demo:
326
+ gr.Markdown("## Hunt for spurious correlations in our LLMs.")
327
+ gr.Markdown("Please see a better explanation in another [Space](https://huggingface.co/spaces/emilylearning/causing_gender_pronouns_two).")
328
+
329
+
330
+ with gr.Row():
331
+ x_axis = gr.Textbox(
332
+ lines=5,
333
+ label="Pick a spectrum of values for text injection and x-axis",
334
+ )
335
+ with gr.Row():
336
+ model_name = gr.Radio(
337
+ MODEL_NAMES,
338
+ type="value",
339
+ label="Pick a BERT-like model.",
340
+ )
341
+ place_holder = gr.Textbox(
342
+ label="Special token used in input text that will be replaced with the above spectrum of values.",
343
+ type="index",
344
+ )
345
+ to_normalize = gr.Dropdown(
346
+ ["False", "True"],
347
+ label="Normalize?",
348
+ type="index",
349
+ )
350
+ n_fit = gr.Dropdown(
351
+ list(range(1, 5)),
352
+ label="Degree of polynomial fit for dose response trend",
353
+ type="value",
354
+ )
355
+ with gr.Row():
356
+ input_text = gr.Textbox(
357
+ lines=5,
358
+ label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
359
+ )
360
+ with gr.Row():
361
+ sample_text = gr.Textbox(
362
+ type="auto", label="Output text: Sample of text fed to model")
363
+ with gr.Row():
364
+ female_fig = gr.Plot(
365
+ type="auto", label="Plot of softmax probability pronouns predicted female.")
366
+ with gr.Row():
367
+ male_fig = gr.Plot(
368
+ type="auto", label="Plot of softmax probability pronouns predicted male.")
369
+ with gr.Row():
370
+ df = gr.Dataframe(
371
+ show_label=True,
372
+ overflow_row_behaviour="show_ends",
373
+ label="Table of softmax probability for pronouns predictions",
374
+ )
375
+ gr.Markdown("x-axis sorted by older to more recent dates:")
376
+ place_gen = gr.Button('Populate fields with a location example')
377
+
378
+ gr.Markdown("x-axis sorted by bottom 10 and top 10 Global Gender Gap ranked countries:")
379
+ date_gen = gr.Button('Populate fields with a date example')
380
+
381
+ gr.Markdown("x-axis sorted in order of increasing self-identified female participation (see [bburky demo](http://bburky.com/subredditgenderratios/)): ")
382
+ subreddit_gen = gr.Button('Populate fields with a subreddit example')
383
+
384
+ #https://github.com/gradio-app/gradio/issues/690#issuecomment-1118772919
385
+ with gr.Row():
386
+ date_gen.click(date_fn, inputs=[], outputs=[model_name,
387
+ x_axis, place_holder, to_normalize, n_fit, input_text])
388
+ place_gen.click(place_fn, inputs=[], outputs=[
389
+ model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
390
+ subreddit_gen.click(reddit_fn, inputs=[], outputs=[
391
+ model_name, x_axis, place_holder, to_normalize, n_fit, input_text])
392
+ with gr.Row():
393
+ btn = gr.Button("Hit submit")
394
+ btn.click(
395
+ predict_gender_pronouns,
396
+ inputs=[model_name, x_axis, place_holder,
397
+ to_normalize, n_fit, input_text],
398
+ outputs=[sample_text, female_fig, male_fig, df])
399
+
400
+ demo.launch(debug=True)