MT-bias-demo / app.py
Oskar van der Wal
Load results from previously found html files directly to make the demo faster.
4ba9f41
raw history blame
No virus
5.55 kB
import gradio
import inseq
from inseq.data.aggregator import AggregatorPipeline, SubwordAggregator, SequenceAttributionAggregator, PairAggregator
import torch
from os.path import exists
if torch.cuda.is_available():
DEVICE = "cuda"
else:
DEVICE = "cpu"
# Start downloading the Hu-En model
# model_hu_en = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients")
def swap_pronoun(sentence):
if "He" in sentence:
return sentence.replace("He", "She")
elif "She" in sentence:
return sentence.replace("She", "He")
else:
return sentence
def run_counterfactual(occupation):
occupation = occupation.split(" (")[0]
result_fp = f"results/counterfactual_{occupation}.html"
if exists(result_fp):
with open(result_fp, 'r') as f:
return f.read()
# "egy" means something like "a", but is used less frequently than in English.
#source = f"Ő egy {occupation}."
source = f"Ő {occupation}."
model = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients")
model.device = DEVICE
target = model.generate(source)[0]
#target_modified = swap_pronoun(target)
out = model.attribute(
[
source,
source,
],
[
#target,
#target_modified,
target.replace("She", "He"),
target.replace("He", "She"),
],
n_steps=150,
return_convergence_delta=False,
attribute_target=False,
step_scores=["probability"],
internal_batch_size=100,
include_eos_baseline=False,
device=DEVICE,
)
#out = model.attribute(source, attribute_target=False, n_steps=150, device=DEVICE, return_convergence_delta=False, step_scores=["probability"])
squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator])
masculine = out.sequence_attributions[0].aggregate(aggregator=squeezesum)
feminine = out.sequence_attributions[1].aggregate(aggregator=squeezesum)
html = masculine.show(aggregator=PairAggregator, paired_attr=feminine, return_html=True, display=True)
# Save html
with open(result_fp, 'w') as f:
f.write(html)
return html
#return out.show(return_html=True, display=True)
def run_simple(occupation, lang, aggregate):
aggregate = True if aggregate == "yes" else False
occupation = occupation.split(" (")[0]
result_fp = f"results/simple_{occupation}_{lang}{'_aggregate' if aggregate else ''}.html"
if exists(result_fp):
with open(result_fp, 'r') as f:
return f.read()
model_name = f"Helsinki-NLP/opus-mt-hu-{lang}"
# "egy" means something like "a", but is used less frequently than in English.
#source = f"Ő egy {occupation}."
source = f"Ő {occupation}."
model = inseq.load_model(model_name, "integrated_gradients")
out = model.attribute([source], attribute_target=True, n_steps=150, device=DEVICE, return_convergence_delta=False)
if aggregate:
squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator])
html = out.show(return_html=True, display=True, aggregator=squeezesum)
else:
html = out.show(return_html=True, display=True)
# Save html
with open(result_fp, 'w') as f:
f.write(html)
return html
with open("description.md") as fh:
desc = fh.read()
with open("simple_translation.md") as fh:
simple_translation = fh.read()
with open("contrastive_pair.md") as fh:
contrastive_pair = fh.read()
with open("notice.md") as fh:
notice = fh.read()
OCCUPATIONS = [
"nő (woman)",
"férfi (man)",
"nővér (nurse)",
"tudós (scientist)",
"mérnök (engineer)",
"pék (baker)",
"tanár (teacher)",
"esküvőszervező (wedding organizer)",
"vezérigazgató (CEO)",
]
LANGS = [
"en",
"fr",
"de",
]
with gradio.Blocks(title="Gender Bias in MT: Hungarian to English") as iface:
gradio.Markdown(desc)
print(simple_translation)
with gradio.Accordion("Simple translation", open=True):
gradio.Markdown(simple_translation)
with gradio.Accordion("Contrastive pair", open=False):
gradio.Markdown(contrastive_pair)
gradio.Markdown("**Does the model seem to rely on gender stereotypes in its translations?**")
with gradio.Tab("Simple translation"):
with gradio.Row(equal_height=True):
with gradio.Column(scale=4):
occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0])
with gradio.Column(scale=4):
target_lang = gradio.Dropdown(label="Target Language", choices=LANGS, value=LANGS[0])
aggregate_subwords = gradio.Radio(
["yes", "no"], label="Aggregate subwords?", value="yes"
)
but = gradio.Button("Translate & Attribute")
out = gradio.HTML()
args = [occupation_sel, target_lang, aggregate_subwords]
but.click(run_simple, inputs=args, outputs=out)
with gradio.Tab("Contrastive pair"):
with gradio.Row(equal_height=True):
with gradio.Column(scale=4):
occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0])
but = gradio.Button("Translate & Attribute")
out = gradio.HTML()
args = [occupation_sel]
but.click(run_counterfactual, inputs=args, outputs=out)
with gradio.Accordion("Notes & References", open=False):
gradio.Markdown(notice)
iface.launch(share=True)