Update modeling_cpmbee.py
Browse files- modeling_cpmbee.py +2 -1
modeling_cpmbee.py
CHANGED
@@ -1472,6 +1472,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
1472 |
pad_token_id: Optional[int] = None,
|
1473 |
eos_token_id: Optional[Union[int, List[int]]] = None,
|
1474 |
bos_token_id: Optional[Union[int, List[int]]] = None,
|
|
|
1475 |
output_attentions: Optional[bool] = None,
|
1476 |
output_hidden_states: Optional[bool] = None,
|
1477 |
output_scores: Optional[bool] = None,
|
@@ -1487,6 +1488,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
1487 |
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
1488 |
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
1489 |
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
|
|
1490 |
max_length = max_length if max_length is not None else self.generation_config.max_length
|
1491 |
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
1492 |
output_attentions = (
|
@@ -1589,7 +1591,6 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
1589 |
break
|
1590 |
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
1591 |
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
|
1592 |
-
vocab_size = next_token_logits.shape[-1]
|
1593 |
next_token_logits = self.adjust_logits_during_generation(
|
1594 |
next_token_logits, batch_size, num_beams, vocab_size, ext_table_ids_cpu, **model_kwargs
|
1595 |
)
|
|
|
1472 |
pad_token_id: Optional[int] = None,
|
1473 |
eos_token_id: Optional[Union[int, List[int]]] = None,
|
1474 |
bos_token_id: Optional[Union[int, List[int]]] = None,
|
1475 |
+
vocab_size: Optional[int] = None,
|
1476 |
output_attentions: Optional[bool] = None,
|
1477 |
output_hidden_states: Optional[bool] = None,
|
1478 |
output_scores: Optional[bool] = None,
|
|
|
1488 |
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
|
1489 |
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
|
1490 |
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
|
1491 |
+
vocab_size = vocab_size if vocab_size is not None else self.generation_config.vocab_size
|
1492 |
max_length = max_length if max_length is not None else self.generation_config.max_length
|
1493 |
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
|
1494 |
output_attentions = (
|
|
|
1591 |
break
|
1592 |
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
|
1593 |
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
|
|
|
1594 |
next_token_logits = self.adjust_logits_during_generation(
|
1595 |
next_token_logits, batch_size, num_beams, vocab_size, ext_table_ids_cpu, **model_kwargs
|
1596 |
)
|