Files changed (3) hide show
  1. README.md +0 -6
  2. modeling_chatglm.py +8 -21
  3. tokenization_chatglm.py +1 -1
README.md CHANGED
@@ -16,13 +16,7 @@ tags:
16
  👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
17
  </p>
18
 
19
- ## 更新/Update
20
-
21
- - 我们优化了KV Cache的存储方式,减少了显存碎片的产生。基于优化后的代码,模型可以在约**20G显存**的情况下处理32K长度的上下文(FP/BF16格式)。
22
- - We have optimized the storage method of the KV Cache, reducing the generation of memory fragmentation. Based on the optimized code, the model can process a context length of 32K under approximately **20G** of memory (FP/BF16 format).
23
-
24
  ## 介绍
25
-
26
  ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
27
 
28
  ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
 
16
  👋 Join our <a href="https://join.slack.com/t/chatglm/shared_invite/zt-1y7pqoloy-9b1g6T6JjA8J0KxvUjbwJw" target="_blank">Slack</a> and <a href="https://github.com/THUDM/ChatGLM-6B/blob/main/resources/WECHAT.md" target="_blank">WeChat</a>
17
  </p>
18
 
 
 
 
 
 
19
  ## 介绍
 
20
  ChatGLM**2**-6B-32K在[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b)的基础上进一步强化了对于长文本的理解能力,能够更好的处理最多32K长度的上下文。具体地,我们基于[位置插值](https://arxiv.org/abs/2306.15595)(Positional Interpolation)的方法对位置编码进行了更新,并在对话阶段使用 32K 的上下文长度训练。在实际的使用中,如果您面临的上下文长度基本在 **8K 以内**,我们推荐使用[ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b);如果您需要处理**超过 8K** 的上下文长度,我们推荐使用ChatGLM2-6B-32K。
21
 
22
  ChatGLM**2**-6B-32K是开源中英双语对话模型 [ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) 的加长版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B-32k 引入了如下新特性:
modeling_chatglm.py CHANGED
@@ -413,10 +413,7 @@ class SelfAttention(torch.nn.Module):
413
  key_layer = torch.cat((cache_k, key_layer), dim=0)
414
  value_layer = torch.cat((cache_v, value_layer), dim=0)
415
  if use_cache:
416
- if kv_cache is None:
417
- kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
418
- else:
419
- kv_cache = (key_layer, value_layer)
420
  else:
421
  kv_cache = None
422
 
@@ -615,8 +612,12 @@ class GLMTransformer(torch.nn.Module):
615
  if not kv_caches:
616
  kv_caches = [None for _ in range(self.num_layers)]
617
  presents = () if use_cache else None
618
- if self.training:
619
- use_cache = False
 
 
 
 
620
 
621
  all_self_attentions = None
622
  all_hidden_states = () if output_hidden_states else None
@@ -644,15 +645,7 @@ class GLMTransformer(torch.nn.Module):
644
  )
645
  hidden_states, kv_cache = layer_ret
646
  if use_cache:
647
- # token by token decoding, use tuple format
648
- if kv_caches[0] is not None:
649
- presents = presents + (kv_cache,)
650
- # prefilling in decoding, use tensor format to save cuda memory
651
- else:
652
- if len(presents) == 0:
653
- presents = kv_cache
654
- else:
655
- presents = torch.cat((presents, kv_cache), dim=0)
656
 
657
  if output_hidden_states:
658
  all_hidden_states = all_hidden_states + (hidden_states,)
@@ -837,12 +830,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
837
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
838
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
839
  )
840
- if presents is not None and type(presents) is torch.Tensor:
841
- presents = presents.split(1, dim=0)
842
- presents = list(presents)
843
- presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
844
- presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
845
- presents = tuple(presents)
846
 
847
  if not return_dict:
848
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
 
413
  key_layer = torch.cat((cache_k, key_layer), dim=0)
414
  value_layer = torch.cat((cache_v, value_layer), dim=0)
415
  if use_cache:
416
+ kv_cache = (key_layer, value_layer)
 
 
 
417
  else:
418
  kv_cache = None
419
 
 
612
  if not kv_caches:
613
  kv_caches = [None for _ in range(self.num_layers)]
614
  presents = () if use_cache else None
615
+ if self.gradient_checkpointing and self.training:
616
+ if use_cache:
617
+ logger.warning_once(
618
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
619
+ )
620
+ use_cache = False
621
 
622
  all_self_attentions = None
623
  all_hidden_states = () if output_hidden_states else None
 
645
  )
646
  hidden_states, kv_cache = layer_ret
647
  if use_cache:
648
+ presents = presents + (kv_cache,)
 
 
 
 
 
 
 
 
649
 
650
  if output_hidden_states:
651
  all_hidden_states = all_hidden_states + (hidden_states,)
 
830
  inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
831
  kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
832
  )
 
 
 
 
 
 
833
 
834
  if not return_dict:
835
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
tokenization_chatglm.py CHANGED
@@ -66,6 +66,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
66
  model_input_names = ["input_ids", "attention_mask", "position_ids"]
67
 
68
  def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
 
69
  self.name = "GLMTokenizer"
70
 
71
  self.vocab_file = vocab_file
@@ -75,7 +76,6 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
75
  "<eos>": self.tokenizer.eos_id,
76
  "<pad>": self.tokenizer.pad_id
77
  }
78
- super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
79
 
80
  def get_command(self, token):
81
  if token in self.special_tokens:
 
66
  model_input_names = ["input_ids", "attention_mask", "position_ids"]
67
 
68
  def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
69
+ super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
70
  self.name = "GLMTokenizer"
71
 
72
  self.vocab_file = vocab_file
 
76
  "<eos>": self.tokenizer.eos_id,
77
  "<pad>": self.tokenizer.pad_id
78
  }
 
79
 
80
  def get_command(self, token):
81
  if token in self.special_tokens: