GuoPD commited on
Commit
a4a5581
1 Parent(s): 19ef51b

add: update generation code

Browse files
Files changed (3) hide show
  1. generation_utils.py +82 -0
  2. modeling_baichuan.py +9 -42
  3. requirements.txt +0 -1
generation_utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from queue import Queue
3
+
4
+ import torch
5
+
6
+
7
+ def build_chat_input(model, tokenizer, messages: List[dict], max_new_tokens: int=0):
8
+ def _parse_messages(messages, split_role="user"):
9
+ system, rounds = "", []
10
+ round = []
11
+ for i, message in enumerate(messages):
12
+ if message["role"] == "system":
13
+ assert i == 0
14
+ system = message["content"]
15
+ continue
16
+ if message["role"] == split_role and round:
17
+ rounds.append(round)
18
+ round = []
19
+ round.append(message)
20
+ if round:
21
+ rounds.append(round)
22
+ return system, rounds
23
+
24
+ max_new_tokens = max_new_tokens or model.generation_config.max_new_tokens
25
+ max_input_tokens = model.config.model_max_length - max_new_tokens
26
+ system, rounds = _parse_messages(messages, split_role="user")
27
+ system_tokens = tokenizer.encode(system)
28
+ max_history_tokens = max_input_tokens - len(system_tokens)
29
+
30
+ history_tokens = []
31
+ for round in rounds[::-1]:
32
+ round_tokens = []
33
+ for message in round:
34
+ if message["role"] == "user":
35
+ round_tokens.append(model.generation_config.user_token_id)
36
+ else:
37
+ round_tokens.append(model.generation_config.assistant_token_id)
38
+ round_tokens.extend(tokenizer.encode(message["content"]))
39
+ if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens:
40
+ history_tokens = round_tokens + history_tokens # concat left
41
+ if len(history_tokens) < max_history_tokens:
42
+ continue
43
+ break
44
+
45
+ input_tokens = system_tokens + history_tokens
46
+ if messages[-1]["role"] != "assistant":
47
+ input_tokens.append(model.generation_config.assistant_token_id)
48
+ input_tokens = input_tokens[-max_input_tokens:] # truncate left
49
+ return torch.LongTensor([input_tokens]).to(model.device)
50
+
51
+
52
+ class TextIterStreamer:
53
+ def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False):
54
+ self.tokenizer = tokenizer
55
+ self.skip_prompt = skip_prompt
56
+ self.skip_special_tokens = skip_special_tokens
57
+ self.tokens = []
58
+ self.text_queue = Queue()
59
+ self.next_tokens_are_prompt = True
60
+
61
+ def put(self, value):
62
+ if self.skip_prompt and self.next_tokens_are_prompt:
63
+ self.next_tokens_are_prompt = False
64
+ else:
65
+ if len(value.shape) > 1:
66
+ value = value[0]
67
+ self.tokens.extend(value.tolist())
68
+ self.text_queue.put(
69
+ self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens))
70
+
71
+ def end(self):
72
+ self.text_queue.put(None)
73
+
74
+ def __iter__(self):
75
+ return self
76
+
77
+ def __next__(self):
78
+ value = self.text_queue.get()
79
+ if value is None:
80
+ raise StopIteration()
81
+ else:
82
+ return value
modeling_baichuan.py CHANGED
@@ -1,6 +1,7 @@
1
  # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
2
 
3
  import math
 
4
  from typing import List, Optional, Tuple, Union
5
 
6
  import torch
@@ -13,6 +14,7 @@ from transformers.utils import logging
13
  from transformers.generation.utils import GenerationConfig
14
 
15
  from .configuration_baichuan import BaichuanConfig
 
16
 
17
  logger = logging.get_logger(__name__)
18
 
@@ -552,54 +554,19 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
552
  )
553
  return self
554
 
555
- def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
556
- max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
557
- max_input_tokens = self.config.model_max_length - max_new_tokens
558
- max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
559
- total_input, round_input = [], []
560
- for i, message in enumerate(messages[::-1]):
561
- content_tokens = tokenizer.encode(message['content'])
562
- if message['role'] == 'user':
563
- round_input = [self.generation_config.user_token_id] + content_tokens + round_input
564
- if total_input and len(total_input) + len(round_input) > max_input_tokens:
565
- break
566
- else:
567
- total_input = round_input + total_input
568
- if len(total_input) >= max_input_tokens:
569
- break
570
- else:
571
- round_input = []
572
- elif message['role'] == 'assistant':
573
- round_input = [
574
- self.generation_config.assistant_token_id
575
- ] + content_tokens + round_input
576
- else:
577
- raise ValueError(f"message role not supported yet: {message['role']}")
578
- total_input = total_input[-max_input_tokens:] # truncate left
579
- total_input.append(self.generation_config.assistant_token_id)
580
- total_input = torch.LongTensor([total_input]).to(self.device)
581
- return total_input
582
-
583
  @torch.no_grad()
584
  def chat(self, tokenizer, messages: List[dict], stream=False,
585
  generation_config: Optional[GenerationConfig]=None):
586
  generation_config = generation_config or self.generation_config
587
- input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
588
  if stream:
589
- from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
590
- self.__class__.generate = NewGenerationMixin.generate
591
- self.__class__.sample_stream = NewGenerationMixin.sample_stream
592
- stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
593
-
594
- def stream_generator():
595
- outputs = []
596
- for token in self.generate(input_ids, generation_config=stream_config):
597
- outputs.append(token.item())
598
- yield tokenizer.decode(outputs, skip_special_tokens=True)
599
-
600
- return stream_generator()
601
  else:
602
- self.__class__.generate = PreTrainedModel.generate # disable stream
603
  outputs = self.generate(input_ids, generation_config=generation_config)
604
  response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
605
  return response
 
1
  # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
2
 
3
  import math
4
+ from threading import Thread
5
  from typing import List, Optional, Tuple, Union
6
 
7
  import torch
 
14
  from transformers.generation.utils import GenerationConfig
15
 
16
  from .configuration_baichuan import BaichuanConfig
17
+ from .generation_utils import build_chat_input, TextIterStreamer
18
 
19
  logger = logging.get_logger(__name__)
20
 
 
554
  )
555
  return self
556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  @torch.no_grad()
558
  def chat(self, tokenizer, messages: List[dict], stream=False,
559
  generation_config: Optional[GenerationConfig]=None):
560
  generation_config = generation_config or self.generation_config
561
+ input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens)
562
  if stream:
563
+ streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
564
+ Thread(target=self.generate, kwargs=dict(
565
+ inputs=input_ids, streamer=streamer,
566
+ generation_config=generation_config,
567
+ )).start()
568
+ return streamer
 
 
 
 
 
 
569
  else:
 
570
  outputs = self.generate(input_ids, generation_config=generation_config)
571
  response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
572
  return response
requirements.txt CHANGED
@@ -3,4 +3,3 @@ colorama
3
  cpm_kernels
4
  sentencepiece
5
  streamlit
6
- transformers_stream_generator
 
3
  cpm_kernels
4
  sentencepiece
5
  streamlit