zxdu20 commited on
Commit
42095d4
1 Parent(s): 220f772

Add support for streaming output

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +120 -42
modeling_chatglm.py CHANGED
@@ -3,7 +3,7 @@
3
  import math
4
  import copy
5
  import os
6
- import time
7
 
8
  import torch
9
  import torch.utils.checkpoint
@@ -11,7 +11,7 @@ import torch.nn.functional as F
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss, LayerNorm
13
  from torch.nn.utils import skip_init
14
- from typing import Optional, Tuple, Union, List
15
 
16
  from transformers.utils import (
17
  add_code_sample_docstrings,
@@ -26,7 +26,7 @@ from transformers.modeling_outputs import (
26
  from transformers.modeling_utils import PreTrainedModel
27
  from transformers.utils import logging
28
  from transformers.generation.logits_process import LogitsProcessor
29
- from transformers.generation.utils import LogitsProcessorList
30
 
31
  from .configuration_chatglm import ChatGLMConfig
32
 
@@ -1107,7 +1107,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1107
  input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1108
  input_ids = input_ids.to(self.device)
1109
  outputs = self.generate(**input_ids, **gen_kwargs)
1110
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
1111
  response = tokenizer.decode(outputs)
1112
  response = response.strip()
1113
  response = response.replace("[[训练时间]]", "2023年")
@@ -1115,55 +1115,133 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1115
  return response, history
1116
 
1117
  @torch.no_grad()
1118
- def generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1119
  self,
 
 
 
 
 
1120
  **kwargs,
1121
  ):
1122
- MASK, gMASK = 150000, 150001
1123
- bos, eos = 150004, 150005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1124
 
1125
- if "eos_token_id" not in kwargs:
1126
- kwargs["eos_token_id"] = eos
 
 
 
 
 
 
 
 
 
1127
 
1128
- stop = False
 
 
 
 
 
 
1129
 
1130
- return_seqs = []
 
 
 
1131
 
 
 
1132
  while True:
1133
- output_ids = super().generate(**kwargs)
1134
-
1135
- return_seqs = []
1136
- max_length = 0
1137
-
1138
- for i in range(output_ids.shape[0]):
1139
- output_seq = output_ids[i].tolist()
1140
- mask_token = MASK if MASK in output_seq else gMASK
1141
- mask_position = output_seq.index(mask_token)
1142
- bos_position = output_seq.index(bos)
1143
- if eos in output_seq:
1144
- eos_position = output_seq.index(eos)
1145
- else:
1146
- eos_position = len(output_seq)
1147
-
1148
- return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[
1149
- mask_position + 1:bos_position]
1150
- max_length = max(max_length, len(return_seq))
1151
- return_seqs.append(return_seq)
1152
-
1153
- for i in range(output_ids.shape[0]):
1154
- return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding
1155
- if mask_token not in return_seqs[i]:
1156
- stop = True
1157
-
1158
- if stop:
1159
- break
1160
 
1161
- for return_seq in return_seqs:
1162
- return_seq += [bos]
1163
 
1164
- kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
 
 
 
 
 
 
 
 
 
1165
 
1166
- return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
 
 
 
 
 
 
 
 
 
 
1167
 
1168
  def quantize(self, bits: int):
1169
  from .quantization import quantize
 
3
  import math
4
  import copy
5
  import os
6
+ import warnings
7
 
8
  import torch
9
  import torch.utils.checkpoint
 
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss, LayerNorm
13
  from torch.nn.utils import skip_init
14
+ from typing import Optional, Tuple, Union, List, Callable
15
 
16
  from transformers.utils import (
17
  add_code_sample_docstrings,
 
26
  from transformers.modeling_utils import PreTrainedModel
27
  from transformers.utils import logging
28
  from transformers.generation.logits_process import LogitsProcessor
29
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
30
 
31
  from .configuration_chatglm import ChatGLMConfig
32
 
 
1107
  input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1108
  input_ids = input_ids.to(self.device)
1109
  outputs = self.generate(**input_ids, **gen_kwargs)
1110
+ outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1111
  response = tokenizer.decode(outputs)
1112
  response = response.strip()
1113
  response = response.replace("[[训练时间]]", "2023年")
 
1115
  return response, history
1116
 
1117
  @torch.no_grad()
1118
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
1119
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1120
+ if history is None:
1121
+ history = []
1122
+ if logits_processor is None:
1123
+ logits_processor = LogitsProcessorList()
1124
+ logits_processor.append(InvalidScoreLogitsProcessor())
1125
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1126
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1127
+ if not history:
1128
+ prompt = query
1129
+ else:
1130
+ prompt = ""
1131
+ for i, (old_query, response) in enumerate(history):
1132
+ prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1133
+ prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1134
+ input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1135
+ input_ids = input_ids.to(self.device)
1136
+ for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1137
+ outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1138
+ response = tokenizer.decode(outputs)
1139
+ response = response.strip()
1140
+ response = response.replace("[[训练时间]]", "2023年")
1141
+ new_history = history + [(query, response)]
1142
+ yield response, new_history
1143
+
1144
+ @torch.no_grad()
1145
+ def stream_generate(
1146
  self,
1147
+ input_ids,
1148
+ generation_config: Optional[GenerationConfig] = None,
1149
+ logits_processor: Optional[LogitsProcessorList] = None,
1150
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1151
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1152
  **kwargs,
1153
  ):
1154
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1155
+
1156
+ if generation_config is None:
1157
+ generation_config = self.generation_config
1158
+ generation_config = copy.deepcopy(generation_config)
1159
+ model_kwargs = generation_config.update(**kwargs)
1160
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1161
+
1162
+ if isinstance(eos_token_id, int):
1163
+ eos_token_id = [eos_token_id]
1164
+
1165
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1166
+ if has_default_max_length and generation_config.max_new_tokens is None:
1167
+ warnings.warn(
1168
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1169
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1170
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1171
+ UserWarning,
1172
+ )
1173
+ elif generation_config.max_new_tokens is not None:
1174
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1175
+ if not has_default_max_length:
1176
+ logger.warn(
1177
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1178
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1179
+ "Please refer to the documentation for more information. "
1180
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1181
+ UserWarning,
1182
+ )
1183
 
1184
+ if input_ids_seq_length >= generation_config.max_length:
1185
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1186
+ logger.warning(
1187
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1188
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1189
+ " increasing `max_new_tokens`."
1190
+ )
1191
+
1192
+ # 2. Set generation parameters if not already defined
1193
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1194
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1195
 
1196
+ logits_processor = self._get_logits_processor(
1197
+ generation_config=generation_config,
1198
+ input_ids_seq_length=input_ids_seq_length,
1199
+ encoder_input_ids=input_ids,
1200
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1201
+ logits_processor=logits_processor,
1202
+ )
1203
 
1204
+ stopping_criteria = self._get_stopping_criteria(
1205
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1206
+ )
1207
+ logits_warper = self._get_logits_warper(generation_config)
1208
 
1209
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1210
+ scores = None
1211
  while True:
1212
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1213
+ # forward pass to get next token
1214
+ outputs = self(
1215
+ **model_inputs,
1216
+ return_dict=True,
1217
+ output_attentions=False,
1218
+ output_hidden_states=False,
1219
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1220
 
1221
+ next_token_logits = outputs.logits[:, -1, :]
 
1222
 
1223
+ # pre-process distribution
1224
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1225
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1226
+
1227
+ # sample
1228
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1229
+ if generation_config.do_sample:
1230
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1231
+ else:
1232
+ next_tokens = torch.argmax(probs, dim=-1)
1233
 
1234
+ # update generated ids, model inputs, and length for next step
1235
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1236
+ model_kwargs = self._update_model_kwargs_for_generation(
1237
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1238
+ )
1239
+ unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
1240
+
1241
+ # stop when each sentence is finished, or if we exceed the maximum length
1242
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1243
+ break
1244
+ yield input_ids
1245
 
1246
  def quantize(self, bits: int):
1247
  from .quantization import quantize