x54-729 commited on
Commit
2667fa7
1 Parent(s): bfbab49

Fix streaming_chat

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +53 -21
modeling_internlm.py CHANGED
@@ -20,6 +20,7 @@
20
  """ PyTorch InternLM model."""
21
  import math
22
  from typing import List, Optional, Tuple, Union
 
23
 
24
  import torch
25
  import torch.utils.checkpoint
@@ -784,7 +785,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
784
  do_sample: bool = True,
785
  temperature: float = 0.8,
786
  top_p: float = 0.8,
787
- eos_token_id = (2, 103028),
788
  **kwargs):
789
  inputs = self.build_inputs(tokenizer, query, history)
790
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
@@ -794,7 +794,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
794
  do_sample=do_sample,
795
  temperature=temperature,
796
  top_p=top_p,
797
- eos_token_id=list(eos_token_id),
798
  **kwargs)
799
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
800
  response = tokenizer.decode(outputs, skip_special_tokens=True)
@@ -811,38 +810,71 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
811
  do_sample: bool = True,
812
  temperature: float = 0.8,
813
  top_p: float = 0.8,
814
- eos_token_id = (2, 103028),
815
  **kwargs):
 
 
 
 
 
 
 
 
 
816
  class ChatStreamer(BaseStreamer):
817
  def __init__(self, tokenizer) -> None:
818
  super().__init__()
819
  self.tokenizer = tokenizer
820
-
 
 
 
 
 
 
821
  def put(self, value):
822
  if len(value.shape) > 1 and value.shape[0] > 1:
823
  raise ValueError("ChatStreamer only supports batch size 1")
824
  elif len(value.shape) > 1:
825
  value = value[0]
 
 
 
 
 
 
826
  token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
827
  if token.strip() != "<eoa>":
828
- print(token, end="")
829
-
 
 
830
  def end(self):
831
- print("")
832
-
833
- return self.chat(
834
- tokenizer=tokenizer,
835
- query=query,
836
- streamer=ChatStreamer(tokenizer=tokenizer),
837
- history=history,
838
- max_new_tokens=max_new_tokens,
839
- do_sample=do_sample,
840
- temperature=temperature,
841
- top_p=top_p,
842
- eos_token_id=eos_token_id,
843
- **kwargs
844
- )
845
-
 
 
 
 
 
 
 
 
 
 
 
846
 
847
  @add_start_docstrings(
848
  """
 
20
  """ PyTorch InternLM model."""
21
  import math
22
  from typing import List, Optional, Tuple, Union
23
+ import threading, queue
24
 
25
  import torch
26
  import torch.utils.checkpoint
 
785
  do_sample: bool = True,
786
  temperature: float = 0.8,
787
  top_p: float = 0.8,
 
788
  **kwargs):
789
  inputs = self.build_inputs(tokenizer, query, history)
790
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
 
794
  do_sample=do_sample,
795
  temperature=temperature,
796
  top_p=top_p,
 
797
  **kwargs)
798
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
799
  response = tokenizer.decode(outputs, skip_special_tokens=True)
 
810
  do_sample: bool = True,
811
  temperature: float = 0.8,
812
  top_p: float = 0.8,
 
813
  **kwargs):
814
+ """
815
+ Return a generator in format: (response, history)
816
+ Eg.
817
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
818
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
819
+ """
820
+
821
+ response_queue = queue.Queue(maxsize=20)
822
+
823
  class ChatStreamer(BaseStreamer):
824
  def __init__(self, tokenizer) -> None:
825
  super().__init__()
826
  self.tokenizer = tokenizer
827
+ self.queue = response_queue
828
+ self.query = query
829
+ self.history = history
830
+ self.response = ""
831
+ self.received_inputs = False
832
+ self.queue.put((self.response, history + [(self.query, self.response)]))
833
+
834
  def put(self, value):
835
  if len(value.shape) > 1 and value.shape[0] > 1:
836
  raise ValueError("ChatStreamer only supports batch size 1")
837
  elif len(value.shape) > 1:
838
  value = value[0]
839
+
840
+ if not self.received_inputs:
841
+ # The first received value is input_ids, ignore here
842
+ self.received_inputs = True
843
+ return
844
+
845
  token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
846
  if token.strip() != "<eoa>":
847
+ self.response = self.response + token
848
+ history = self.history + [(self.query, self.response)]
849
+ self.queue.put((self.response, history))
850
+
851
  def end(self):
852
+ self.queue.put(None)
853
+
854
+ def stream_producer():
855
+ return self.chat(
856
+ tokenizer=tokenizer,
857
+ query=query,
858
+ streamer=ChatStreamer(tokenizer=tokenizer),
859
+ history=history,
860
+ max_new_tokens=max_new_tokens,
861
+ do_sample=do_sample,
862
+ temperature=temperature,
863
+ top_p=top_p,
864
+ **kwargs
865
+ )
866
+
867
+ def consumer():
868
+ producer = threading.Thread(target=stream_producer)
869
+ producer.start()
870
+ while True:
871
+ res = response_queue.get()
872
+ if res is None:
873
+ return
874
+ yield res
875
+
876
+ return consumer()
877
+
878
 
879
  @add_start_docstrings(
880
  """