Bug of modeling_gemma.py in transformers 4.38.0

#45

Traceback (most recent call last):
File "/mnt/bn/motor-nlp-team/mlx/users/zhangkaiqi.zlkqz/repo/5355/Personal_repo/test.py", line 15, in
outputs = model.generate(**input_ids, max_new_tokens=1024)
File "/usr/local/lib/python3.9/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/generation/utils.py", line 1544, in generate
return self.greedy_search(
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/generation/utils.py", line 2404, in greedy_search
outputs = self(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 1068, in forward
outputs = self.model(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 906, in forward
layer_outputs = decoder_layer(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 626, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/tiger/.local/lib/python3.9/site-packages/transformers/models/gemma/modeling_gemma.py", line 280, in forward
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
RuntimeError: shape '[1, 13, 3072]' is invalid for input of size 53248

In modeling_gemma.py in transformers 4.38.0 line 280:
You should change source code:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
to:
attn_output = attn_output.reshape(bsz, q_len, self.num_heads*self.head_dim)

Because the next line is:
attn_output = self.o_proj(attn_output)
attn_output will multiply with o_proj(shape=self.num_heads*self.head_dim, self.hidden_size)
The multiplication operation will fail

Google org

?

zlk changed pull request status to open
zlk changed pull request status to closed
zlk changed pull request status to open
zlk changed pull request status to closed

Sign up or log in to comment