x54-729 commited on
Commit
dd2fa16
1 Parent(s): 4217652

Update modeling_internlm.py

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +4 -0
modeling_internlm.py CHANGED
@@ -784,6 +784,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
784
  do_sample: bool = True,
785
  temperature: float = 0.8,
786
  top_p: float = 0.8,
 
787
  **kwargs):
788
  inputs = self.build_inputs(tokenizer, query, history)
789
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
@@ -793,6 +794,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
793
  do_sample=do_sample,
794
  temperature=temperature,
795
  top_p=top_p,
 
796
  **kwargs)
797
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
798
  response = tokenizer.decode(outputs, skip_special_tokens=True)
@@ -809,6 +811,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
809
  do_sample: bool = True,
810
  temperature: float = 0.8,
811
  top_p: float = 0.8,
 
812
  **kwargs):
813
  class ChatStreamer(BaseStreamer):
814
  def __init__(self, tokenizer) -> None:
@@ -836,6 +839,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
836
  do_sample=do_sample,
837
  temperature=temperature,
838
  top_p=top_p,
 
839
  **kwargs
840
  )
841
 
 
784
  do_sample: bool = True,
785
  temperature: float = 0.8,
786
  top_p: float = 0.8,
787
+ eos_token_id = (2, 103028),
788
  **kwargs):
789
  inputs = self.build_inputs(tokenizer, query, history)
790
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
 
794
  do_sample=do_sample,
795
  temperature=temperature,
796
  top_p=top_p,
797
+ eos_token_id=list(eos_token_id),
798
  **kwargs)
799
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
800
  response = tokenizer.decode(outputs, skip_special_tokens=True)
 
811
  do_sample: bool = True,
812
  temperature: float = 0.8,
813
  top_p: float = 0.8,
814
+ eos_token_id = (2, 103028),
815
  **kwargs):
816
  class ChatStreamer(BaseStreamer):
817
  def __init__(self, tokenizer) -> None:
 
839
  do_sample=do_sample,
840
  temperature=temperature,
841
  top_p=top_p,
842
+ eos_token_id=eos_token_id,
843
  **kwargs
844
  )
845