Example code throw error only using *-896
The code provided in the model card does not cause an error when using the google/paligemma2-10b-pt-448
model, but does cause an error when using the google/paligemma2-10b-pt-896
model.
Output summary:
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [16386,0,0], thread: [64,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [16386,0,0], thread: [65,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [16386,0,0], thread: [66,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
...(more than 100 lines)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 11
8 input_len = model_inputs["input_ids"].shape[-1]
10 with torch.inference_mode():
---> 11 generation = model.generate(**model_inputs, max_new_tokens=128, do_sample=False)
12 generation = generation[0][input_len:]
13 decoded = processor.decode(generation, skip_special_tokens=True)
File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:2255, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2247 input_ids, model_kwargs = self._expand_inputs_for_generation(
2248 input_ids=input_ids,
2249 expand_size=generation_config.num_return_sequences,
2250 is_encoder_decoder=self.config.is_encoder_decoder,
2251 **model_kwargs,
2252 )
2254 # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2255 result = self._sample(
2256 input_ids,
2257 logits_processor=prepared_logits_processor,
2258 stopping_criteria=prepared_stopping_criteria,
2259 generation_config=generation_config,
2260 synced_gpus=synced_gpus,
2261 streamer=streamer,
2262 **model_kwargs,
2263 )
2265 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2266 # 11. prepare beam search scorer
2267 beam_scorer = BeamSearchScorer(
2268 batch_size=batch_size,
2269 num_beams=generation_config.num_beams,
(...)
2274 max_length=generation_config.max_length,
2275 )
File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:3254, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
3251 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
3253 if is_prefill:
-> 3254 outputs = self(**model_inputs, return_dict=True)
3255 is_prefill = False
3256 else:
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /opt/conda/lib/python3.10/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
525 labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
527 causal_mask = self._update_causal_mask(
528 attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
529 )
--> 530 outputs = self.language_model(
531 attention_mask=causal_mask,
532 position_ids=position_ids,
533 past_key_values=past_key_values,
534 inputs_embeds=inputs_embeds,
535 use_cache=use_cache,
536 output_attentions=output_attentions,
537 output_hidden_states=output_hidden_states,
538 return_dict=return_dict,
539 cache_position=cache_position,
540 num_logits_to_keep=num_logits_to_keep,
541 )
543 logits = outputs.logits
544 loss = None
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 842 outputs = self.model(
843 input_ids=input_ids,
844 attention_mask=attention_mask,
845 position_ids=position_ids,
846 past_key_values=past_key_values,
847 inputs_embeds=inputs_embeds,
848 use_cache=use_cache,
849 output_attentions=output_attentions,
850 output_hidden_states=output_hidden_states,
851 return_dict=return_dict,
852 cache_position=cache_position,
853 )
855 hidden_states = outputs[0]
856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
617 layer_outputs = self._gradient_checkpointing_func(
618 decoder_layer.__call__,
619 hidden_states,
(...)
626 cache_position,
627 )
628 else:
--> 629 layer_outputs = decoder_layer(
630 hidden_states,
631 position_embeddings=position_embeddings,
632 attention_mask=causal_mask,
633 position_ids=position_ids,
634 past_key_value=past_key_values,
635 output_attentions=output_attentions,
636 use_cache=use_cache,
637 cache_position=cache_position,
638 **flash_attn_kwargs,
639 )
641 hidden_states = layer_outputs[0]
643 if output_attentions:
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
296 hidden_states = self.input_layernorm(hidden_states)
298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
300 hidden_states=hidden_states,
301 position_embeddings=position_embeddings,
302 attention_mask=attention_mask,
303 position_ids=position_ids,
304 past_key_value=past_key_value,
305 output_attentions=output_attentions,
306 use_cache=use_cache,
307 cache_position=cache_position,
308 )
309 hidden_states = self.post_attention_layernorm(hidden_states)
310 hidden_states = residual + hidden_states
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
221 if past_key_value is not None:
222 # sin and cos are specific to RoPE models; cache_position needed for the static cache
223 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 224 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
226 attention_interface: Callable = eager_attention_forward
227 if self.config._attn_implementation != "eager":
File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
1714 else:
1715 update_fn = self._static_update
-> 1717 return update_fn(
1718 cache_position,
1719 layer_idx,
1720 key_states,
1721 value_states,
1722 k_out,
1723 v_out,
1724 k_out.shape[2],
1725 )
File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
-> 1694 k_out[:, :, cache_position] = key_states
1695 v_out[:, :, cache_position] = value_states
1697 self.key_cache[layer_idx] = k_out
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Hi, The model card (google/paligemma2-10b-pt-896) code is running fine when I tried reproducing the error using Nvidia T4*4 GPU (transformer='4.47.1', torch='2.5.1+cu121') setup environment. Please try again and let us know if the issue still persists with some more details like torch
version or environment setup details used to run the above code. Please have a look on model output and execution at the below screenshot.
Thanks for your quick reply. The code is exactly the same and the problem still persists.
Environment
- torch: 2.5.1+cu124
- transformers: 4.48.0
- GPU: NVIDIA A100-PCIE-40GB
There was no problem with models of different resolutions(448) on the same GPU, same container, and same code.
I can reproduce with transformers 4.48.0, will take a closer look. Meanwhile, would it be possible for you to to use 4.47.1
, which works fine?
For reference, the fix is going to be part of https://github.com/huggingface/transformers/pull/35681