zR commited on
Commit
5dd1ab7
·
1 Parent(s): 99a1409
Files changed (4) hide show
  1. README.md +1 -1
  2. config.json +1 -1
  3. generation_config.json +1 -1
  4. modeling_chatglm.py +2 -202
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  license: other
3
  license_name: glm-4
4
- license_link: https://huggingface.co/THUDM/glm-4-9b/LICENSE
5
  language:
6
  - zh
7
  - en
 
1
  ---
2
  license: other
3
  license_name: glm-4
4
+ license_link: https://huggingface.co/THUDM/glm-4-9b/main/LICENSE
5
  language:
6
  - zh
7
  - en
config.json CHANGED
@@ -36,7 +36,7 @@
36
  "seq_length": 8192,
37
  "use_cache": true,
38
  "torch_dtype": "bfloat16",
39
- "transformers_version": "4.40.2",
40
  "tie_word_embeddings": false,
41
  "eos_token_id": [151329, 151336, 151338],
42
  "pad_token_id": 151329
 
36
  "seq_length": 8192,
37
  "use_cache": true,
38
  "torch_dtype": "bfloat16",
39
+ "transformers_version": "4.42.4",
40
  "tie_word_embeddings": false,
41
  "eos_token_id": [151329, 151336, 151338],
42
  "pad_token_id": 151329
generation_config.json CHANGED
@@ -9,5 +9,5 @@
9
  "temperature": 0.8,
10
  "max_length": 8192,
11
  "top_p": 0.8,
12
- "transformers_version": "4.40.2"
13
  }
 
9
  "temperature": 0.8,
10
  "max_length": 8192,
11
  "top_p": 0.8,
12
+ "transformers_version": "4.42.4"
13
  }
modeling_chatglm.py CHANGED
@@ -797,11 +797,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
797
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
798
  return position_ids
799
 
800
- def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
801
- if not self.supports_gradient_checkpointing:
802
- raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
803
-
804
-
805
  class Embedding(torch.nn.Module):
806
  """Language model embeddings."""
807
 
@@ -936,9 +931,10 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
936
  standardize_cache_format: bool = False,
937
  ) -> Dict[str, Any]:
938
  # update past_key_values
939
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
940
  outputs, standardize_cache_format=standardize_cache_format
941
  )
 
942
 
943
  # update attention mask
944
  if "attention_mask" in model_kwargs:
@@ -1063,202 +1059,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1063
  for layer_past in past
1064
  )
1065
 
1066
- def process_response(self, output, history):
1067
- content = ""
1068
- history = deepcopy(history)
1069
- for response in output.split("<|assistant|>"):
1070
- if "\n" in response:
1071
- metadata, content = response.split("\n", maxsplit=1)
1072
- else:
1073
- metadata, content = "", response
1074
- if not metadata.strip():
1075
- content = content.strip()
1076
- history.append({"role": "assistant", "metadata": metadata, "content": content})
1077
- content = content.replace("[[训练时间]]", "2023年")
1078
- else:
1079
- history.append({"role": "assistant", "metadata": metadata, "content": content})
1080
- if history[0]["role"] == "system" and "tools" in history[0]:
1081
- parameters = json.loads(content)
1082
- content = {"name": metadata.strip(), "parameters": parameters}
1083
- else:
1084
- content = {"name": metadata.strip(), "content": content}
1085
- return content, history
1086
-
1087
- @torch.inference_mode()
1088
- def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1089
- max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1090
- **kwargs):
1091
- if history is None:
1092
- history = []
1093
- if logits_processor is None:
1094
- logits_processor = LogitsProcessorList()
1095
- logits_processor.append(InvalidScoreLogitsProcessor())
1096
- gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1097
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1098
- history.append({"role": role, "content": query})
1099
- inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,
1100
- return_tensors="pt", return_dict=True)
1101
- inputs = inputs.to(self.device)
1102
- eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
1103
- tokenizer.convert_tokens_to_ids("<|observation|>")]
1104
- outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1105
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1106
- response = tokenizer.decode(outputs)
1107
- response, history = self.process_response(response, history)
1108
- return response, history
1109
-
1110
- @torch.inference_mode()
1111
- def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1112
- past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1113
- logits_processor=None, return_past_key_values=False, **kwargs):
1114
- if history is None:
1115
- history = []
1116
- if logits_processor is None:
1117
- logits_processor = LogitsProcessorList()
1118
- logits_processor.append(InvalidScoreLogitsProcessor())
1119
- eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"),
1120
- tokenizer.convert_tokens_to_ids("<|observation|>")]
1121
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1122
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1123
- if past_key_values is None:
1124
- inputs = tokenizer.apply_chat_template(history + [{"role": role, "content": query}],
1125
- add_generation_prompt=True, tokenize=True, return_tensors="pt",
1126
- return_dict=True)
1127
- else:
1128
- inputs = tokenizer.apply_chat_template([{"role": role, "content": query}], add_special_tokens=False,
1129
- add_generation_prompt=True, tokenize=True, return_tensors="pt",
1130
- return_dict=True)
1131
- inputs = inputs.to(self.device)
1132
- if past_key_values is not None:
1133
- past_length = past_key_values[0][0].shape[2]
1134
- inputs.position_ids += past_length
1135
- attention_mask = inputs.attention_mask
1136
- attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1137
- inputs['attention_mask'] = attention_mask
1138
- history.append({"role": role, "content": query})
1139
- for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1140
- eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1141
- **gen_kwargs):
1142
- if return_past_key_values:
1143
- outputs, past_key_values = outputs
1144
- outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1145
- response = tokenizer.decode(outputs)
1146
- if response and response[-1] != "�":
1147
- response, new_history = self.process_response(response, history)
1148
- if return_past_key_values:
1149
- yield response, new_history, past_key_values
1150
- else:
1151
- yield response, new_history
1152
-
1153
- @torch.inference_mode()
1154
- def stream_generate(
1155
- self,
1156
- input_ids,
1157
- generation_config: Optional[GenerationConfig] = None,
1158
- logits_processor: Optional[LogitsProcessorList] = None,
1159
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1160
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1161
- return_past_key_values=False,
1162
- **kwargs,
1163
- ):
1164
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1165
-
1166
- if generation_config is None:
1167
- generation_config = self.generation_config
1168
- generation_config = copy.deepcopy(generation_config)
1169
- model_kwargs = generation_config.update(**kwargs)
1170
- model_kwargs["use_cache"] = generation_config.use_cache
1171
- bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1172
-
1173
- if isinstance(eos_token_id, int):
1174
- eos_token_id = [eos_token_id]
1175
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1176
-
1177
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1178
- if has_default_max_length and generation_config.max_new_tokens is None:
1179
- warnings.warn(
1180
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1181
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1182
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
1183
- UserWarning,
1184
- )
1185
- elif generation_config.max_new_tokens is not None:
1186
- generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1187
- if not has_default_max_length:
1188
- logger.warn(
1189
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1190
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1191
- "Please refer to the documentation for more information. "
1192
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1193
- UserWarning,
1194
- )
1195
-
1196
- if input_ids_seq_length >= generation_config.max_length:
1197
- input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1198
- logger.warning(
1199
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1200
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1201
- " increasing `max_new_tokens`."
1202
- )
1203
-
1204
- # 2. Set generation parameters if not already defined
1205
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1206
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1207
-
1208
- logits_processor = self._get_logits_processor(
1209
- generation_config=generation_config,
1210
- input_ids_seq_length=input_ids_seq_length,
1211
- encoder_input_ids=input_ids,
1212
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1213
- logits_processor=logits_processor,
1214
- )
1215
-
1216
- stopping_criteria = self._get_stopping_criteria(
1217
- generation_config=generation_config, stopping_criteria=stopping_criteria
1218
- )
1219
- logits_warper = self._get_logits_warper(generation_config)
1220
-
1221
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1222
- scores = None
1223
- while True:
1224
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1225
- # forward pass to get next token
1226
- outputs = self(
1227
- **model_inputs,
1228
- return_dict=True,
1229
- output_attentions=False,
1230
- output_hidden_states=False,
1231
- )
1232
-
1233
- next_token_logits = outputs.logits[:, -1, :]
1234
-
1235
- # pre-process distribution
1236
- next_token_scores = logits_processor(input_ids, next_token_logits)
1237
- next_token_scores = logits_warper(input_ids, next_token_scores)
1238
-
1239
- # sample
1240
- probs = nn.functional.softmax(next_token_scores, dim=-1)
1241
- if generation_config.do_sample:
1242
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1243
- else:
1244
- next_tokens = torch.argmax(probs, dim=-1)
1245
- # update generated ids, model inputs, and length for next step
1246
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1247
- model_kwargs = self._update_model_kwargs_for_generation(
1248
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1249
- )
1250
- unfinished_sequences = unfinished_sequences.mul(
1251
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1252
- )
1253
- if return_past_key_values:
1254
- yield input_ids, outputs.past_key_values
1255
- else:
1256
- yield input_ids
1257
- # stop when each sentence is finished, or if we exceed the maximum length
1258
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1259
- break
1260
-
1261
-
1262
  class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1263
  def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1264
  super().__init__(config)
 
797
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
798
  return position_ids
799
 
 
 
 
 
 
800
  class Embedding(torch.nn.Module):
801
  """Language model embeddings."""
802
 
 
931
  standardize_cache_format: bool = False,
932
  ) -> Dict[str, Any]:
933
  # update past_key_values
934
+ cache_name, cache = self._extract_past_from_model_output(
935
  outputs, standardize_cache_format=standardize_cache_format
936
  )
937
+ model_kwargs[cache_name] = cache
938
 
939
  # update attention mask
940
  if "attention_mask" in model_kwargs:
 
1059
  for layer_past in past
1060
  )
1061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1062
  class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1063
  def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1064
  super().__init__(config)