support multi gpu and remove some bug

#16
Files changed (1) hide show
  1. modeling_chatglm.py +252 -307
modeling_chatglm.py CHANGED
@@ -3,8 +3,6 @@
3
  import math
4
  import copy
5
  import os
6
- import warnings
7
- import re
8
  import sys
9
 
10
  import torch
@@ -13,7 +11,7 @@ import torch.nn.functional as F
13
  from torch import nn
14
  from torch.nn import CrossEntropyLoss, LayerNorm
15
  from torch.nn.utils import skip_init
16
- from typing import Optional, Tuple, Union, List, Callable
17
 
18
  from transformers.utils import (
19
  add_code_sample_docstrings,
@@ -26,20 +24,17 @@ from transformers.modeling_outputs import (
26
  BaseModelOutputWithPastAndCrossAttentions,
27
  )
28
  from transformers.modeling_utils import PreTrainedModel
29
- from transformers.utils import logging
30
- from transformers.generation.logits_process import LogitsProcessor
31
- from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
32
 
 
33
  from .configuration_chatglm import ChatGLMConfig
34
 
35
- # flags required to enable jit fusion kernels
36
-
37
  if sys.platform != 'darwin':
38
  torch._C._jit_set_profiling_mode(False)
39
  torch._C._jit_set_profiling_executor(False)
40
  torch._C._jit_override_can_fuse_on_cpu(True)
41
  torch._C._jit_override_can_fuse_on_gpu(True)
42
 
 
43
  logger = logging.get_logger(__name__)
44
 
45
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
@@ -51,14 +46,6 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
51
  ]
52
 
53
 
54
- class InvalidScoreLogitsProcessor(LogitsProcessor):
55
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56
- if torch.isnan(scores).any() or torch.isinf(scores).any():
57
- scores.zero_()
58
- scores[..., 20005] = 5e4
59
- return scores
60
-
61
-
62
  def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
63
  """Load tf checkpoints in a pytorch model."""
64
  try:
@@ -153,7 +140,7 @@ class RotaryEmbedding(torch.nn.Module):
153
  if learnable:
154
  self.inv_freq = torch.nn.Parameter(inv_freq)
155
  self.max_seq_len_cached = None
156
- else:
157
  self.register_buffer('inv_freq', inv_freq)
158
  self.max_seq_len_cached = None
159
  self.cos_cached = None
@@ -169,22 +156,24 @@ class RotaryEmbedding(torch.nn.Module):
169
  seq_len = x.shape[seq_dim]
170
  if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
171
  self.max_seq_len_cached = None if self.learnable else seq_len
172
- t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
173
- freqs = torch.einsum('i,j->ij', t, self.inv_freq)
174
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
175
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
176
- if self.precision == torch.bfloat16:
177
- emb = emb.float()
178
-
179
- # [sx, 1 (b * np), hn]
180
- cos_cached = emb.cos()[:, None, :]
181
- sin_cached = emb.sin()[:, None, :]
182
- if self.precision == torch.bfloat16:
183
- cos_cached = cos_cached.bfloat16()
184
- sin_cached = sin_cached.bfloat16()
185
- if self.learnable:
186
- return cos_cached, sin_cached
187
- self.cos_cached, self.sin_cached = cos_cached, sin_cached
 
 
188
  return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
189
 
190
 
@@ -202,114 +191,6 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
202
  return q, k
203
 
204
 
205
- def attention_fn(
206
- self,
207
- query_layer,
208
- key_layer,
209
- value_layer,
210
- attention_mask,
211
- hidden_size_per_partition,
212
- layer_id,
213
- layer_past=None,
214
- scaling_attention_score=True,
215
- use_cache=False,
216
- ):
217
- if layer_past is not None:
218
- past_key, past_value = layer_past
219
- key_layer = torch.cat((past_key, key_layer), dim=0)
220
- value_layer = torch.cat((past_value, value_layer), dim=0)
221
-
222
- # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
223
- seq_len, b, nh, hidden_size = key_layer.shape
224
-
225
- if use_cache:
226
- present = (key_layer, value_layer)
227
- else:
228
- present = None
229
-
230
- query_key_layer_scaling_coeff = float(layer_id + 1)
231
- if scaling_attention_score:
232
- query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
233
-
234
- # ===================================
235
- # Raw attention scores. [b, np, s, s]
236
- # ===================================
237
-
238
- # [b, np, sq, sk]
239
- output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
240
-
241
- # [sq, b, np, hn] -> [sq, b * np, hn]
242
- query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
243
- # [sk, b, np, hn] -> [sk, b * np, hn]
244
- key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
245
-
246
- matmul_result = torch.empty(
247
- output_size[0] * output_size[1],
248
- output_size[2],
249
- output_size[3],
250
- dtype=query_layer.dtype,
251
- device=query_layer.device,
252
- )
253
-
254
- matmul_result = torch.baddbmm(
255
- matmul_result,
256
- query_layer.transpose(0, 1), # [b * np, sq, hn]
257
- key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
258
- beta=0.0,
259
- alpha=1.0,
260
- )
261
-
262
- # change view to [b, np, sq, sk]
263
- attention_scores = matmul_result.view(*output_size)
264
-
265
- if self.scale_mask_softmax:
266
- self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
267
- attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
268
- else:
269
- if not (attention_mask == 0).all():
270
- # if auto-regressive, skip
271
- attention_scores.masked_fill_(attention_mask, -10000.0)
272
- dtype = attention_scores.dtype
273
- attention_scores = attention_scores.float()
274
- attention_scores = attention_scores * query_key_layer_scaling_coeff
275
-
276
- attention_probs = F.softmax(attention_scores, dim=-1)
277
-
278
- attention_probs = attention_probs.type(dtype)
279
-
280
- # =========================
281
- # Context layer. [sq, b, hp]
282
- # =========================
283
-
284
- # value_layer -> context layer.
285
- # [sk, b, np, hn] --> [b, np, sq, hn]
286
-
287
- # context layer shape: [b, np, sq, hn]
288
- output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
289
-
290
- # change view [sk, b * np, hn]
291
- value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
292
-
293
- # change view [b * np, sq, sk]
294
- attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
295
-
296
- # matmul: [b * np, sq, hn]
297
- context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
298
-
299
- # change view [b, np, sq, hn]
300
- context_layer = context_layer.view(*output_size)
301
-
302
- # [b, np, sq, hn] --> [sq, b, np, hn]
303
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
304
-
305
- # [sq, b, np, hn] --> [sq, b, hp]
306
- new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
307
- context_layer = context_layer.view(*new_context_layer_shape)
308
-
309
- outputs = (context_layer, present, attention_probs)
310
-
311
- return outputs
312
-
313
 
314
  class SelfAttention(torch.nn.Module):
315
  def __init__(self, hidden_size, num_attention_heads,
@@ -399,7 +280,7 @@ class SelfAttention(torch.nn.Module):
399
  """
400
 
401
  # [seq_len, batch, 3 * hidden_size]
402
- mixed_raw_layer = self.query_key_value(hidden_states)
403
 
404
  # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
405
  new_tensor_shape = mixed_raw_layer.size()[:-1] + (
@@ -414,6 +295,7 @@ class SelfAttention(torch.nn.Module):
414
  if self.position_encoding_2d:
415
  q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
416
  k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
 
417
  cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
418
  position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
419
  position_ids[:, 1, :].transpose(0, 1).contiguous()
@@ -423,22 +305,25 @@ class SelfAttention(torch.nn.Module):
423
  key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
424
  else:
425
  position_ids = position_ids.transpose(0, 1)
 
426
  cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
427
  # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
428
  query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)
429
 
430
  # [seq_len, batch, hidden_size]
431
- context_layer, present, attention_probs = attention_fn(
432
- self=self,
433
  query_layer=query_layer,
434
  key_layer=key_layer,
435
  value_layer=value_layer,
436
- attention_mask=attention_mask,
437
  hidden_size_per_partition=self.hidden_size_per_partition,
438
  layer_id=layer_id,
439
  layer_past=layer_past,
440
  use_cache=use_cache
441
  )
 
 
 
442
 
443
  output = self.dense(context_layer)
444
 
@@ -449,6 +334,118 @@ class SelfAttention(torch.nn.Module):
449
 
450
  return outputs # output, present, attention_probs
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
  class GEGLU(torch.nn.Module):
454
  def __init__(self):
@@ -614,7 +611,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
614
  a simple interface for downloading and loading pretrained models.
615
  """
616
 
617
- is_parallelizable = False
 
618
  supports_gradient_checkpointing = False
619
  config_class = ChatGLMConfig
620
  base_model_prefix = "transformer"
@@ -724,6 +722,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
724
  self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
725
  self.position_encoding_2d = config.position_encoding_2d
726
 
 
 
727
  self.word_embeddings = skip_init(
728
  torch.nn.Embedding,
729
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
@@ -757,8 +757,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
757
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
758
  self.word_embeddings = new_embeddings
759
 
760
- def get_masks(self, seq, device):
761
- context_length = seq.index(self.config.bos_token_id) + 1
 
762
 
763
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
764
  attention_mask.tril_()
@@ -769,9 +770,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
769
  return attention_mask
770
 
771
  def get_position_ids(self, seq, mask_position, device, gmask=False):
772
- context_length = seq.index(self.config.bos_token_id) + 1
773
  if self.position_encoding_2d:
774
- seq_length = seq.index(self.config.bos_token_id)
775
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
776
  if not gmask:
777
  position_ids[seq_length:] = mask_position
@@ -826,8 +827,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
826
 
827
  if past_key_values is None:
828
  past_key_values = tuple([None] * len(self.layers))
 
 
 
 
829
  seq = input_ids[0].tolist()
830
 
 
 
831
  if attention_mask is None:
832
  attention_mask = self.get_masks(
833
  seq=seq,
@@ -835,11 +842,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
835
  )
836
 
837
  if position_ids is None:
838
- MASK, gMASK = 150000, 150001
839
- mask_token = MASK if MASK in input_ids else gMASK
840
- use_gmask = False if MASK in input_ids else gMASK
841
-
842
- mask_position = seq.index(mask_token)
843
  position_ids = self.get_position_ids(
844
  seq=seq,
845
  mask_position=mask_position,
@@ -848,15 +850,28 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
848
  )
849
 
850
  if inputs_embeds is None:
851
- inputs_embeds = self.word_embeddings(input_ids)
 
 
 
852
 
853
  # [seq_len, batch, hidden_size]
854
  hidden_states = inputs_embeds.transpose(0, 1)
855
 
 
 
 
 
 
 
 
 
856
  presents = () if use_cache else None
857
  all_self_attentions = () if output_attentions else None
858
  all_hidden_states = () if output_hidden_states else None
859
 
 
 
860
  seq_length_with_past = seq_length
861
  past_key_values_length = 0
862
  if past_key_values[0] is not None:
@@ -873,15 +888,39 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
873
  if output_hidden_states:
874
  all_hidden_states = all_hidden_states + (hidden_states,)
875
 
876
- layer_ret = layer(
877
- hidden_states,
878
- position_ids=position_ids,
879
- attention_mask=attention_mask,
880
- layer_id=torch.tensor(i),
881
- layer_past=past_key_values[i],
882
- use_cache=use_cache,
883
- output_attentions=output_attentions
884
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
 
886
  hidden_states = layer_ret[0]
887
 
@@ -928,6 +967,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
928
  bias=False,
929
  dtype=torch.half
930
  )
 
931
 
932
  def get_output_embeddings(self):
933
  return self.lm_head
@@ -943,7 +983,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
943
  attention_mask = (attention_mask < 0.5).bool()
944
 
945
  if self.position_encoding_2d:
946
- seq_length = seq.index(self.config.bos_token_id)
947
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
948
  if not gmask:
949
  position_ids[seq_length:] = mask_position
@@ -981,7 +1021,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
981
 
982
  # only last token for input_ids if past is not None
983
  if past is not None or past_key_values is not None:
984
- context_length = seq.index(self.config.bos_token_id)
985
  last_token = input_ids[:, -1].unsqueeze(-1)
986
  if self.position_encoding_2d:
987
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
@@ -1053,10 +1093,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1053
  shift_labels = labels[..., 1:].contiguous()
1054
  # Flatten the tokens
1055
  loss_fct = CrossEntropyLoss()
1056
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1057
 
1058
- lm_logits = lm_logits.to(hidden_states.dtype)
1059
- loss = loss.to(hidden_states.dtype)
1060
 
1061
  if not return_dict:
1062
  output = (lm_logits,) + transformer_outputs[1:]
@@ -1089,31 +1127,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1089
  for layer_past in past
1090
  )
1091
 
1092
- def process_response(self, response):
1093
- response = response.strip()
1094
- response = response.replace("[[训练时间]]", "2023年")
1095
- punkts = [
1096
- [",", ","],
1097
- ["!", "!"],
1098
- [":", ":"],
1099
- [";", ";"],
1100
- ["\?", "?"],
1101
- ]
1102
- for item in punkts:
1103
- response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
1104
- response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
1105
- return response
1106
-
1107
  @torch.no_grad()
1108
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1109
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1110
  if history is None:
1111
  history = []
1112
- if logits_processor is None:
1113
- logits_processor = LogitsProcessorList()
1114
- logits_processor.append(InvalidScoreLogitsProcessor())
1115
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1116
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1117
  if not history:
1118
  prompt = query
1119
  else:
@@ -1124,139 +1144,64 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1124
  input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1125
  input_ids = input_ids.to(self.device)
1126
  outputs = self.generate(**input_ids, **gen_kwargs)
1127
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1128
  response = tokenizer.decode(outputs)
1129
- response = self.process_response(response)
 
1130
  history = history + [(query, response)]
1131
  return response, history
1132
 
1133
  @torch.no_grad()
1134
- def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
1135
- do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1136
- if history is None:
1137
- history = []
1138
- if logits_processor is None:
1139
- logits_processor = LogitsProcessorList()
1140
- logits_processor.append(InvalidScoreLogitsProcessor())
1141
- gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1142
- "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1143
- if not history:
1144
- prompt = query
1145
- else:
1146
- prompt = ""
1147
- for i, (old_query, response) in enumerate(history):
1148
- prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1149
- prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1150
- input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1151
- input_ids = input_ids.to(self.device)
1152
- for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1153
- outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1154
- response = tokenizer.decode(outputs)
1155
- response = self.process_response(response)
1156
- new_history = history + [(query, response)]
1157
- yield response, new_history
1158
-
1159
- @torch.no_grad()
1160
- def stream_generate(
1161
  self,
1162
- input_ids,
1163
- generation_config: Optional[GenerationConfig] = None,
1164
- logits_processor: Optional[LogitsProcessorList] = None,
1165
- stopping_criteria: Optional[StoppingCriteriaList] = None,
1166
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1167
  **kwargs,
1168
  ):
1169
- batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1170
-
1171
- if generation_config is None:
1172
- generation_config = self.generation_config
1173
- generation_config = copy.deepcopy(generation_config)
1174
- model_kwargs = generation_config.update(**kwargs)
1175
- bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1176
-
1177
- if isinstance(eos_token_id, int):
1178
- eos_token_id = [eos_token_id]
1179
-
1180
- has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1181
- if has_default_max_length and generation_config.max_new_tokens is None:
1182
- warnings.warn(
1183
- f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1184
- "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1185
- " recommend using `max_new_tokens` to control the maximum length of the generation.",
1186
- UserWarning,
1187
- )
1188
- elif generation_config.max_new_tokens is not None:
1189
- generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1190
- if not has_default_max_length:
1191
- logger.warn(
1192
- f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1193
- f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1194
- "Please refer to the documentation for more information. "
1195
- "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1196
- UserWarning,
1197
- )
1198
-
1199
- if input_ids_seq_length >= generation_config.max_length:
1200
- input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1201
- logger.warning(
1202
- f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1203
- f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1204
- " increasing `max_new_tokens`."
1205
- )
1206
 
1207
- # 2. Set generation parameters if not already defined
1208
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1209
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1210
 
1211
- logits_processor = self._get_logits_processor(
1212
- generation_config=generation_config,
1213
- input_ids_seq_length=input_ids_seq_length,
1214
- encoder_input_ids=input_ids,
1215
- prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1216
- logits_processor=logits_processor,
1217
- )
1218
 
1219
- stopping_criteria = self._get_stopping_criteria(
1220
- generation_config=generation_config, stopping_criteria=stopping_criteria
1221
- )
1222
- logits_warper = self._get_logits_warper(generation_config)
1223
 
1224
- unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1225
- scores = None
1226
  while True:
1227
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1228
- # forward pass to get next token
1229
- outputs = self(
1230
- **model_inputs,
1231
- return_dict=True,
1232
- output_attentions=False,
1233
- output_hidden_states=False,
1234
- )
1235
-
1236
- next_token_logits = outputs.logits[:, -1, :]
1237
-
1238
- # pre-process distribution
1239
- next_token_scores = logits_processor(input_ids, next_token_logits)
1240
- next_token_scores = logits_warper(input_ids, next_token_scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1241
 
1242
- # sample
1243
- probs = nn.functional.softmax(next_token_scores, dim=-1)
1244
- if generation_config.do_sample:
1245
- next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1246
- else:
1247
- next_tokens = torch.argmax(probs, dim=-1)
1248
 
1249
- # update generated ids, model inputs, and length for next step
1250
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1251
- model_kwargs = self._update_model_kwargs_for_generation(
1252
- outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1253
- )
1254
- unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
1255
 
1256
- # stop when each sentence is finished, or if we exceed the maximum length
1257
- if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1258
- break
1259
- yield input_ids
1260
 
1261
  def quantize(self, bits: int):
1262
  from .quantization import quantize
 
3
  import math
4
  import copy
5
  import os
 
 
6
  import sys
7
 
8
  import torch
 
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss, LayerNorm
13
  from torch.nn.utils import skip_init
14
+ from typing import Optional, Tuple, Union, List
15
 
16
  from transformers.utils import (
17
  add_code_sample_docstrings,
 
24
  BaseModelOutputWithPastAndCrossAttentions,
25
  )
26
  from transformers.modeling_utils import PreTrainedModel
 
 
 
27
 
28
+ from transformers.utils import logging
29
  from .configuration_chatglm import ChatGLMConfig
30
 
 
 
31
  if sys.platform != 'darwin':
32
  torch._C._jit_set_profiling_mode(False)
33
  torch._C._jit_set_profiling_executor(False)
34
  torch._C._jit_override_can_fuse_on_cpu(True)
35
  torch._C._jit_override_can_fuse_on_gpu(True)
36
 
37
+
38
  logger = logging.get_logger(__name__)
39
 
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
 
46
  ]
47
 
48
 
 
 
 
 
 
 
 
 
49
  def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
50
  """Load tf checkpoints in a pytorch model."""
51
  try:
 
140
  if learnable:
141
  self.inv_freq = torch.nn.Parameter(inv_freq)
142
  self.max_seq_len_cached = None
143
+ else:
144
  self.register_buffer('inv_freq', inv_freq)
145
  self.max_seq_len_cached = None
146
  self.cos_cached = None
 
156
  seq_len = x.shape[seq_dim]
157
  if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
158
  self.max_seq_len_cached = None if self.learnable else seq_len
159
+
160
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
161
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
162
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
163
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
164
+ if self.precision == torch.bfloat16:
165
+ emb = emb.float()
166
+
167
+ # [sx, 1 (b * np), hn]
168
+ cos_cached = emb.cos()[:, None, :]
169
+ sin_cached = emb.sin()[:, None, :]
170
+ if self.precision == torch.bfloat16:
171
+ cos_cached = cos_cached.bfloat16()
172
+ sin_cached = sin_cached.bfloat16()
173
+ if self.learnable:
174
+ return cos_cached, sin_cached
175
+ self.cos_cached, self.sin_cached = cos_cached, sin_cached
176
+
177
  return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
178
 
179
 
 
191
  return q, k
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  class SelfAttention(torch.nn.Module):
196
  def __init__(self, hidden_size, num_attention_heads,
 
280
  """
281
 
282
  # [seq_len, batch, 3 * hidden_size]
283
+ mixed_raw_layer = self.query_key_value.to(device=hidden_states.device)(hidden_states)
284
 
285
  # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
286
  new_tensor_shape = mixed_raw_layer.size()[:-1] + (
 
295
  if self.position_encoding_2d:
296
  q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
297
  k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
298
+ position_ids = position_ids.to(q1.device)
299
  cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
300
  position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \
301
  position_ids[:, 1, :].transpose(0, 1).contiguous()
 
305
  key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
306
  else:
307
  position_ids = position_ids.transpose(0, 1)
308
+ position_ids = position_ids.to(value_layer.device)
309
  cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
310
  # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
311
  query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)
312
 
313
  # [seq_len, batch, hidden_size]
314
+ context_layer, present, attention_probs = self.attention_fn(
 
315
  query_layer=query_layer,
316
  key_layer=key_layer,
317
  value_layer=value_layer,
318
+ attention_mask=attention_mask.to(query_layer.device),
319
  hidden_size_per_partition=self.hidden_size_per_partition,
320
  layer_id=layer_id,
321
  layer_past=layer_past,
322
  use_cache=use_cache
323
  )
324
+ # print("*"*80)
325
+ # print(f"{context_layer.device = }")
326
+ # print(f"{self.dense.weight.device = }")
327
 
328
  output = self.dense(context_layer)
329
 
 
334
 
335
  return outputs # output, present, attention_probs
336
 
337
+ def attention_fn(
338
+ self,
339
+ query_layer,
340
+ key_layer,
341
+ value_layer,
342
+ attention_mask,
343
+ hidden_size_per_partition,
344
+ layer_id,
345
+ layer_past=None,
346
+ scaling_attention_score=True,
347
+ use_cache=False,
348
+ ):
349
+ if layer_past is not None:
350
+ past_key, past_value = layer_past
351
+ key_layer = torch.cat((past_key, key_layer), dim=0)
352
+ value_layer = torch.cat((past_value, value_layer), dim=0)
353
+
354
+ # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
355
+ seq_len, b, nh, hidden_size = key_layer.shape
356
+
357
+ if use_cache:
358
+ present = (key_layer, value_layer)
359
+ else:
360
+ present = None
361
+
362
+ query_key_layer_scaling_coeff = float(layer_id + 1)
363
+ if scaling_attention_score:
364
+ query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
365
+
366
+ # ===================================
367
+ # Raw attention scores. [b, np, s, s]
368
+ # ===================================
369
+
370
+ # [b, np, sq, sk]
371
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
372
+
373
+ # [sq, b, np, hn] -> [sq, b * np, hn]
374
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
375
+ # [sk, b, np, hn] -> [sk, b * np, hn]
376
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
377
+
378
+ matmul_result = torch.empty(
379
+ output_size[0] * output_size[1],
380
+ output_size[2],
381
+ output_size[3],
382
+ dtype=query_layer.dtype,
383
+ device=query_layer.device,
384
+ )
385
+
386
+ matmul_result = torch.baddbmm(
387
+ matmul_result,
388
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
389
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
390
+ beta=0.0,
391
+ alpha=1.0,
392
+ )
393
+
394
+ # change view to [b, np, sq, sk]
395
+ attention_scores = matmul_result.view(*output_size)
396
+
397
+ if self.scale_mask_softmax:
398
+ self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
399
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
400
+ else:
401
+ # print("*"*80)
402
+ # print(f"{attention_mask.device = }")
403
+ # print(f"{attention_scores.device = }")
404
+ if not (attention_mask == 0).all():
405
+ # if auto-regressive, skip
406
+ attention_scores.masked_fill_(attention_mask, -10000.0)
407
+ dtype = attention_scores.type()
408
+ attention_scores = attention_scores.float()
409
+ attention_scores = attention_scores * query_key_layer_scaling_coeff
410
+
411
+ attention_probs = F.softmax(attention_scores, dim=-1)
412
+
413
+ attention_probs = attention_probs.type(dtype)
414
+
415
+ # =========================
416
+ # Context layer. [sq, b, hp]
417
+ # =========================
418
+
419
+ # value_layer -> context layer.
420
+ # [sk, b, np, hn] --> [b, np, sq, hn]
421
+
422
+ # context layer shape: [b, np, sq, hn]
423
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
424
+
425
+ # change view [sk, b * np, hn]
426
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
427
+
428
+ # change view [b * np, sq, sk]
429
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
430
+
431
+ # matmul: [b * np, sq, hn]
432
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
433
+
434
+ # change view [b, np, sq, hn]
435
+ context_layer = context_layer.view(*output_size)
436
+
437
+ # [b, np, sq, hn] --> [sq, b, np, hn]
438
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
439
+
440
+ # [sq, b, np, hn] --> [sq, b, hp]
441
+ new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
442
+ context_layer = context_layer.view(*new_context_layer_shape)
443
+
444
+ outputs = (context_layer, present, attention_probs)
445
+
446
+ return outputs
447
+
448
+
449
 
450
  class GEGLU(torch.nn.Module):
451
  def __init__(self):
 
611
  a simple interface for downloading and loading pretrained models.
612
  """
613
 
614
+ is_parallelizable = True
615
+ model_parallel = False
616
  supports_gradient_checkpointing = False
617
  config_class = ChatGLMConfig
618
  base_model_prefix = "transformer"
 
722
  self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
723
  self.position_encoding_2d = config.position_encoding_2d
724
 
725
+ self.gradient_checkpointing = True # 默认打开 用来节约显存
726
+
727
  self.word_embeddings = skip_init(
728
  torch.nn.Embedding,
729
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
 
757
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
758
  self.word_embeddings = new_embeddings
759
 
760
+ @staticmethod
761
+ def get_masks(seq, device):
762
+ context_length = seq.index(150004) + 1
763
 
764
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
765
  attention_mask.tril_()
 
770
  return attention_mask
771
 
772
  def get_position_ids(self, seq, mask_position, device, gmask=False):
773
+ context_length = seq.index(150004) + 1
774
  if self.position_encoding_2d:
775
+ seq_length = seq.index(150004)
776
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
777
  if not gmask:
778
  position_ids[seq_length:] = mask_position
 
827
 
828
  if past_key_values is None:
829
  past_key_values = tuple([None] * len(self.layers))
830
+
831
+ MASK, gMASK = 150000, 150001
832
+ mask_token = MASK if MASK in input_ids else gMASK
833
+ use_gmask = False if MASK in input_ids else gMASK
834
  seq = input_ids[0].tolist()
835
 
836
+ mask_position = seq.index(mask_token)
837
+
838
  if attention_mask is None:
839
  attention_mask = self.get_masks(
840
  seq=seq,
 
842
  )
843
 
844
  if position_ids is None:
 
 
 
 
 
845
  position_ids = self.get_position_ids(
846
  seq=seq,
847
  mask_position=mask_position,
 
850
  )
851
 
852
  if inputs_embeds is None:
853
+ # print("*"*80)
854
+ # print(f"{input_ids.device = }")
855
+ # print(f"{self.word_embeddings.weight.device = }")
856
+ inputs_embeds = self.word_embeddings(input_ids.to(self.word_embeddings.weight.device))
857
 
858
  # [seq_len, batch, hidden_size]
859
  hidden_states = inputs_embeds.transpose(0, 1)
860
 
861
+
862
+ if self.gradient_checkpointing and self.training:
863
+ if use_cache:
864
+ logger.warning_once(
865
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
866
+ )
867
+ use_cache = False
868
+
869
  presents = () if use_cache else None
870
  all_self_attentions = () if output_attentions else None
871
  all_hidden_states = () if output_hidden_states else None
872
 
873
+
874
+
875
  seq_length_with_past = seq_length
876
  past_key_values_length = 0
877
  if past_key_values[0] is not None:
 
888
  if output_hidden_states:
889
  all_hidden_states = all_hidden_states + (hidden_states,)
890
 
891
+ if self.gradient_checkpointing and self.training:
892
+ # https://mathpretty.com/11156.html
893
+ use_cache = False
894
+ def create_custom_forward(module):
895
+ def custom_forward(*inputs):
896
+ # None for past_key_value
897
+ return module(*inputs, use_cache, output_attentions)
898
+
899
+ return custom_forward
900
+
901
+ layer_ret = torch.utils.checkpoint.checkpoint(
902
+ create_custom_forward(layer),
903
+ # create_custom_forward(layer),
904
+ hidden_states,
905
+ position_ids,
906
+ attention_mask,
907
+ torch.ones(1, dtype=torch.float32, requires_grad=True) * i,
908
+ # torch.tensor(i, requires_grad=True),
909
+ past_key_values[i],
910
+
911
+ )
912
+
913
+ else:
914
+
915
+ layer_ret = layer(
916
+ hidden_states,
917
+ position_ids=position_ids,
918
+ attention_mask=attention_mask,
919
+ layer_id=torch.tensor(i),
920
+ layer_past=past_key_values[i],
921
+ use_cache=use_cache,
922
+ output_attentions=output_attentions
923
+ )
924
 
925
  hidden_states = layer_ret[0]
926
 
 
967
  bias=False,
968
  dtype=torch.half
969
  )
970
+ self.model_parallel = False
971
 
972
  def get_output_embeddings(self):
973
  return self.lm_head
 
983
  attention_mask = (attention_mask < 0.5).bool()
984
 
985
  if self.position_encoding_2d:
986
+ seq_length = seq.index(150004)
987
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
988
  if not gmask:
989
  position_ids[seq_length:] = mask_position
 
1021
 
1022
  # only last token for input_ids if past is not None
1023
  if past is not None or past_key_values is not None:
1024
+ context_length = seq.index(150004)
1025
  last_token = input_ids[:, -1].unsqueeze(-1)
1026
  if self.position_encoding_2d:
1027
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
 
1093
  shift_labels = labels[..., 1:].contiguous()
1094
  # Flatten the tokens
1095
  loss_fct = CrossEntropyLoss()
1096
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)).to(shift_labels.device), shift_labels.view(-1))
1097
 
 
 
1098
 
1099
  if not return_dict:
1100
  output = (lm_logits,) + transformer_outputs[1:]
 
1127
  for layer_past in past
1128
  )
1129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
  @torch.no_grad()
1131
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1132
+ do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
1133
  if history is None:
1134
  history = []
 
 
 
1135
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1136
+ "temperature": temperature, **kwargs}
1137
  if not history:
1138
  prompt = query
1139
  else:
 
1144
  input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
1145
  input_ids = input_ids.to(self.device)
1146
  outputs = self.generate(**input_ids, **gen_kwargs)
1147
+ outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]) - 2:]
1148
  response = tokenizer.decode(outputs)
1149
+ response = response.strip()
1150
+ response = response.replace("[[训练时间]]", "2023年")
1151
  history = history + [(query, response)]
1152
  return response, history
1153
 
1154
  @torch.no_grad()
1155
+ def generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1156
  self,
 
 
 
 
 
1157
  **kwargs,
1158
  ):
1159
+ MASK, gMASK = 150000, 150001
1160
+ bos, eos = 150004, 150005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1161
 
1162
+ if "eos_token_id" not in kwargs:
1163
+ kwargs["eos_token_id"] = eos
 
1164
 
1165
+ stop = False
 
 
 
 
 
 
1166
 
1167
+ return_seqs = []
 
 
 
1168
 
 
 
1169
  while True:
1170
+ print(kwargs)
1171
+ output_ids = super().generate(**kwargs)
1172
+
1173
+ return_seqs = []
1174
+ max_length = 0
1175
+
1176
+ for i in range(output_ids.shape[0]):
1177
+ output_seq = output_ids[i].tolist()
1178
+ mask_token = MASK if MASK in output_seq else gMASK
1179
+ mask_position = output_seq.index(mask_token)
1180
+ bos_position = output_seq.index(bos)
1181
+ if eos in output_seq:
1182
+ eos_position = output_seq.index(eos)
1183
+ else:
1184
+ eos_position = len(output_seq)
1185
+
1186
+ return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[
1187
+ mask_position + 1:bos_position]
1188
+ max_length = max(max_length, len(return_seq))
1189
+ return_seqs.append(return_seq)
1190
+
1191
+ for i in range(output_ids.shape[0]):
1192
+ return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding
1193
+ if mask_token not in return_seqs[i]:
1194
+ stop = True
1195
+
1196
+ if stop:
1197
+ break
1198
 
1199
+ for return_seq in return_seqs:
1200
+ return_seq += [bos]
 
 
 
 
1201
 
1202
+ kwargs['input_ids'] = torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
 
 
 
 
 
1203
 
1204
+ return torch.tensor(return_seqs, dtype=torch.long, device=kwargs['input_ids'].device)
 
 
 
1205
 
1206
  def quantize(self, bits: int):
1207
  from .quantization import quantize