sharp commited on
Commit
1c2e9a9
1 Parent(s): 41aa7c9

add chat function

Browse files
Files changed (1) hide show
  1. modeling_orion.py +21 -0
modeling_orion.py CHANGED
@@ -30,6 +30,10 @@ from transformers.utils import (
30
  replace_return_docstrings,
31
  )
32
 
 
 
 
 
33
  if is_flash_attn_2_available():
34
  from flash_attn import flash_attn_func, flash_attn_varlen_func
35
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -951,6 +955,23 @@ class OrionForCausalLM(OrionPreTrainedModel):
951
  attentions=outputs.attentions,
952
  )
953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
954
  def prepare_inputs_for_generation(
955
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
956
  ):
 
30
  replace_return_docstrings,
31
  )
32
 
33
+ from .generation_utils import build_chat_input, TextIterStreamer
34
+ from transformers.generation.utils import GenerationConfig
35
+ from threading import Thread
36
+
37
  if is_flash_attn_2_available():
38
  from flash_attn import flash_attn_func, flash_attn_varlen_func
39
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
955
  attentions=outputs.attentions,
956
  )
957
 
958
+ def chat(self, tokenizer, messages: List[dict], streaming=False,generation_config: Optional[GenerationConfig]=None):
959
+ generation_config = generation_config or self.generation_config
960
+ input_tokens = build_chat_input(tokenizer,messages)
961
+ input_ids = torch.LongTensor([input_tokens]).to(self.device)
962
+
963
+ if streaming:
964
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
965
+ Thread(target=self.generate, kwargs=dict(
966
+ inputs=input_ids, streamer=streamer,
967
+ generation_config=generation_config,
968
+ )).start()
969
+ return streamer
970
+ else:
971
+ outputs = self.generate(input_ids, generation_config=generation_config)
972
+ response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
973
+ return response
974
+
975
  def prepare_inputs_for_generation(
976
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
977
  ):