Trouble generating rewritten query

#1
by PaulBrouillette - opened

In my Google Colab file, I have everything set up properly like this:

from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
model_name = "castorini/t5-base-canard"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

context = 'Frank Zappa ||| Disbandment ||| What group disbanded ||| Zappa and the Mothers of Invention ||| When did they disband?'

encoded_input = tokenizer(
context,
padding='max_length',
max_length=512,
truncation=True,
return_tensors="pt",
)
decoder_input = tokenizer(
context,
padding='max_length',
max_length=512,
truncation=True,
return_tensors="pt",
)

encoder_output = model.generate(input_ids=encoded_input["input_ids"], decoder_input_ids=decoder_input["input_ids"])
output = tokenizer.decode(
encoder_output[0],
skip_special_tokens=True
)
print(output)

However, my output looks like this:
Input length of decoder_input_ids is 512, but max_length is set to 20. This can lead to unexpected behavior. You should consider increasing config.max_length or max_length.
Frank Zappa ||| Disbandment ||| What group disbanded ||| Zappa and the Mothers of Invention ||| When did they disband? When

I have been trying to change the max_length values to print out the fully rewritten query, for example in the max_length's for encoded_input and decoder_input I set them to 59 and set the bottom max_length to 89, and the output is:
Frank Zappa ||| Disbandment ||| What group disbanded ||| Zappa and the Mothers of Invention ||| When did they disband? When did Frank Zappa and the Mothers of Invention disband?

I have to hardcode those values which isn't convenient when there's several 'context' phrases. Does anyone know how to circumvent this problem?

Castorini org

Thanks for your question. I guess you would like to use T5ForConditionalGeneration.from_pretrained in this case. Here is our implementation in chattygoose for your reference. https://github.com/castorini/chatty-goose/blob/master/chatty_goose/cqr/ntr.py

Thank you for your reply. I looked at the GitHub link and I did:
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)

However, I'm still getting the same output:
Input length of decoder_input_ids is 512, but max_length is set to 20. This can lead to unexpected behavior. You should consider increasing config.max_length or max_length.
Frank Zappa ||| Disbandment ||| What group disbanded ||| Zappa and the Mothers of Invention ||| When did they disband? When

Do I need to change the values of my max_length, or do I need to change more things and make it more similar to the link?

Castorini org

I think you don't need decoder_input_ids in model.generation. Just use something like this: encoder_output = model.generate(input_ids=encoded_input["input_ids"], max_length=512, num_beams=1).

Wow, that worked! Thank you so much! :)

Castorini org

You are welcome :)

jacklin changed discussion status to closed

Sign up or log in to comment