gemma-2b-it model works but gemma-7b-it model generates errors

#51
by saurabhkumar - opened

I tried using the same code with google/gemma-2b-it and google/gemma-7b-it. The 2b-it model generates the text but 7b-it model generates error. I am using a node with 8xA100 GPUs (it is not needed for this, but that is what I had when trying this out). Changing device to CPU also does not make any difference.

cache_dir = '/path/to/hf_model_cache'
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it", cache_dir=cache_dir, device_map="cuda",torch_dtype=torch.bfloat16)
gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", cache_dir=cache_dir)

user_request = "Write me the simplest code snippet in python you can think of."

chat = [
    { "role": "user", "content": user_request },
]
prompt = gemma_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

inputs = gemma_tokenizer.encode(prompt, add_special_tokens=True, return_tensors="pt")
outputs = gemma.generate(input_ids=inputs.to(gemma.device), max_new_tokens=200)

print(gemma_tokenizer.decode(outputs[0]))

Error:

  File "/myfile.py", line <line with generate>, in <module>
    outputs = gemma.generate(input_ids=inputs.to(gemma.device), max_new_tokens=200)
.....
.....

  File "/my_venv/python3.8/site-packages/transformers/models/gemma/modeling_gemma.py", line 280, in forward
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
RuntimeError: shape '[1, 23, 3072]' is invalid for input of size 94208

Replacing just the model and tokenizer for 7b-it with 2b-it works fine

gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", cache_dir=cache_dir, device_map="cuda",torch_dtype=torch.bfloat16)
gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it", cache_dir=cache_dir)

Output (gets formatted here because of the presence of "```" in the generated output:

<bos><bos><start_of_turn>user
Write me the simplest code snippet in python you can think of.<end_of_turn>
<start_of_turn>model
```python
print("Hello, world!")

This code will print the string "Hello, world!" to the console.

Explanation:

  • print() is a built-in Python function that prints a message to the console.
  • "Hello, world!" is the string we want to print.

Output:

Hello, world!
```<eos>

I updated to transformers-4.38.1 now and this solved this issue. I hope this was the right solution.
The model generated the following text (formatted here as before due to presence of "```" in the generated text):

<bos><bos><start_of_turn>user
Write me the simplest code snippet in python you can think of.<end_of_turn>
<start_of_turn>model
```python
print "Hello, world!"

# This line prints the string "Hello, world!" to the console
```<eos>

Nice! Should this be included in the documentation somewhere?

Sign up or log in to comment