Obtain alignment information

#8
by kp1234 - opened

I am trying to obtain the alignment positions of the source to the target words to place xml tags in the translations.
Therefore I wanted to gain the attention_scores but when using model.generate(**encoded_hi, forced_bos_token_id=tokenizer.lang_code_to_id["de_DE"], output_attentions=True, output_hidden_states=True)
print(generated_tokens) only the generated tokes and no attentions are returned.

Does someone have an idea how this could be archived?
Best regards, Kai

AI at Meta org

When running generate, target tokens are added one at a time, and attention from the previous tokens are not preserved.

Thus, to obtain the full cross-attention map, you'll have to run inference one more with the full generated translation, like

import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

article_hi = "संयुक्त राष्ट्र के प्रमुख का कहना है कि सीरिया में कोई सैन्य समाधान नहीं है"
tokenizer.src_lang = "hi_IN"
encoded_hi = tokenizer(article_hi, return_tensors="pt")
generated_tokens = model.generate(
    **encoded_hi,
    forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"]
)
translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
print(translated)  # Le chef de l 'ONU affirme qu 'il n 'y a pas de solution militaire en Syria.

tokenizer.src_lang = "fr_XX"
encoded_fr = tokenizer(translated, return_tensors="pt")
with torch.inference_mode():
    result = model(**encoded_hi, decoder_input_ids=encoded_fr.input_ids, output_attentions=True,)

all_cross_attentions = torch.cat(result.cross_attentions)
print(all_cross_attentions.shape)  # torch.Size([12, 16, 23, 18])
# 12 layers * 16 heads * 23 input tokens * 18 output tokens

Now you can average the all_cross_attentions across the first two dimensions, or pick the best layer and attention head to compute the alignment.

However, this alignment is not guaranteed to be meaningful, because (1) sometimes token attributions are indirect, and (2) sometimes neural networks compensate high attention weights with low attention values, or vice versa. So if you are not satisfied with attention alignment, consider taking a look at its extended version, ALTI+ by Ferrando et al, 2022.

Sign up or log in to comment