Problems with sample when using left padding and enable sampling
[Solved with transformer=4.45.2]
My code structure is:
generated_ids = model.generate(
model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pad_token_id=pad_token_id,
max_length=200,
repetition_penalty=1.0,
do_sample=True,
temperature=1.0,
top_k=50,
top_p=1.0,
)
I got errors like the following only when batch_size > 1:
File */lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File */lib/python3.10/site-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2016 input_ids, model_kwargs = self._expand_inputs_for_generation(
2017 input_ids=input_ids,
2018 expand_size=generation_config.num_return_sequences,
2019 is_encoder_decoder=self.config.is_encoder_decoder,
2020 **model_kwargs,
2021 )
2023 # 13. run sample (it degenerates to greedy search when generation_config.do_sample=False
)
-> 2024 result = self._sample(
2025 input_ids,
2026 logits_processor=prepared_logits_processor,
2027 logits_warper=prepared_logits_warper,
2028 stopping_criteria=prepared_stopping_criteria,
2029 generation_config=generation_config,
2030 synced_gpus=synced_gpus,
2031 streamer=streamer,
2032 **model_kwargs,
2033 )
2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2036 # 11. prepare logits warper
2037 prepared_logits_warper = (
2038 self._get_logits_warper(generation_config, device=input_ids.device)
2039 if generation_config.do_sample
2040 else None
2041 )
File */lib/python3.10/site-packages/transformers/generation/utils.py:3020, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
3018 probs = nn.functional.softmax(next_token_scores, dim=-1)
3019 # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
-> 3020 next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
3021 else:
3022 next_tokens = torch.argmax(next_token_scores, dim=-1)
RuntimeError: probability tensor contains either inf
, nan
or element < 0
I didn't have this problem with other Gemma models. The result is fine when batch size = 1.
My transformers version = 4.44.0, torch=2.4.0, device=H100*2. I am loading models using device_map='auto'
@GopiUppari Could you take a look at this issue and provide some feedback? Thank you very much!
This issue is solved with the newest transformer library update (4.45.2).