MT-bias-demo / app.py
Oskar van der Wal
Disable share=True on Spaces
14b1df3
import gradio
import inseq
from inseq.data.aggregator import AggregatorPipeline, SubwordAggregator, SequenceAttributionAggregator, PairAggregator
import torch
from os.path import exists
if torch.cuda.is_available():
DEVICE = "cuda"
else:
DEVICE = "cpu"
# Start downloading the Hu-En model
# model_hu_en = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients")
def swap_pronoun(sentence):
if "He" in sentence:
return sentence.replace("He", "She")
elif "She" in sentence:
return sentence.replace("She", "He")
else:
return sentence
def run_counterfactual(occupation):
occupation = occupation.split(" (")[0]
result_fp = f"results/counterfactual_{occupation}.html"
if exists(result_fp):
with open(result_fp, 'r') as f:
return f.read()
# "egy" means something like "a", but is used less frequently than in English.
#source = f"Ő egy {occupation}."
source = f"Ő {occupation}."
model = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients")
model.device = DEVICE
target = model.generate(source)[0]
#target_modified = swap_pronoun(target)
out = model.attribute(
[
source,
source,
],
[
#target,
#target_modified,
target.replace("She", "He"),
target.replace("He", "She"),
],
n_steps=150,
return_convergence_delta=False,
attribute_target=False,
step_scores=["probability"],
internal_batch_size=100,
include_eos_baseline=False,
device=DEVICE,
)
#out = model.attribute(source, attribute_target=False, n_steps=150, device=DEVICE, return_convergence_delta=False, step_scores=["probability"])
squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator])
masculine = out.sequence_attributions[0].aggregate(aggregator=squeezesum)
feminine = out.sequence_attributions[1].aggregate(aggregator=squeezesum)
html = masculine.show(aggregator=PairAggregator, paired_attr=feminine, return_html=True, display=True)
# Save html
with open(result_fp, 'w') as f:
f.write(html)
return html
#return out.show(return_html=True, display=True)
def run_simple(occupation, lang, aggregate):
aggregate = True if aggregate == "yes" else False
occupation = occupation.split(" (")[0]
result_fp = f"results/simple_{occupation}_{lang}{'_aggregate' if aggregate else ''}.html"
if exists(result_fp):
with open(result_fp, 'r') as f:
return f.read()
model_name = f"Helsinki-NLP/opus-mt-hu-{lang}"
# "egy" means something like "a", but is used less frequently than in English.
#source = f"Ő egy {occupation}."
source = f"Ő {occupation}."
model = inseq.load_model(model_name, "integrated_gradients")
out = model.attribute([source], attribute_target=True, n_steps=150, device=DEVICE, return_convergence_delta=False)
if aggregate:
squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator])
html = out.show(return_html=True, display=True, aggregator=squeezesum)
else:
html = out.show(return_html=True, display=True)
# Save html
with open(result_fp, 'w') as f:
f.write(html)
return html
with open("description.md") as fh:
desc = fh.read()
with open("simple_translation.md") as fh:
simple_translation = fh.read()
with open("contrastive_pair.md") as fh:
contrastive_pair = fh.read()
with open("notice.md") as fh:
notice = fh.read()
OCCUPATIONS = [
"nő (woman)",
"férfi (man)",
"nővér (nurse)",
"tudós (scientist)",
"mérnök (engineer)",
"pék (baker)",
"tanár (teacher)",
"esküvőszervező (wedding organizer)",
"vezérigazgató (CEO)",
]
LANGS = [
"en",
"fr",
"de",
]
with gradio.Blocks(title="Gender Bias in MT: Hungarian to English") as iface:
gradio.Markdown(desc)
print(simple_translation)
with gradio.Accordion("Simple translation", open=True):
gradio.Markdown(simple_translation)
with gradio.Accordion("Contrastive pair", open=False):
gradio.Markdown(contrastive_pair)
gradio.Markdown("**Does the model seem to rely on gender stereotypes in its translations?**")
with gradio.Tab("Simple translation"):
with gradio.Row(equal_height=True):
with gradio.Column(scale=4):
occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0])
with gradio.Column(scale=4):
target_lang = gradio.Dropdown(label="Target Language", choices=LANGS, value=LANGS[0])
aggregate_subwords = gradio.Radio(
["yes", "no"], label="Aggregate subwords?", value="yes"
)
but = gradio.Button("Translate & Attribute")
out = gradio.HTML()
args = [occupation_sel, target_lang, aggregate_subwords]
but.click(run_simple, inputs=args, outputs=out)
with gradio.Tab("Contrastive pair"):
with gradio.Row(equal_height=True):
with gradio.Column(scale=4):
occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0])
but = gradio.Button("Translate & Attribute")
out = gradio.HTML()
args = [occupation_sel]
but.click(run_counterfactual, inputs=args, outputs=out)
with gradio.Accordion("Notes & References", open=False):
gradio.Markdown(notice)
iface.launch()