Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
Browse files- modeling_decilm.py +19 -0
modeling_decilm.py
CHANGED
@@ -1311,6 +1311,25 @@ class DeciLMForCausalLM(DeciLMPreTrainedModel, GenerationMixin):
|
|
1311 |
)
|
1312 |
return model_inputs
|
1313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1314 |
|
1315 |
@add_start_docstrings(
|
1316 |
"""
|
|
|
1311 |
)
|
1312 |
return model_inputs
|
1313 |
|
1314 |
+
def _maybe_initialize_input_ids_for_generation(
|
1315 |
+
self,
|
1316 |
+
inputs: Optional[torch.Tensor] = None,
|
1317 |
+
bos_token_id: Optional[torch.Tensor] = None,
|
1318 |
+
model_kwargs: Optional[dict[str, torch.Tensor]] = None,
|
1319 |
+
) -> torch.LongTensor:
|
1320 |
+
"""
|
1321 |
+
Patching hf bug that creates wrong cache length if only inputs_embeds are passed to the model
|
1322 |
+
"""
|
1323 |
+
input_ids = super()._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
|
1324 |
+
if (
|
1325 |
+
"inputs_embeds" in model_kwargs
|
1326 |
+
and input_ids is not None
|
1327 |
+
and input_ids.shape[1] == 0
|
1328 |
+
):
|
1329 |
+
batch_size, input_sequence_length = model_kwargs["inputs_embeds"].shape[:2]
|
1330 |
+
input_ids = torch.zeros((batch_size, input_sequence_length), dtype=torch.long, device=self.device)
|
1331 |
+
return input_ids
|
1332 |
+
|
1333 |
|
1334 |
@add_start_docstrings(
|
1335 |
"""
|