tomer-nv commited on
Commit
d71a214
·
verified ·
1 Parent(s): 20cc7f1

Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model

Browse files
Files changed (1) hide show
  1. modeling_decilm.py +19 -0
modeling_decilm.py CHANGED
@@ -1311,6 +1311,25 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
1311
  )
1312
  return model_inputs
1313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1314
 
1315
  @add_start_docstrings(
1316
  """
 
1311
  )
1312
  return model_inputs
1313
 
1314
+ def _maybe_initialize_input_ids_for_generation(
1315
+ self,
1316
+ inputs: Optional[torch.Tensor] = None,
1317
+ bos_token_id: Optional[torch.Tensor] = None,
1318
+ model_kwargs: Optional[dict[str, torch.Tensor]] = None,
1319
+ ) -> torch.LongTensor:
1320
+ """
1321
+ Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
1322
+ """
1323
+ input_ids = super()._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
1324
+ if (
1325
+ "inputs_embeds" in model_kwargs
1326
+ and input_ids is not None
1327
+ and input_ids.shape[1] == 0
1328
+ ):
1329
+ batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
1330
+ input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
1331
+ return input_ids
1332
+
1333
 
1334
  @add_start_docstrings(
1335
  """