File size: 17,707 Bytes
5943071
bd64c46
 
 
5943071
bd64c46
 
 
 
 
 
 
5943071
bd64c46
 
 
 
 
 
 
 
 
 
 
6bfe2f5
 
bd64c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5943071
 
6bfe2f5
 
bd64c46
 
 
 
6bfe2f5
 
bd64c46
 
 
 
 
5943071
bd64c46
 
 
 
 
 
 
 
5943071
 
 
 
6bfe2f5
 
bd64c46
 
 
 
 
 
 
 
 
6a3abb5
bd64c46
 
 
 
 
 
 
6a3abb5
bd64c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bfe2f5
 
bd64c46
 
 
 
 
 
 
 
 
 
 
6a3abb5
 
bd64c46
 
6bfe2f5
5943071
6a3abb5
5943071
 
 
6bfe2f5
5943071
 
6a3abb5
6bfe2f5
 
 
 
6a3abb5
 
6bfe2f5
5943071
6a3abb5
 
 
 
 
 
 
 
 
 
6bfe2f5
 
6a3abb5
 
 
 
 
bd64c46
 
6a3abb5
bd64c46
 
 
 
5943071
 
 
 
bd64c46
 
6bfe2f5
 
bd64c46
 
6a3abb5
 
5943071
6bfe2f5
 
 
 
bd64c46
6bfe2f5
 
5943071
 
 
 
 
 
 
6bfe2f5
 
5943071
 
 
 
6bfe2f5
016995c
 
 
5943071
 
 
 
 
 
 
 
 
 
 
 
 
6bfe2f5
 
5943071
6bfe2f5
 
 
5943071
bd64c46
 
d906d0f
bd64c46
d906d0f
 
bd64c46
d906d0f
bd64c46
d906d0f
 
bd64c46
 
6bfe2f5
bd64c46
 
 
 
5943071
bd64c46
 
ce84560
5943071
 
f93dd76
f9ec056
f93dd76
6a3abb5
f93dd76
f9ec056
bd64c46
5943071
f93dd76
f9ec056
f93dd76
f9ec056
f93dd76
f9ec056
bd64c46
 
f9ec056
f93dd76
6a3abb5
bd64c46
 
 
 
 
 
6a3abb5
bd64c46
5943071
f93dd76
5943071
f93dd76
bd64c46
 
 
 
 
 
 
 
 
 
 
f93dd76
f9ec056
5943071
 
f93dd76
6a3abb5
5943071
 
d906d0f
 
 
 
 
 
f93dd76
5943071
 
bd64c46
 
f9ec056
f93dd76
6bfe2f5
 
6a3abb5
f9ec056
f93dd76
6a3abb5
bd64c46
 
 
 
 
 
d906d0f
 
 
 
 
 
 
 
 
 
 
 
 
f93dd76
 
 
 
 
 
 
 
 
 
d906d0f
 
ce84560
 
 
d906d0f
ce84560
 
d906d0f
 
f93dd76
 
bd64c46
 
 
 
ce84560
bd64c46
 
 
 
ce84560
bd64c46
6a3abb5
bd64c46
 
 
ce84560
bd64c46
6a3abb5
5943071
 
 
ce84560
5943071
ce84560
bd64c46
 
 
302a127
ce84560
bd64c46
 
 
6bfe2f5
 
bd64c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5943071
 
 
ce84560
6bfe2f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
from typing import Optional
import gradio as gr
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer
from transformers import pipeline
import pandas as pd
import numpy as np


# Play with me, consts
CONDITIONING_VARIABLES = ["none", "birth_place", "birth_date", "name"]
FEMALE_WEIGHTS = [1.5, 5]  # About 5x more male than female tokens in dataset
BERT_LIKE_MODELS = ["bert", "distilbert"]

# Internal consts
START_YEAR = 1800
STOP_YEAR = 1999
SPLIT_KEY = "DATE"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MAX_TOKEN_LENGTH = 128
NON_LOSS_TOKEN_ID = -100
NON_GENDERED_TOKEN_ID = 30  # Picked an int that will pop out visually
# Picked an int that will pop out visually
LABEL_DICT = {"female": 9, "male": -9}
CLASSES = list(LABEL_DICT.keys())


# Fire up the models
models_paths = dict()
models = dict()

base_path = "emilylearning/"
for var in CONDITIONING_VARIABLES:
    for f_weight in FEMALE_WEIGHTS:
        if f_weight == 1.5:
            models_paths[(var, f_weight)] = (
                base_path
                + f"finetuned_cgp_added_{var}__female_weight_{f_weight}__test_run_False__p_dataset_100"
            )
        else:
            models_paths[(var, f_weight)] = (
                base_path
                + f"finetuned_cgp_add_{var}__f_weight_{f_weight}__p_dataset_100__test_False"
            )
        models[(var, f_weight)] = AutoModelForTokenClassification.from_pretrained(
            models_paths[(var, f_weight)]
        )
for bert_like in BERT_LIKE_MODELS:
    models_paths[(bert_like,)] = f"{bert_like}-base-uncased"
    models[(bert_like,)] = pipeline(
        "fill-mask", model=models_paths[(bert_like,)])


# Tokenizers same for each model, so just grabbing one of them
tokenizer = AutoTokenizer.from_pretrained(
    models_paths[(CONDITIONING_VARIABLES[0], FEMALE_WEIGHTS[0])
                 ], add_prefix_space=True
)
MASK_TOKEN_ID = tokenizer.mask_token_id


# more static stuff
gendered_lists = [
    ["he", "she"],
    ["him", "her"],
    ["his", "hers"],
    ["male", "female"],
    ["man", "woman"],
    ["men", "women"],
    ["husband", "wife"],
]
male_gendered_tokens = [list[0] for list in gendered_lists]
female_gendered_tokens = [list[1] for list in gendered_lists]

male_gendered_token_ids = tokenizer.convert_tokens_to_ids(male_gendered_tokens)
female_gendered_token_ids = tokenizer.convert_tokens_to_ids(
    female_gendered_tokens)

assert tokenizer.unk_token_id not in male_gendered_token_ids
assert tokenizer.unk_token_id not in female_gendered_token_ids

label_list = list(LABEL_DICT.values())
assert label_list[0] == LABEL_DICT["female"], "LABEL_DICT not an ordered dict"

label2id = {label: idx for idx, label in enumerate(label_list)}


def tokenize_and_append_metadata(text, tokenizer):
    tokenized = tokenizer(
        text,
        truncation=True,
        padding=True,
        max_length=MAX_TOKEN_LENGTH,
    )
    """Tokenize text and mask/flag 'gendered_tokens_ids' in token_ids and labels."""

    # Finding the gender pronouns in the tokens
    token_ids = tokenized["input_ids"]
    female_tags = torch.tensor(
        [
            LABEL_DICT["female"]
            if id in female_gendered_token_ids
            else NON_GENDERED_TOKEN_ID
            for id in token_ids
        ]
    )
    male_tags = torch.tensor(
        [
            LABEL_DICT["male"]
            if id in male_gendered_token_ids
            else NON_GENDERED_TOKEN_ID
            for id in token_ids
        ]
    )

    # Labeling and masking out occurrences of gendered pronouns
    labels = torch.tensor([NON_LOSS_TOKEN_ID] * len(token_ids))
    labels = torch.where(
        female_tags == LABEL_DICT["female"],
        label2id[LABEL_DICT["female"]],
        NON_LOSS_TOKEN_ID,
    )
    labels = torch.where(
        male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels
    )
    masked_token_ids = torch.where(
        female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor(
            token_ids)
    )
    masked_token_ids = torch.where(
        male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids
    )

    tokenized["input_ids"] = masked_token_ids
    tokenized["labels"] = labels

    return tokenized


def get_tokenized_text_with_years(years, input_text):
    """Construct dict of tokenized texts with each year injected into the text."""
    text_portions = input_text.split(SPLIT_KEY)

    tokenized_w_year = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []}
    for b_date in years:

        target_text = f"{b_date}".join(text_portions)
        tokenized_sample = tokenize_and_append_metadata(
            target_text,
            tokenizer=tokenizer,
        )

        tokenized_w_year['ids'].append(tokenized_sample["input_ids"])
        tokenized_w_year['atten_mask'].append(
            torch.tensor(tokenized_sample["attention_mask"]))
        tokenized_w_year['toks'].append(
            tokenizer.convert_ids_to_tokens(tokenized_sample["input_ids"]))
        tokenized_w_year['labels'].append(tokenized_sample["labels"])

    return tokenized_w_year


def predict_gender_pronouns(
    num_points, conditioning_variables, f_weights, bert_like_models, input_text
):
    """Run inference on input_text for each model type, returning df and plots of precentage
    of gender pronouns predicted as female and male in each target text.
    """

    years = np.linspace(START_YEAR, STOP_YEAR, int(num_points)).astype(int)

    tokenized = get_tokenized_text_with_years(years, input_text)

    is_masked = tokenized['ids'][0] == MASK_TOKEN_ID
    num_preds = torch.sum(is_masked).item()

    dfs = []
    dfs.append(pd.DataFrame({"year": years}))
    for f_weight in f_weights:
        for var in conditioning_variables:
            prefix = f"{var}_w{f_weight}"
            model = models[(var, f_weight)]

            p_female = []
            p_male = []
            for year_idx in range(len(tokenized['ids'])):
                ids = tokenized["ids"][year_idx]
                atten_mask = tokenized["atten_mask"][year_idx]
                labels = tokenized["labels"][year_idx]

                with torch.no_grad():
                    outputs = model(ids.unsqueeze(dim=0),
                                    atten_mask.unsqueeze(dim=0))
                    preds = torch.argmax(outputs[0][0].cpu(), dim=1)

                    #was_masked = labels.cpu() != -100
                    preds = torch.where(is_masked, preds, -100)

                    p_female.append(
                        len(torch.where(preds == 0)[0]) / num_preds * 100)
                    p_male.append(
                        len(torch.where(preds == 1)[0]) / num_preds * 100)

            dfs.append(pd.DataFrame(
                {f"%f_{prefix}": p_female, f"%m_{prefix}": p_male}))

    for bert_like in bert_like_models:

        p_female = []
        p_male = []
        for year_idx in range(len(tokenized['ids'])):
            toks = tokenized["toks"][year_idx]
            target_text_for_bert = ' '.join(
                toks[1:-1])  # Removing [CLS] and [SEP]

            prefix = bert_like
            model = models[(bert_like,)]

            mask_filled_text = model(target_text_for_bert)
            # Quick hack as realized return type based on how many MASKs in text.
            if type(mask_filled_text[0]) is not list:
                mask_filled_text = [mask_filled_text]

            female_pronouns = [
                1 if pronoun[0]["token_str"] in female_gendered_tokens else 0
                for pronoun in mask_filled_text
            ]
            male_pronouns = [
                1 if pronoun[0]["token_str"] in male_gendered_tokens else 0
                for pronoun in mask_filled_text
            ]

            p_female.append(sum(female_pronouns) / num_preds * 100)
            p_male.append(sum(male_pronouns) / num_preds * 100)

        dfs.append(pd.DataFrame(
            {f"%f_{prefix}": p_female, f"%m_{prefix}": p_male}))

    # To display to user as an example
    toks = tokenized["toks"][0]
    target_text_w_masks = ' '.join(toks[1:-1])

    results = pd.concat(dfs, axis=1).set_index("year")

    female_df = results.filter(regex=".*f_").reset_index()  # Gradio doesn't 'see' index?
    female_df_for_plot = (
        female_df
    ) 

    male_df = results.filter(regex=".*m_").reset_index()  # Gradio doesn't 'see' index?
    male_df_for_plot = (
        male_df
    )  

    return (
        target_text_w_masks,
        female_df_for_plot,
        female_df,
        male_df_for_plot,
        male_df,
    )


title = "Causing Gender Pronouns"
description = """
<h2> Intro </h2>

This is a demo for a project exploring possible spurious correlations that have been learned by our models. We first examined the training datasets and learning tasks to hypothesize what spurious correlations may exist. Below we can condition on these variables to determine what effect they may have on the prediction outcomes.

Specially in this demo: In a user provided sentence, with at least one reference to a `DATE` and one gender pronoun, we will see how sweeping through a range of `DATE` values can change the predicted pronouns. This effect can be observed in BERT base models and in our fine-tuned models (with a specific pronoun predicting task on the [wiki-bio](https://huggingface.co/datasets/wiki_bio) dataset).

One way to explain this phenomenon is by looking at a likely data generating process for biographical-like data in both the main BERT training dataset as well as the `wiki_bio` dataset, in the form of a causal DAG.
 
<h2> Causal DAG </h2> 

In the DAG, we can see that `birth_place`, `birth_date` and `gender` are all independent elements that have no common cause with the other covariates in the DAG. However `birth_place`, `birth_date` and `gender` may all have a role in causing one's `access_to_resources`, with the general trend that `access_to_resources` has become less gender-dependent over time, but not in every `birth_place`, with recent events in Afghanistan providing a stark counterexample to this trend. 

Importantly, `access_to_resources` determines how, **if at all**, you may appear in the dataset's `context_words`.

We argue that although there are complex causal interactions between each word in any given sentence, the `context_words` are more likely to cause the `gender_pronouns`, rather than vice versa. For example, if the subject is a famous doctor and the object is her wealthy father, these context words will determine which person is being referred to, and thus which gendered-pronoun to use.
 
 
In this graph, arrow heads are intended to show the assumed direction of causation. E.g. as described above, we are claiming `context_words` cause the `gender_pronouns`. While causation follow direction of the arrows, statistical correlation can flow in any direction (it is cause-agnostic). 

In the case of this graph, any pink path between `context_words` and `gender_pronouns` will allow the flow of statistical correlation, inviting confounding and thus spurious correlations into the trained model.
 
<center>
<img src="https://www.dropbox.com/s/x60r43h7uwztnru/generic_ds_dag.png?raw=1" 
    alt="DAG of possible data generating process for datasets used in training.">
</center>
  
Those familiar with causal DAGs may note when can simply condition on `gender` to block any confounding between the `context_words` and the `gender_pronouns`.  However, this is not always possible, particularly in generative or mask-filling tasks where gender may be unknown, common in language modeling and in the demo below.  
 
 <h2> How to use this demo </h2> 
 
In this demo, a user can add any sentence that contains at least one gender pronoun and the capitalized word `DATE`. We then sweep through a range of `date` values in the place of `DATE`, while masking (for prediction) the gender pronouns (included in the list below).

```
gendered_lists = [
   ['he', 'she'],
   ['him', 'her'],
   ['his', 'hers'],
   ['male', 'female'],
   ['man', 'woman'],
   ['men', 'women'],
   ["husband", "wife"],
]
```

In addition to choosing the test sentence, we ask that you pick how the fine-tuned model was trained:
- conditioning variable: which, if any, conditioning variable from the three noted above in the DAG, was included in the text at train time.
- loss function weight: weight assigned to the minority class (female pronouns in this fine-tuning dataset) that was included in the text at train time.

You can also optionally pick a bert-like model for comparison.


Some notes: 
- Gradio currently only supports 6 plotting colors (but there are [plans](https://github.com/gradio-app/gradio/issues/1088) to support more!), so best to not select too many models at once for now.
- If the dataframes appear to not update with new fields, it may help to 'Clear' the fields before 'Submitting' new inputs. 


 <h2> What are the results</h2> 
 
In the resulting plots, we can look for a dose-response relationship between:
- our treatment: the sample text,
- and our outcome: the predicted gender of pronouns in the text.
 
Specifically, we are seeing if 1) making larger magnitude intervention: an older `DATE` in the text will, 2) result in a larger magnitude effect in the outcome: higher percentage of predicted female pronouns.

Some trends that appear in the test sentences I have tried:
- Conditioning on `birth_date` metadata in both training and inference text has the largest dose-response relationship. This seems reasonable, as the fine-tuned model is able to 'stratify' a learned relationship between gender pronouns and dates, when both are present in the text.
- While conditioning on either no metadata or `birth_place` data training, have similar middle-ground effects for this inference task. 
- Finally, conditioning on `name` metadata in training, (while again conditioning on `date` in inference) has almost no dose-response relationship. It appears the learning of a `name —> gender pronouns` relationship was sufficiently successful to overwhelm any potential more nuanced learning, such as that driven by `birth_date` or `place`. 

Please feel free to ping me on the Hugging Face discord (I'm 'emily_learner' there), with any feedback/comments/concerns or interesting findings!
"""


article = "Check out [main colab notebook](https://colab.research.google.com/drive/14ce4KD6PrCIL60Eng-t79tEI1UP-DHGz?usp=sharing#scrollTo=Mg1tUeHLRLaG) \
 with a lot more details about this method and implementation."

ceo_example = [
    20,
    ["none", "birth_date", "name"],
    FEMALE_WEIGHTS,
    [],
    'Born in DATE, she was a CEO. Her work was greatly respected, and she was well-regarded in her field.',
]

death_date_example = [
    10,
    ['birth_date'],
    [1.5],
    BERT_LIKE_MODELS,
    'Died in DATE, she was recognized for her great accomplishments to the field of teaching.'
]


no_job_example = [
    20,
    CONDITIONING_VARIABLES,
    [1.5],
    BERT_LIKE_MODELS,
    'Born in DATE, she was a happy child. Her family raised her in a loving environment where she thrived.',
]

coder_example = [
    20,
    ['none', 'birth_date'],
    [1.5],
    ['bert'],
    'Born in DATE, she was a computer scientist. Her work was greatly respected, and she was well-regarded in her field.'
]



gr.Interface(
    fn=predict_gender_pronouns,
    inputs=[
        gr.inputs.Number(
            default=20,
            label="Number of points (years) plotted -- select fewer if slow.",
        ),
        gr.inputs.CheckboxGroup(
            CONDITIONING_VARIABLES,
            default=["birth_date"],
            type="value",
            label="(1) Pick conditioning variable included in text during fine-tuning.",
        ),
        gr.inputs.CheckboxGroup(
            FEMALE_WEIGHTS,
            default=[1.5],
            type="value",
            label="(2) Pick loss function weight placed on female predictions  during fine-tuning.",
        ),
        gr.inputs.CheckboxGroup(
            BERT_LIKE_MODELS,
            default=BERT_LIKE_MODELS,
            type="value",
            label="(Optional) Pick BERT-like base uncased model for comparison.",
        ),
        gr.inputs.Textbox(
            lines=7,
            label="Input Text: Include one or more instance of the word 'DATE' below (to be replaced with a range of `{dates}` in demo), and one or more gender pronoun (to be `[MASK]`ed for prediction).",
            default="She always walked past the building built in DATE on her way to her job as an elementary school teacher.",
        ),
    ],
    outputs=[
        gr.outputs.Textbox(
            type="auto", label="Sample target text fed to model"),
        gr.outputs.Timeseries(
            x="year",
            label="Precent pred female pronoun vs year, per model trained with conditioning and with weight for female preds",
        ),
        gr.outputs.Dataframe(
            overflow_row_behaviour="show_ends",
            label="Precent pred female pronoun vs year, per model trained with conditioning and with weight for female preds",
        ),
        gr.outputs.Timeseries(
            x="year",
            label="Precent pred male pronoun vs year, per model trained with conditioning and with weight for female preds",
        ),
        gr.outputs.Dataframe(
            overflow_row_behaviour="show_ends",
            label="Precent pred male pronoun vs year, per model trained with conditioning and with weight for female preds",
        ),
    ],
    title=title,
    description=description,
    article=article,
    examples=[ceo_example, death_date_example, no_job_example, coder_example]
).launch()