Upload modeling_cpmbee.py
Browse files- modeling_cpmbee.py +4 -2
modeling_cpmbee.py
CHANGED
@@ -1634,8 +1634,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
1634 |
)
|
1635 |
|
1636 |
# reshape for beam search
|
1637 |
-
|
1638 |
-
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
1639 |
|
1640 |
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
|
1641 |
next_token_scores, next_tokens = torch.topk(
|
@@ -1872,6 +1871,7 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
1872 |
logits_processor=logits_processor,
|
1873 |
pad_token_id=generation_config.pad_token_id,
|
1874 |
eos_token_id=generation_config.eos_token_id,
|
|
|
1875 |
output_scores=generation_config.output_scores,
|
1876 |
return_dict_in_generate=generation_config.return_dict_in_generate,
|
1877 |
synced_gpus=synced_gpus,
|
@@ -1909,6 +1909,8 @@ class CpmBeeForCausalLM(CpmBeePreTrainedModel):
|
|
1909 |
input_encoded = tokenizer(data_list, return_tensors="pt", padding=True, device=self.device)
|
1910 |
input_encoded.update(kwargs)
|
1911 |
input_encoded["generation_config"] = generation_config
|
|
|
|
|
1912 |
|
1913 |
decode_res = self._generate(**input_encoded)
|
1914 |
|
|
|
1634 |
)
|
1635 |
|
1636 |
# reshape for beam search
|
1637 |
+
next_token_scores = next_token_scores.view(batch_size, -1)
|
|
|
1638 |
|
1639 |
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
|
1640 |
next_token_scores, next_tokens = torch.topk(
|
|
|
1871 |
logits_processor=logits_processor,
|
1872 |
pad_token_id=generation_config.pad_token_id,
|
1873 |
eos_token_id=generation_config.eos_token_id,
|
1874 |
+
vocab_size=kwargs.get("vocab_size", None),
|
1875 |
output_scores=generation_config.output_scores,
|
1876 |
return_dict_in_generate=generation_config.return_dict_in_generate,
|
1877 |
synced_gpus=synced_gpus,
|
|
|
1909 |
input_encoded = tokenizer(data_list, return_tensors="pt", padding=True, device=self.device)
|
1910 |
input_encoded.update(kwargs)
|
1911 |
input_encoded["generation_config"] = generation_config
|
1912 |
+
input_encoded["vocab_size"] = tokenizer.vocab_size
|
1913 |
+
print(tokenizer.vocab_size)
|
1914 |
|
1915 |
decode_res = self._generate(**input_encoded)
|
1916 |
|