katuni4ka commited on
Commit
71090c6
1 Parent(s): f86aa57

fix pkv update for new transformers compatibility

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +13 -4
modeling_chatglm.py CHANGED
@@ -14,6 +14,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm
14
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
  from torch.nn.utils import skip_init
16
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
 
17
 
18
  from transformers.modeling_outputs import (
19
  BaseModelOutputWithPast,
@@ -45,6 +46,8 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
  # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
46
  ]
47
 
 
 
48
 
49
  def default_init(cls, *args, **kwargs):
50
  return cls(*args, **kwargs)
@@ -867,10 +870,16 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
867
  is_encoder_decoder: bool = False,
868
  standardize_cache_format: bool = False,
869
  ) -> Dict[str, Any]:
870
- # update past_key_values
871
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
872
- outputs, standardize_cache_format=standardize_cache_format
873
- )
 
 
 
 
 
 
874
 
875
  # update attention mask
876
  if "attention_mask" in model_kwargs:
 
14
  from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
15
  from torch.nn.utils import skip_init
16
  from typing import Optional, Tuple, Union, List, Callable, Dict, Any
17
+ import transformers
18
 
19
  from transformers.modeling_outputs import (
20
  BaseModelOutputWithPast,
 
46
  # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
47
  ]
48
 
49
+ is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
50
+
51
 
52
  def default_init(cls, *args, **kwargs):
53
  return cls(*args, **kwargs)
 
870
  is_encoder_decoder: bool = False,
871
  standardize_cache_format: bool = False,
872
  ) -> Dict[str, Any]:
873
+ if is_transformers_4_42_or_higher:
874
+ # update past_key_values
875
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
876
+ outputs, standardize_cache_format=standardize_cache_format
877
+ )[1]
878
+ else:
879
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
880
+ outputs, standardize_cache_format=standardize_cache_format
881
+ )
882
+
883
 
884
  # update attention mask
885
  if "attention_mask" in model_kwargs: