Spaces:
Running
Running
| import gradio | |
| import inseq | |
| from inseq.data.aggregator import AggregatorPipeline, SubwordAggregator, SequenceAttributionAggregator, PairAggregator | |
| import torch | |
| from os.path import exists | |
| if torch.cuda.is_available(): | |
| DEVICE = "cuda" | |
| else: | |
| DEVICE = "cpu" | |
| # Start downloading the Hu-En model | |
| # model_hu_en = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients") | |
| def swap_pronoun(sentence): | |
| if "He" in sentence: | |
| return sentence.replace("He", "She") | |
| elif "She" in sentence: | |
| return sentence.replace("She", "He") | |
| else: | |
| return sentence | |
| def run_counterfactual(occupation): | |
| occupation = occupation.split(" (")[0] | |
| result_fp = f"results/counterfactual_{occupation}.html" | |
| if exists(result_fp): | |
| with open(result_fp, 'r') as f: | |
| return f.read() | |
| # "egy" means something like "a", but is used less frequently than in English. | |
| #source = f"Ő egy {occupation}." | |
| source = f"Ő {occupation}." | |
| model = inseq.load_model("Helsinki-NLP/opus-mt-hu-en", "integrated_gradients") | |
| model.device = DEVICE | |
| target = model.generate(source)[0] | |
| #target_modified = swap_pronoun(target) | |
| out = model.attribute( | |
| [ | |
| source, | |
| source, | |
| ], | |
| [ | |
| #target, | |
| #target_modified, | |
| target.replace("She", "He"), | |
| target.replace("He", "She"), | |
| ], | |
| n_steps=150, | |
| return_convergence_delta=False, | |
| attribute_target=False, | |
| step_scores=["probability"], | |
| internal_batch_size=100, | |
| include_eos_baseline=False, | |
| device=DEVICE, | |
| ) | |
| #out = model.attribute(source, attribute_target=False, n_steps=150, device=DEVICE, return_convergence_delta=False, step_scores=["probability"]) | |
| squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator]) | |
| masculine = out.sequence_attributions[0].aggregate(aggregator=squeezesum) | |
| feminine = out.sequence_attributions[1].aggregate(aggregator=squeezesum) | |
| html = masculine.show(aggregator=PairAggregator, paired_attr=feminine, return_html=True, display=True) | |
| # Save html | |
| with open(result_fp, 'w') as f: | |
| f.write(html) | |
| return html | |
| #return out.show(return_html=True, display=True) | |
| def run_simple(occupation, lang, aggregate): | |
| aggregate = True if aggregate == "yes" else False | |
| occupation = occupation.split(" (")[0] | |
| result_fp = f"results/simple_{occupation}_{lang}{'_aggregate' if aggregate else ''}.html" | |
| if exists(result_fp): | |
| with open(result_fp, 'r') as f: | |
| return f.read() | |
| model_name = f"Helsinki-NLP/opus-mt-hu-{lang}" | |
| # "egy" means something like "a", but is used less frequently than in English. | |
| #source = f"Ő egy {occupation}." | |
| source = f"Ő {occupation}." | |
| model = inseq.load_model(model_name, "integrated_gradients") | |
| out = model.attribute([source], attribute_target=True, n_steps=150, device=DEVICE, return_convergence_delta=False) | |
| if aggregate: | |
| squeezesum = AggregatorPipeline([SubwordAggregator, SequenceAttributionAggregator]) | |
| html = out.show(return_html=True, display=True, aggregator=squeezesum) | |
| else: | |
| html = out.show(return_html=True, display=True) | |
| # Save html | |
| with open(result_fp, 'w') as f: | |
| f.write(html) | |
| return html | |
| with open("description.md") as fh: | |
| desc = fh.read() | |
| with open("simple_translation.md") as fh: | |
| simple_translation = fh.read() | |
| with open("contrastive_pair.md") as fh: | |
| contrastive_pair = fh.read() | |
| with open("notice.md") as fh: | |
| notice = fh.read() | |
| OCCUPATIONS = [ | |
| "nő (woman)", | |
| "férfi (man)", | |
| "nővér (nurse)", | |
| "tudós (scientist)", | |
| "mérnök (engineer)", | |
| "pék (baker)", | |
| "tanár (teacher)", | |
| "esküvőszervező (wedding organizer)", | |
| "vezérigazgató (CEO)", | |
| ] | |
| LANGS = [ | |
| "en", | |
| "fr", | |
| "de", | |
| ] | |
| with gradio.Blocks(title="Gender Bias in MT: Hungarian to English") as iface: | |
| gradio.Markdown(desc) | |
| print(simple_translation) | |
| with gradio.Accordion("Simple translation", open=True): | |
| gradio.Markdown(simple_translation) | |
| with gradio.Accordion("Contrastive pair", open=False): | |
| gradio.Markdown(contrastive_pair) | |
| gradio.Markdown("**Does the model seem to rely on gender stereotypes in its translations?**") | |
| with gradio.Tab("Simple translation"): | |
| with gradio.Row(equal_height=True): | |
| with gradio.Column(scale=4): | |
| occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0]) | |
| with gradio.Column(scale=4): | |
| target_lang = gradio.Dropdown(label="Target Language", choices=LANGS, value=LANGS[0]) | |
| aggregate_subwords = gradio.Radio( | |
| ["yes", "no"], label="Aggregate subwords?", value="yes" | |
| ) | |
| but = gradio.Button("Translate & Attribute") | |
| out = gradio.HTML() | |
| args = [occupation_sel, target_lang, aggregate_subwords] | |
| but.click(run_simple, inputs=args, outputs=out) | |
| with gradio.Tab("Contrastive pair"): | |
| with gradio.Row(equal_height=True): | |
| with gradio.Column(scale=4): | |
| occupation_sel = gradio.Dropdown(label="Occupation", choices=OCCUPATIONS, value=OCCUPATIONS[0]) | |
| but = gradio.Button("Translate & Attribute") | |
| out = gradio.HTML() | |
| args = [occupation_sel] | |
| but.click(run_counterfactual, inputs=args, outputs=out) | |
| with gradio.Accordion("Notes & References", open=False): | |
| gradio.Markdown(notice) | |
| iface.launch() |