Qishuai commited on
Commit
e3c7deb
·
verified ·
1 Parent(s): 2bf7de6

Upload modeling_cogvlm.py

Browse files

## Compatible with Transformers > 4.41.2

### Details:
The current implementation has an error with the line:
```python
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
```
When the Transformers version is > 4.41.2.
The issue is caused by the change in the output of
`_extract_past_from_model_output` function defined in Transformers `src/transformers/generation/utils.py` since version v4.42.0.
![Screenshot 2024-08-15 at 11.45.29 AM.png](https://cdn-uploads.huggingface.co/production/uploads/618cba5eb5e8b9d34205d474/sc9jcZAPTx9BZcOg1qHIT.png)

Therefore, my pr includes checking the version of Transformers and modifying the process of the output of `_extract_past_from_model_output` to make sure cogvlm2 can work with newer version of Transformers

Files changed (1) hide show
  1. modeling_cogvlm.py +15 -4
modeling_cogvlm.py CHANGED
@@ -1,9 +1,11 @@
1
  """largely copy from llama and adapt for cogvlm"""
2
  import warnings
 
3
  from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
4
 
5
  import math
6
  import torch
 
7
  from torch import nn
8
  from torch.nn import CrossEntropyLoss
9
  from torchvision import transforms
@@ -26,7 +28,12 @@ logger = get_logger(__name__)
26
 
27
  LANGUAGE_TOKEN_TYPE = 0
28
  VISION_TOKEN_TYPE = 1
29
-
 
 
 
 
 
30
 
31
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
32
  def _make_causal_mask(
@@ -736,9 +743,13 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
736
  standardize_cache_format: bool = False,
737
  ) -> Dict[str, Any]:
738
  # update past_key_values
739
- model_kwargs["past_key_values"] = self._extract_past_from_model_output(
740
- outputs, standardize_cache_format=standardize_cache_format
741
- )
 
 
 
 
742
  if getattr(outputs, "state", None) is not None:
743
  model_kwargs["state"] = outputs.state
744
 
 
1
  """largely copy from llama and adapt for cogvlm"""
2
  import warnings
3
+ import packaging.version
4
  from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
5
 
6
  import math
7
  import torch
8
+ import transformers
9
  from torch import nn
10
  from torch.nn import CrossEntropyLoss
11
  from torchvision import transforms
 
28
 
29
  LANGUAGE_TOKEN_TYPE = 0
30
  VISION_TOKEN_TYPE = 1
31
+ TRANSFORMERS_ABOVE_441 = (
32
+ True
33
+ if packaging.version.parse(transformers.__version__)
34
+ >= packaging.version.parse("4.42.0")
35
+ else False
36
+ )
37
 
38
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
39
  def _make_causal_mask(
 
743
  standardize_cache_format: bool = False,
744
  ) -> Dict[str, Any]:
745
  # update past_key_values
746
+ if TRANSFORMERS_ABOVE_441:
747
+ cache_name, cache = self._extract_past_from_model_output(outputs)
748
+ model_kwargs[cache_name] = cache
749
+ else:
750
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
751
+ outputs, standardize_cache_format=standardize_cache_format
752
+ )
753
  if getattr(outputs, "state", None) is not None:
754
  model_kwargs["state"] = outputs.state
755