Spaces:
Runtime error
Runtime error
# %% | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import random | |
from matplotlib.ticker import MaxNLocator | |
from transformers import pipeline | |
MODEL_NAMES = ["bert-base-uncased", "roberta-base", "bert-large-uncased", "roberta-large"] | |
OWN_MODEL_NAME = 'add-a-model' | |
DECIMAL_PLACES = 1 | |
EPS = 1e-5 # to avoid /0 errors | |
# Example date conts | |
DATE_SPLIT_KEY = "DATE" | |
START_YEAR = 1801 | |
STOP_YEAR = 1999 | |
NUM_PTS = 20 | |
DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist() | |
DATES = [f'{d}' for d in DATES] | |
# Example place conts | |
# https://www3.weforum.org/docs/WEF_GGGR_2021.pdf | |
# Bottom 10 and top 10 Global Gender Gap ranked countries. | |
PLACE_SPLIT_KEY = "PLACE" | |
PLACES = [ | |
"Afghanistan", | |
"Yemen", | |
"Iraq", | |
"Pakistan", | |
"Syria", | |
"Democratic Republic of Congo", | |
"Iran", | |
"Mali", | |
"Chad", | |
"Saudi Arabia", | |
"Switzerland", | |
"Ireland", | |
"Lithuania", | |
"Rwanda", | |
"Namibia", | |
"Sweden", | |
"New Zealand", | |
"Norway", | |
"Finland", | |
"Iceland"] | |
# Example Reddit interest consts | |
# in order of increasing self-identified female participation. | |
# See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 400000 | |
SUBREDDITS = [ | |
"GlobalOffensive", | |
"pcmasterrace", | |
"nfl", | |
"sports", | |
"The_Donald", | |
"leagueoflegends", | |
"Overwatch", | |
"gonewild", | |
"Futurology", | |
"space", | |
"technology", | |
"gaming", | |
"Jokes", | |
"dataisbeautiful", | |
"woahdude", | |
"askscience", | |
"wow", | |
"anime", | |
"BlackPeopleTwitter", | |
"politics", | |
"pokemon", | |
"worldnews", | |
"reddit.com", | |
"interestingasfuck", | |
"videos", | |
"nottheonion", | |
"television", | |
"science", | |
"atheism", | |
"movies", | |
"gifs", | |
"Music", | |
"trees", | |
"EarthPorn", | |
"GetMotivated", | |
"pokemongo", | |
"news", | |
# removing below subreddit as most of the tokens are taken up by it: | |
# ['ff', '##ff', '##ff', '##fu', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', ...] | |
# "fffffffuuuuuuuuuuuu", | |
"Fitness", | |
"Showerthoughts", | |
"OldSchoolCool", | |
"explainlikeimfive", | |
"todayilearned", | |
"gameofthrones", | |
"AdviceAnimals", | |
"DIY", | |
"WTF", | |
"IAmA", | |
"cringepics", | |
"tifu", | |
"mildlyinteresting", | |
"funny", | |
"pics", | |
"LifeProTips", | |
"creepy", | |
"personalfinance", | |
"food", | |
"AskReddit", | |
"books", | |
"aww", | |
"sex", | |
"relationships", | |
] | |
GENDERED_LIST = [ | |
['he', 'she'], | |
['him', 'her'], | |
['his', 'hers'], | |
["himself", "herself"], | |
['male', 'female'], | |
['man', 'woman'], | |
['men', 'women'], | |
["husband", "wife"], | |
['father', 'mother'], | |
['boyfriend', 'girlfriend'], | |
['brother', 'sister'], | |
["actor", "actress"], | |
] | |
# %% | |
# Fire up the models | |
models = dict() | |
for bert_like in MODEL_NAMES: | |
models[bert_like] = pipeline("fill-mask", model=bert_like) | |
# %% | |
def get_gendered_token_ids(): | |
male_gendered_tokens = [list[0] for list in GENDERED_LIST] | |
female_gendered_tokens = [list[1] for list in GENDERED_LIST] | |
return male_gendered_tokens, female_gendered_tokens | |
def prepare_text_for_masking(input_text, mask_token, gendered_tokens, split_key): | |
text_w_masks_list = [ | |
mask_token if word.lower() in gendered_tokens else word for word in input_text.split()] | |
num_masks = len([m for m in text_w_masks_list if m == mask_token]) | |
text_portions = ' '.join(text_w_masks_list).split(split_key) | |
return text_portions, num_masks | |
def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_preds): | |
pronoun_preds = [sum([ | |
pronoun["score"] if pronoun["token_str"].strip().lower() in gendered_token else 0.0 | |
for pronoun in top_preds]) | |
for top_preds in mask_filled_text | |
] | |
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES) | |
# %% | |
def get_figure(df, gender, n_fit=1): | |
df = df.set_index('x-axis') | |
cols = df.columns | |
xs = list(range(len(df))) | |
ys = df[cols[0]] | |
fig, ax = plt.subplots() | |
# Trying small fig due to rendering issues on HF, not on VS Code | |
fig.set_figheight(3) | |
fig.set_figwidth(9) | |
# find stackoverflow reference | |
p, C_p = np.polyfit(xs, ys, n_fit, cov=1) | |
t = np.linspace(min(xs)-1, max(xs)+1, 10*len(xs)) | |
TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T | |
# matrix multiplication calculates the polynomial values | |
yi = np.dot(TT, p) | |
C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T | |
sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal | |
ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha=.25) | |
ax.plot(t, yi, '-') | |
ax.plot(df, 'ro') | |
ax.legend(list(df.columns)) | |
ax.axis('tight') | |
ax.set_xlabel("Value injected into input text") | |
ax.set_title( | |
f"Probability of predicting {gender} pronouns.") | |
ax.set_ylabel(f"Softmax prob for pronouns") | |
ax.xaxis.set_major_locator(MaxNLocator(6)) | |
ax.tick_params(axis='x', labelrotation=5) | |
return fig | |
# %% | |
def predict_gender_pronouns( | |
model_name, | |
own_model_name, | |
indie_vars, | |
split_key, | |
normalizing, | |
n_fit, | |
input_text, | |
): | |
"""Run inference on input_text for each model type, returning df and plots of percentage | |
of gender pronouns predicted as female and male in each target text. | |
""" | |
if model_name not in MODEL_NAMES: | |
model = pipeline("fill-mask", model=own_model_name) | |
else: | |
model = models[model_name] | |
mask_token = model.tokenizer.mask_token | |
indie_vars_list = indie_vars.split(',') | |
male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids() | |
text_segments, num_preds = prepare_text_for_masking( | |
input_text, mask_token, male_gendered_tokens + female_gendered_tokens, split_key) | |
male_pronoun_preds = [] | |
female_pronoun_preds = [] | |
for indie_var in indie_vars_list: | |
target_text = f"{indie_var}".join(text_segments) | |
mask_filled_text = model(target_text) | |
# Quick hack as realized return type based on how many MASKs in text. | |
if type(mask_filled_text[0]) is not list: | |
mask_filled_text = [mask_filled_text] | |
female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( | |
mask_filled_text, | |
female_gendered_tokens, | |
num_preds | |
)) | |
male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( | |
mask_filled_text, | |
male_gendered_tokens, | |
num_preds | |
)) | |
if normalizing: | |
total_gendered_probs = np.add( | |
female_pronoun_preds, male_pronoun_preds) | |
female_pronoun_preds = np.around( | |
np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100, | |
decimals=DECIMAL_PLACES | |
) | |
male_pronoun_preds = np.around( | |
np.divide(male_pronoun_preds, total_gendered_probs+EPS)*100, | |
decimals=DECIMAL_PLACES | |
) | |
results_df = pd.DataFrame({'x-axis': indie_vars_list}) | |
results_df['female_pronouns'] = female_pronoun_preds | |
results_df['male_pronouns'] = male_pronoun_preds | |
female_fig = get_figure(results_df.drop( | |
'male_pronouns', axis=1), 'female', n_fit,) | |
male_fig = get_figure(results_df.drop( | |
'female_pronouns', axis=1), 'male', n_fit,) | |
display_text = f"{random.choice(indie_vars_list)}".join(text_segments) | |
return ( | |
display_text, | |
female_fig, | |
male_fig, | |
results_df, | |
) | |
# %% | |
title = "Causing Gender Pronouns" | |
description = """ | |
## Intro | |
""" | |
date_example = [ | |
MODEL_NAMES[1], | |
'', | |
', '.join(DATES), | |
'DATE', | |
"False", | |
1, | |
'She was a teenager in DATE.' | |
] | |
place_example = [ | |
MODEL_NAMES[0], | |
'', | |
', '.join(PLACES), | |
'PLACE', | |
"False", | |
1, | |
'She became an adult in PLACE.' | |
] | |
subreddit_example = [ | |
MODEL_NAMES[3], | |
'', | |
', '.join(SUBREDDITS), | |
'SUBREDDIT', | |
"False", | |
1, | |
'She was a kid. SUBREDDIT.' | |
] | |
own_model_example = [ | |
OWN_MODEL_NAME, | |
'emilyalsentzer/Bio_ClinicalBERT', | |
', '.join(DATES), | |
'DATE', | |
"False", | |
1, | |
'She was exposed to the virus in DATE.' | |
] | |
def date_fn(): | |
return date_example | |
def place_fn(): | |
return place_example | |
def reddit_fn(): | |
return subreddit_example | |
def your_fn(): | |
return own_model_example | |
# %% | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs") | |
gr.Markdown("Find spurious correlations between seemingly independent variables (for example between `gender` and `time`) in almost any BERT-like LLM on Hugging Face, below.") | |
gr.Markdown("See why this happens how in our paper, [Selection Bias Induced Spurious Correlations in Large Language Models](https://arxiv.org/pdf/2207.08982.pdf), presented at [ICML 2022 Workshop on Spurious Correlations, Invariance, and Stability](https://sites.google.com/view/scis-workshop/home).") | |
gr.Markdown("## Instructions for this Demo") | |
gr.Markdown("1) Click on one of the examples below (where we sweep through a spectrum of `places`, `dates` and `subreddits`) to pre-populate the input fields.") | |
gr.Markdown("2) Check out the pre-populated fields as you scroll down to the ['Hit Submit...'] button!") | |
gr.Markdown("3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!") | |
gr.Markdown("## Example inputs") | |
gr.Markdown("Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions.") | |
with gr.Row(): | |
date_gen = gr.Button('Click for date example inputs') | |
gr.Markdown("<-- x-axis sorted by older to more recent dates:") | |
place_gen = gr.Button('Click for country example inputs') | |
gr.Markdown( | |
"<-- x-axis sorted by bottom 10 and top 10 [Global Gender Gap](https://www3.weforum.org/docs/WEF_GGGR_2021.pdf) ranked countries:") | |
subreddit_gen = gr.Button('Click for Subreddit example inputs') | |
gr.Markdown( | |
"<-- x-axis sorted in order of increasing self-identified female participation (see [bburky](http://bburky.com/subredditgenderratios/)): ") | |
your_gen = gr.Button('Add-a-model example inputs') | |
gr.Markdown("<-- x-axis dates, with your own model loaded! (If first time, try another example, it can take a while to load new model.)") | |
gr.Markdown("## Input fields") | |
gr.Markdown( | |
f"A) Pick a spectrum of comma separated values for text injection and x-axis.") | |
with gr.Row(): | |
x_axis = gr.Textbox( | |
lines=3, | |
label="A) Comma separated values for text injection and x-axis", | |
) | |
gr.Markdown("B) Pick a pre-loaded BERT-family model of interest on the right.") | |
gr.Markdown(f"Or C) select `{OWN_MODEL_NAME}`, then add the mame of any other Hugging Face model that supports the [fill-mask](https://huggingface.co/models?pipeline_tag=fill-mask) task on the right (note: this may take some time to load).") | |
with gr.Row(): | |
model_name = gr.Radio( | |
MODEL_NAMES + [OWN_MODEL_NAME], | |
type="value", | |
label="B) BERT-like model.", | |
) | |
own_model_name = gr.Textbox( | |
label="C) If you selected an 'add-a-model' model, put any Hugging Face pipeline model name (that supports the fill-mask task) here.", | |
) | |
gr.Markdown("D) Pick if you want to the predictions normalied to these gendered terms only.") | |
gr.Markdown("E) Also tell the demo what special token you will use in your input text, that you would like replaced with the spectrum of values you listed above.") | |
gr.Markdown("And F) the degree of polynomial fit used for high-lighting potential spurious association.") | |
with gr.Row(): | |
to_normalize = gr.Dropdown( | |
["False", "True"], | |
label="D) Normalize model's predictions to only the gendered ones?", | |
type="index", | |
) | |
place_holder = gr.Textbox( | |
label="E) Special token place-holder", | |
) | |
n_fit = gr.Dropdown( | |
list(range(1, 5)), | |
label="F) Degree of polynomial fit", | |
type="value", | |
) | |
gr.Markdown( | |
"G) Finally, add input text that includes at least one gendered pronouns and one place-holder token specified above.") | |
with gr.Row(): | |
input_text = gr.Textbox( | |
lines=2, | |
label="G) Input text with pronouns and place-holder token", | |
) | |
gr.Markdown("## Outputs!") | |
#gr.Markdown("Scroll down and 'Hit Submit'!") | |
with gr.Row(): | |
btn = gr.Button("Hit submit to generate predictions!") | |
with gr.Row(): | |
sample_text = gr.Textbox( | |
type="auto", label="Output text: Sample of text fed to model") | |
with gr.Row(): | |
female_fig = gr.Plot(type="auto") | |
male_fig = gr.Plot(type="auto") | |
with gr.Row(): | |
df = gr.Dataframe( | |
show_label=True, | |
overflow_row_behaviour="show_ends", | |
label="Table of softmax probability for pronouns predictions", | |
) | |
with gr.Row(): | |
date_gen.click(date_fn, inputs=[], outputs=[model_name, own_model_name, | |
x_axis, place_holder, to_normalize, n_fit, input_text]) | |
place_gen.click(place_fn, inputs=[], outputs=[ | |
model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) | |
subreddit_gen.click(reddit_fn, inputs=[], outputs=[ | |
model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) | |
your_gen.click(your_fn, inputs=[], outputs=[ | |
model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) | |
btn.click( | |
predict_gender_pronouns, | |
inputs=[model_name, own_model_name, x_axis, place_holder, | |
to_normalize, n_fit, input_text], | |
outputs=[sample_text, female_fig, male_fig, df]) | |
demo.launch(debug=True) | |
# %% | |