davidlvxin commited on
Commit
a93f22e
1 Parent(s): acb9849

Update modeling_chatglm.py

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +23 -10
modeling_chatglm.py CHANGED
@@ -416,7 +416,10 @@ class SelfAttention(torch.nn.Module):
416
  key_layer = torch.cat((cache_k, key_layer), dim=0)
417
  value_layer = torch.cat((cache_v, value_layer), dim=0)
418
  if use_cache:
419
- kv_cache = (key_layer, value_layer)
 
 
 
420
  else:
421
  kv_cache = None
422
 
@@ -627,12 +630,8 @@ class GLMTransformer(torch.nn.Module):
627
  if not kv_caches:
628
  kv_caches = [None for _ in range(self.num_layers)]
629
  presents = () if use_cache else None
630
- if self.gradient_checkpointing and self.training:
631
- if use_cache:
632
- logger.warning_once(
633
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
634
- )
635
- use_cache = False
636
 
637
  all_self_attentions = None
638
  all_hidden_states = () if output_hidden_states else None
@@ -660,7 +659,15 @@ class GLMTransformer(torch.nn.Module):
660
  )
661
  hidden_states, kv_cache = layer_ret
662
  if use_cache:
663
- presents = presents + (kv_cache,)
 
 
 
 
 
 
 
 
664
 
665
  if output_hidden_states:
666
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -845,6 +852,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
845
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
846
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
847
  )
 
 
 
 
 
 
848
 
849
  if not return_dict:
850
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
@@ -1036,7 +1049,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1036
 
1037
  @torch.inference_mode()
1038
  def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1039
- max_length: int = 131072, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1040
  **kwargs):
1041
  if history is None:
1042
  history = []
@@ -1058,7 +1071,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1058
 
1059
  @torch.inference_mode()
1060
  def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1061
- past_key_values=None,max_length: int = 131072, do_sample=True, top_p=0.8, temperature=0.8,
1062
  logits_processor=None, return_past_key_values=False, **kwargs):
1063
  if history is None:
1064
  history = []
 
416
  key_layer = torch.cat((cache_k, key_layer), dim=0)
417
  value_layer = torch.cat((cache_v, value_layer), dim=0)
418
  if use_cache:
419
+ if kv_cache is None:
420
+ kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
421
+ else:
422
+ kv_cache = (key_layer, value_layer)
423
  else:
424
  kv_cache = None
425
 
 
630
  if not kv_caches:
631
  kv_caches = [None for _ in range(self.num_layers)]
632
  presents = () if use_cache else None
633
+ if self.training:
634
+ use_cache = False
 
 
 
 
635
 
636
  all_self_attentions = None
637
  all_hidden_states = () if output_hidden_states else None
 
659
  )
660
  hidden_states, kv_cache = layer_ret
661
  if use_cache:
662
+ # token by token decoding, use tuple format
663
+ if kv_caches[0] is not None:
664
+ presents = presents + (kv_cache,)
665
+ # prefilling in decoding, use tensor format to save cuda memory
666
+ else:
667
+ if len(presents) == 0:
668
+ presents = kv_cache
669
+ else:
670
+ presents = torch.cat((presents, kv_cache), dim=0)
671
 
672
  if output_hidden_states:
673
  all_hidden_states = all_hidden_states + (hidden_states,)
 
852
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
853
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
854
  )
855
+ if presents is not None and type(presents) is torch.Tensor:
856
+ presents = presents.split(1, dim=0)
857
+ presents = list(presents)
858
+ presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
859
+ presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
860
+ presents = tuple(presents)
861
 
862
  if not return_dict:
863
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
1049
 
1050
  @torch.inference_mode()
1051
  def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1052
+ max_length: int = 131072, num_beams=1, do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None,
1053
  **kwargs):
1054
  if history is None:
1055
  history = []
 
1071
 
1072
  @torch.inference_mode()
1073
  def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
1074
+ past_key_values=None,max_length: int = 131072, do_sample=True, top_p=0.7, temperature=0.95,
1075
  logits_processor=None, return_past_key_values=False, **kwargs):
1076
  if history is None:
1077
  history = []