Qubitium commited on
Commit
6c22367
1 Parent(s): 4d86323

sync with main

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +24 -16
modeling_chatglm.py CHANGED
@@ -40,12 +40,6 @@ logger = logging.get_logger(__name__)
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
43
- CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
- "THUDM/chatglm3-6b",
45
- # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
- ]
47
-
48
-
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
51
 
@@ -253,15 +247,12 @@ class CoreAttention(torch.nn.Module):
253
  # This is actually dropping out entire tokens to attend to, which might
254
  # seem a bit unusual, but is taken from the original Transformer paper.
255
  attention_probs = self.attention_dropout(attention_probs)
256
- # =========================
257
- # Context layer. [sq, b, hp]
258
- # =========================
259
-
260
- # value_layer -> context layer.
261
- # [sk, b, np, hn] --> [b, np, sq, hn]
262
 
 
 
 
263
  # context layer shape: [b, np, sq, hn]
264
- output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
265
  # change view [b * np, sk, hn]
266
  value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
267
  # change view [b * np, sq, sk]
@@ -386,7 +377,10 @@ class SelfAttention(torch.nn.Module):
386
  key_layer = torch.cat((cache_k, key_layer), dim=2)
387
  value_layer = torch.cat((cache_v, value_layer), dim=2)
388
  if use_cache:
389
- kv_cache = (key_layer, value_layer)
 
 
 
390
  else:
391
  kv_cache = None
392
 
@@ -605,7 +599,7 @@ class GLMTransformer(torch.nn.Module):
605
  hidden_states,
606
  attention_mask=attention_mask,
607
  rotary_pos_emb=rotary_pos_emb,
608
- kv_cache=kv_caches[index],
609
  use_cache=use_cache,
610
  use_reentrant=False
611
  )
@@ -619,7 +613,15 @@ class GLMTransformer(torch.nn.Module):
619
  )
620
  hidden_states, kv_cache = layer_ret
621
  if use_cache:
622
- presents = presents + (kv_cache,)
 
 
 
 
 
 
 
 
623
 
624
  if output_hidden_states:
625
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -773,6 +775,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
773
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
774
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
775
  )
 
 
 
 
 
 
776
 
777
  if not return_dict:
778
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
40
  _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
41
  _CONFIG_FOR_DOC = "ChatGLMConfig"
42
 
 
 
 
 
 
 
43
  def default_init(cls, *args, **kwargs):
44
  return cls(*args, **kwargs)
45
 
 
247
  # This is actually dropping out entire tokens to attend to, which might
248
  # seem a bit unusual, but is taken from the original Transformer paper.
249
  attention_probs = self.attention_dropout(attention_probs)
 
 
 
 
 
 
250
 
251
+ # query layer shape: [b * np, sq, hn]
252
+ # value layer shape: [b, np, sk, hn]
253
+ # attention shape: [b, np, sq, sk]
254
  # context layer shape: [b, np, sq, hn]
255
+ output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3))
256
  # change view [b * np, sk, hn]
257
  value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
258
  # change view [b * np, sq, sk]
 
377
  key_layer = torch.cat((cache_k, key_layer), dim=2)
378
  value_layer = torch.cat((cache_v, value_layer), dim=2)
379
  if use_cache:
380
+ if kv_cache is None:
381
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
382
+ else:
383
+ kv_cache = (key_layer, value_layer)
384
  else:
385
  kv_cache = None
386
 
 
599
  hidden_states,
600
  attention_mask=attention_mask,
601
  rotary_pos_emb=rotary_pos_emb,
602
+ kv_caches=kv_caches[index],
603
  use_cache=use_cache,
604
  use_reentrant=False
605
  )
 
613
  )
614
  hidden_states, kv_cache = layer_ret
615
  if use_cache:
616
+ # token by token decoding, use tuple format
617
+ if kv_caches[0] is not None:
618
+ presents = presents + (kv_cache,)
619
+ # prefilling in decoding, use tensor format to save cuda memory
620
+ else:
621
+ if len(presents) == 0:
622
+ presents = kv_cache
623
+ else:
624
+ presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)
625
 
626
  if output_hidden_states:
627
  all_hidden_states = all_hidden_states + (hidden_states,)
 
775
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
776
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
777
  )
778
+ if presents is not None and type(presents) is torch.Tensor:
779
+ presents = presents.split(1, dim=0)
780
+ presents = list(presents)
781
+ presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
782
+ presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
783
+ presents = tuple(presents)
784
 
785
  if not return_dict:
786
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)