Update modeling_chatglm.py
Browse files- modeling_chatglm.py +6 -1
modeling_chatglm.py
CHANGED
@@ -47,6 +47,7 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
47 |
]
|
48 |
|
49 |
is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
|
|
|
50 |
|
51 |
|
52 |
def default_init(cls, *args, **kwargs):
|
@@ -870,7 +871,11 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
870 |
is_encoder_decoder: bool = False,
|
871 |
standardize_cache_format: bool = False,
|
872 |
) -> Dict[str, Any]:
|
873 |
-
if
|
|
|
|
|
|
|
|
|
874 |
# update past_key_values
|
875 |
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
876 |
outputs, standardize_cache_format=standardize_cache_format
|
|
|
47 |
]
|
48 |
|
49 |
is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
|
50 |
+
is_transformers_4_44_or_higher = int(transformers.__version__.split(".")[1]) >= 44
|
51 |
|
52 |
|
53 |
def default_init(cls, *args, **kwargs):
|
|
|
871 |
is_encoder_decoder: bool = False,
|
872 |
standardize_cache_format: bool = False,
|
873 |
) -> Dict[str, Any]:
|
874 |
+
if is_transformers_4_44_or_higher:
|
875 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
876 |
+
outputs
|
877 |
+
)[1]
|
878 |
+
elif is_transformers_4_42_or_higher:
|
879 |
# update past_key_values
|
880 |
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
881 |
outputs, standardize_cache_format=standardize_cache_format
|