zxdu20 commited on
Commit
aea6cef
1 Parent(s): 0564795

Implement gradient checkpointing

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +41 -20
modeling_chatglm.py CHANGED
@@ -244,7 +244,7 @@ def attention_fn(
244
  use_cache=False,
245
  ):
246
  if layer_past is not None:
247
- past_key, past_value = layer_past
248
  key_layer = torch.cat((past_key, key_layer), dim=0)
249
  value_layer = torch.cat((past_value, value_layer), dim=0)
250
 
@@ -644,7 +644,7 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
644
  """
645
 
646
  is_parallelizable = False
647
- supports_gradient_checkpointing = False
648
  config_class = ChatGLMConfig
649
  base_model_prefix = "transformer"
650
  _no_split_modules = ["GLM6BBlock"]
@@ -656,6 +656,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
656
  """Initialize the weights."""
657
  return
658
 
 
 
 
 
659
 
660
  CHATGLM_6B_START_DOCSTRING = r"""
661
  This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
@@ -760,6 +764,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
760
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
761
  dtype=self.params_dtype
762
  )
 
763
 
764
  def get_layer(layer_id):
765
  return GLMBlock(
@@ -812,9 +817,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
812
  #seq_len, b, nh, hidden_size
813
  past_key_values = self.dropout(past_key_values)
814
  past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
815
- past_key_values = [(v[0], v[1]) for v in past_key_values]
816
- # past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(self.num_layers)
817
- # past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
818
  return past_key_values
819
 
820
  def get_masks(self, input_ids, device):
@@ -877,6 +880,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
877
  use_cache = use_cache if use_cache is not None else self.config.use_cache
878
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
879
 
 
 
 
 
 
 
 
880
  if input_ids is not None and inputs_embeds is not None:
881
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
882
  elif input_ids is not None:
@@ -926,31 +936,42 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
926
  all_self_attentions = () if output_attentions else None
927
  all_hidden_states = () if output_hidden_states else None
928
 
929
- seq_length_with_past = seq_length
930
- past_key_values_length = 0
931
- if past_key_values[0] is not None:
932
- past_key_values_length = past_key_values[0][0].shape[0]
933
- seq_length_with_past = seq_length_with_past + past_key_values_length
934
  if attention_mask is None:
935
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
936
 
937
  else:
938
  attention_mask = attention_mask.to(input_ids.device)
939
 
 
 
 
940
  for i, layer in enumerate(self.layers):
941
 
942
  if output_hidden_states:
943
  all_hidden_states = all_hidden_states + (hidden_states,)
944
-
945
- layer_ret = layer(
946
- hidden_states,
947
- position_ids=position_ids,
948
- attention_mask=attention_mask,
949
- layer_id=torch.tensor(i),
950
- layer_past=past_key_values[i],
951
- use_cache=use_cache,
952
- output_attentions=output_attentions
953
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
954
 
955
  hidden_states = layer_ret[0]
956
 
 
244
  use_cache=False,
245
  ):
246
  if layer_past is not None:
247
+ past_key, past_value = layer_past[0], layer_past[1]
248
  key_layer = torch.cat((past_key, key_layer), dim=0)
249
  value_layer = torch.cat((past_value, value_layer), dim=0)
250
 
 
644
  """
645
 
646
  is_parallelizable = False
647
+ supports_gradient_checkpointing = True
648
  config_class = ChatGLMConfig
649
  base_model_prefix = "transformer"
650
  _no_split_modules = ["GLM6BBlock"]
 
656
  """Initialize the weights."""
657
  return
658
 
659
+ def _set_gradient_checkpointing(self, module, value=False):
660
+ if isinstance(module, ChatGLMModel):
661
+ module.gradient_checkpointing = value
662
+
663
 
664
  CHATGLM_6B_START_DOCSTRING = r"""
665
  This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
 
764
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
765
  dtype=self.params_dtype
766
  )
767
+ self.gradient_checkpointing = False
768
 
769
  def get_layer(layer_id):
770
  return GLMBlock(
 
817
  #seq_len, b, nh, hidden_size
818
  past_key_values = self.dropout(past_key_values)
819
  past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
820
+ # past_key_values = [(v[0], v[1]) for v in past_key_values]
 
 
821
  return past_key_values
822
 
823
  def get_masks(self, input_ids, device):
 
880
  use_cache = use_cache if use_cache is not None else self.config.use_cache
881
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
882
 
883
+ if self.gradient_checkpointing and self.training:
884
+ if use_cache:
885
+ logger.warning_once(
886
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
887
+ )
888
+ use_cache = False
889
+
890
  if input_ids is not None and inputs_embeds is not None:
891
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
892
  elif input_ids is not None:
 
936
  all_self_attentions = () if output_attentions else None
937
  all_hidden_states = () if output_hidden_states else None
938
 
 
 
 
 
 
939
  if attention_mask is None:
940
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
941
 
942
  else:
943
  attention_mask = attention_mask.to(input_ids.device)
944
 
945
+ if self.training:
946
+ hidden_states = hidden_states.requires_grad_(True)
947
+
948
  for i, layer in enumerate(self.layers):
949
 
950
  if output_hidden_states:
951
  all_hidden_states = all_hidden_states + (hidden_states,)
952
+ layer_past = past_key_values[i]
953
+
954
+ if self.gradient_checkpointing and self.training:
955
+ layer_ret = torch.utils.checkpoint.checkpoint(
956
+ layer,
957
+ hidden_states,
958
+ position_ids,
959
+ attention_mask,
960
+ torch.tensor(i),
961
+ layer_past,
962
+ use_cache,
963
+ output_attentions
964
+ )
965
+ else:
966
+ layer_ret = layer(
967
+ hidden_states,
968
+ position_ids=position_ids,
969
+ attention_mask=attention_mask,
970
+ layer_id=torch.tensor(i),
971
+ layer_past=layer_past,
972
+ use_cache=use_cache,
973
+ output_attentions=output_attentions
974
+ )
975
 
976
  hidden_states = layer_ret[0]
977