global attention mask for model inference

#3
by pszemraj - opened

hey, thanks for the great upload! my question is related to model inference.

I was able to train using the trainer API on Colab, but in trying to run some inference in a notebook on non-dataset text, I get the following error even with LongT5ForConditionalGeneration:

TypeError: forward() got an unexpected keyword argument 'global_attention_mask'

(perhaps unrelated, but that same notebook works fine with this checkpoint for example). I then turned to your example in the model card to see if I could replicate that and I am doing something wrong, but I can't find where you define global_attention_mask in the example:

import torch
from transformers import AutoTokenizer, LongT5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")

input_ids = tokenizer(LONG_ARTICLE, return_tensors="pt").input_ids.to("cuda")

model = LongT5ForConditionalGeneration.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps", return_dict_in_generate=True).to("cuda")

sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences

summary = tokenizer.batch_decode(sequences)

any help would be appreciated :)

Hi, thank you very much for pointing out this issue. It's a mistake on my side. LongT5 model accepts attention_mask, not global_attentiona_mask.. Sorry for the confusion, a I'll fix that :]

Thanks! Works as expected now 👍

I’ll post my checkpoints once I’m happy with the performance but DAMN this thing takes forever to train (even compared to LED at 16384)

pszemraj changed discussion status to closed

Sign up or log in to comment