nonsense response when bsz>1

#16
by OliverNova - opened

It seems that both 9b-it and 27b-it will generate nonsense responses when using a bsz > 1 for inference.
Would you kindly look into that?
Thank you!

Google org

Hey @OliverNova , thanks for your report! Can you please make sure you're using the latest transformers version (v4.42.3)?

If it still happens, do you mind sharing a reproducible code snippet for us to take a look at?

Hi @lysandre here is my code snippet:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

input_text = ["Write me a poem about Machine Learning.", "I want to eat a"]
input_ids = tokenizer.batch_encode_plus(input_text, return_tensors="pt", padding=True)

for t in input_ids:
    if torch.is_tensor(input_ids[t]):
        input_ids[t] = input_ids[t].to("cuda:0")

outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs))

With a bsz of 2, the second input doesnt seem to get processed at all. This wasnt the case with gemma1.

Sign up or log in to comment