hlarcher HF staff commited on
Commit
67f33d0
1 Parent(s): 9fbc122

Fix attention for Nvidia V100s compatibility (no FlashAttention). Based on work of puru22 for Falcon-40B

Browse files
Files changed (2) hide show
  1. .gitignore +211 -0
  2. modelling_RW.py +139 -111
.gitignore ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python,macos
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python,macos
3
+
4
+ ### macOS ###
5
+ # General
6
+ .DS_Store
7
+ .AppleDouble
8
+ .LSOverride
9
+
10
+ # Icon must end with two \r
11
+ Icon
12
+
13
+
14
+ # Thumbnails
15
+ ._*
16
+
17
+ # Files that might appear in the root of a volume
18
+ .DocumentRevisions-V100
19
+ .fseventsd
20
+ .Spotlight-V100
21
+ .TemporaryItems
22
+ .Trashes
23
+ .VolumeIcon.icns
24
+ .com.apple.timemachine.donotpresent
25
+
26
+ # Directories potentially created on remote AFP share
27
+ .AppleDB
28
+ .AppleDesktop
29
+ Network Trash Folder
30
+ Temporary Items
31
+ .apdisk
32
+
33
+ ### macOS Patch ###
34
+ # iCloud generated files
35
+ *.icloud
36
+
37
+ ### Python ###
38
+ # Byte-compiled / optimized / DLL files
39
+ __pycache__/
40
+ *.py[cod]
41
+ *$py.class
42
+
43
+ # C extensions
44
+ *.so
45
+
46
+ # Distribution / packaging
47
+ .Python
48
+ build/
49
+ develop-eggs/
50
+ dist/
51
+ downloads/
52
+ eggs/
53
+ .eggs/
54
+ lib/
55
+ lib64/
56
+ parts/
57
+ sdist/
58
+ var/
59
+ wheels/
60
+ share/python-wheels/
61
+ *.egg-info/
62
+ .installed.cfg
63
+ *.egg
64
+ MANIFEST
65
+
66
+ # PyInstaller
67
+ # Usually these files are written by a python script from a template
68
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
69
+ *.manifest
70
+ *.spec
71
+
72
+ # Installer logs
73
+ pip-log.txt
74
+ pip-delete-this-directory.txt
75
+
76
+ # Unit test / coverage reports
77
+ htmlcov/
78
+ .tox/
79
+ .nox/
80
+ .coverage
81
+ .coverage.*
82
+ .cache
83
+ nosetests.xml
84
+ coverage.xml
85
+ *.cover
86
+ *.py,cover
87
+ .hypothesis/
88
+ .pytest_cache/
89
+ cover/
90
+
91
+ # Translations
92
+ *.mo
93
+ *.pot
94
+
95
+ # Django stuff:
96
+ *.log
97
+ local_settings.py
98
+ db.sqlite3
99
+ db.sqlite3-journal
100
+
101
+ # Flask stuff:
102
+ instance/
103
+ .webassets-cache
104
+
105
+ # Scrapy stuff:
106
+ .scrapy
107
+
108
+ # Sphinx documentation
109
+ docs/_build/
110
+
111
+ # PyBuilder
112
+ .pybuilder/
113
+ target/
114
+
115
+ # Jupyter Notebook
116
+ .ipynb_checkpoints
117
+
118
+ # IPython
119
+ profile_default/
120
+ ipython_config.py
121
+
122
+ # pyenv
123
+ # For a library or package, you might want to ignore these files since the code is
124
+ # intended to run in multiple environments; otherwise, check them in:
125
+ # .python-version
126
+
127
+ # pipenv
128
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
129
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
130
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
131
+ # install all needed dependencies.
132
+ #Pipfile.lock
133
+
134
+ # poetry
135
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
136
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
137
+ # commonly ignored for libraries.
138
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
139
+ #poetry.lock
140
+
141
+ # pdm
142
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
143
+ #pdm.lock
144
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
145
+ # in version control.
146
+ # https://pdm.fming.dev/#use-with-ide
147
+ .pdm.toml
148
+
149
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
150
+ __pypackages__/
151
+
152
+ # Celery stuff
153
+ celerybeat-schedule
154
+ celerybeat.pid
155
+
156
+ # SageMath parsed files
157
+ *.sage.py
158
+
159
+ # Environments
160
+ .env
161
+ .venv
162
+ env/
163
+ venv/
164
+ ENV/
165
+ env.bak/
166
+ venv.bak/
167
+
168
+ # Spyder project settings
169
+ .spyderproject
170
+ .spyproject
171
+
172
+ # Rope project settings
173
+ .ropeproject
174
+
175
+ # mkdocs documentation
176
+ /site
177
+
178
+ # mypy
179
+ .mypy_cache/
180
+ .dmypy.json
181
+ dmypy.json
182
+
183
+ # Pyre type checker
184
+ .pyre/
185
+
186
+ # pytype static type analyzer
187
+ .pytype/
188
+
189
+ # Cython debug symbols
190
+ cython_debug/
191
+
192
+ # PyCharm
193
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
194
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
195
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
196
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
197
+ #.idea/
198
+
199
+ ### Python Patch ###
200
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
201
+ poetry.toml
202
+
203
+ # ruff
204
+ .ruff_cache/
205
+
206
+ # LSP config files
207
+ pyrightconfig.json
208
+
209
+ # End of https://www.toptal.com/developers/gitignore/api/python,macos
210
+
211
+ .idea
modelling_RW.py CHANGED
@@ -25,6 +25,7 @@ from .configuration_RW import RWConfig
25
 
26
  logger = logging.get_logger(__name__)
27
 
 
28
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
  class Linear(nn.Linear):
@@ -38,9 +39,10 @@ class Linear(nn.Linear):
38
 
39
  from einops import rearrange
40
 
 
41
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
42
  def rotate_half(x):
43
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
44
  return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
45
 
46
 
@@ -51,9 +53,9 @@ class RotaryEmbedding(torch.nn.Module):
51
  """
52
 
53
  def __init__(
54
- self,
55
- head_dim: int,
56
- base=10000,
57
  ):
58
  super().__init__()
59
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
@@ -65,10 +67,10 @@ class RotaryEmbedding(torch.nn.Module):
65
  self.sin_cached: torch.Tensor | None = None
66
 
67
  def cos_sin(
68
- self,
69
- seq_len: int,
70
- device="cuda",
71
- dtype=torch.bfloat16,
72
  ) -> torch.Tensor:
73
  if seq_len != self.seq_len_cached:
74
  self.seq_len_cached = seq_len
@@ -87,23 +89,31 @@ class RotaryEmbedding(torch.nn.Module):
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
 
 
 
 
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
 
94
 
95
 
96
  def _make_causal_mask(
97
- input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
98
  ) -> torch.BoolTensor:
99
  batch_size, target_length = input_ids_shape
100
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
  seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
 
105
  if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
@@ -230,14 +240,14 @@ class Attention(nn.Module):
230
  return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
231
 
232
  def forward(
233
- self,
234
- hidden_states: torch.Tensor,
235
- alibi: torch.Tensor,
236
- attention_mask: torch.Tensor,
237
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
238
- head_mask: Optional[torch.Tensor] = None,
239
- use_cache: bool = False,
240
- output_attentions: bool = False,
241
  ):
242
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
243
 
@@ -256,18 +266,27 @@ class Attention(nn.Module):
256
 
257
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
258
 
 
 
 
 
 
 
 
259
  if layer_past is not None:
260
  past_key, past_value = layer_past
261
  # concatenate along seq_length dimension:
262
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
263
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
 
264
  key_layer = torch.cat((past_key, key_layer), dim=1)
265
  value_layer = torch.cat((past_value, value_layer), dim=1)
266
 
267
  _, kv_length, _ = key_layer.shape
268
 
269
  if use_cache is True:
270
- present = (key_layer, value_layer)
 
271
  else:
272
  present = None
273
 
@@ -276,9 +295,14 @@ class Attention(nn.Module):
276
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
278
 
279
- attn_output = F.scaled_dot_product_attention(
280
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
- )
 
 
 
 
 
282
 
283
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
284
  x = x.permute(0, 2, 1, 3)
@@ -303,7 +327,8 @@ class Attention(nn.Module):
303
  attention_scores = attention_scores.to(torch.float32)
304
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
305
  attention_probs = F.softmax(
306
- (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
 
307
  dim=-1,
308
  dtype=hidden_states.dtype,
309
  )
@@ -368,14 +393,14 @@ class DecoderLayer(nn.Module):
368
  self.config = config
369
 
370
  def forward(
371
- self,
372
- hidden_states: torch.Tensor,
373
- alibi: torch.Tensor,
374
- attention_mask: torch.Tensor,
375
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
376
- head_mask: Optional[torch.Tensor] = None,
377
- use_cache: bool = False,
378
- output_attentions: bool = False,
379
  ):
380
 
381
  layernorm_output = self.input_layernorm(hidden_states)
@@ -453,7 +478,7 @@ class RWPreTrainedModel(PreTrainedModel):
453
 
454
  @staticmethod
455
  def _convert_to_standard_cache(
456
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
457
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
458
  """
459
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
@@ -473,7 +498,7 @@ class RWPreTrainedModel(PreTrainedModel):
473
 
474
  @staticmethod
475
  def _convert_to_rw_cache(
476
- past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
477
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
478
  batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
479
  batch_size_times_num_heads = batch_size * num_heads
@@ -514,7 +539,7 @@ class RWModel(RWPreTrainedModel):
514
  return self.word_embeddings
515
 
516
  def _prepare_attn_mask(
517
- self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
518
  ) -> torch.BoolTensor:
519
  # create causal mask
520
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
@@ -522,10 +547,10 @@ class RWModel(RWPreTrainedModel):
522
  device = attention_mask.device
523
  _, src_length = input_shape
524
 
525
- if src_length > 1:
526
- combined_attention_mask = _make_causal_mask(
527
- input_shape, device=device, past_key_values_length=past_key_values_length
528
- )
529
 
530
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
531
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
@@ -539,17 +564,17 @@ class RWModel(RWPreTrainedModel):
539
  self.word_embeddings = new_embeddings
540
 
541
  def forward(
542
- self,
543
- input_ids: Optional[torch.LongTensor] = None,
544
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
545
- attention_mask: Optional[torch.Tensor] = None,
546
- head_mask: Optional[torch.LongTensor] = None,
547
- inputs_embeds: Optional[torch.LongTensor] = None,
548
- use_cache: Optional[bool] = None,
549
- output_attentions: Optional[bool] = None,
550
- output_hidden_states: Optional[bool] = None,
551
- return_dict: Optional[bool] = None,
552
- **deprecated_arguments,
553
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
554
  if deprecated_arguments.pop("position_ids", False) is not False:
555
  # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
@@ -697,40 +722,43 @@ class RWForCausalLM(RWPreTrainedModel):
697
  self.lm_head = new_embeddings
698
 
699
  def prepare_inputs_for_generation(
700
- self,
701
- input_ids: torch.LongTensor,
702
- past: Optional[torch.Tensor] = None,
703
- attention_mask: Optional[torch.Tensor] = None,
704
- **kwargs,
705
  ) -> dict:
706
  # only last token for input_ids if past is not None
707
- if past:
708
  input_ids = input_ids[:, -1].unsqueeze(-1)
709
-
710
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
711
- if past[0][0].shape[0] == input_ids.shape[0]:
712
- past = self._convert_to_rw_cache(past)
 
 
 
713
 
714
  return {
715
  "input_ids": input_ids,
716
- "past_key_values": past,
717
  "use_cache": kwargs.get("use_cache"),
718
  "attention_mask": attention_mask,
719
  }
720
 
721
  def forward(
722
- self,
723
- input_ids: Optional[torch.LongTensor] = None,
724
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
725
- attention_mask: Optional[torch.Tensor] = None,
726
- head_mask: Optional[torch.Tensor] = None,
727
- inputs_embeds: Optional[torch.Tensor] = None,
728
- labels: Optional[torch.Tensor] = None,
729
- use_cache: Optional[bool] = None,
730
- output_attentions: Optional[bool] = None,
731
- output_hidden_states: Optional[bool] = None,
732
- return_dict: Optional[bool] = None,
733
- **deprecated_arguments,
734
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
735
  r"""
736
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -790,7 +818,7 @@ class RWForCausalLM(RWPreTrainedModel):
790
  )
791
 
792
  def _reorder_cache(
793
- self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
794
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
795
  """
796
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
@@ -828,18 +856,18 @@ class RWForSequenceClassification(RWPreTrainedModel):
828
  self.post_init()
829
 
830
  def forward(
831
- self,
832
- input_ids: Optional[torch.LongTensor] = None,
833
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
834
- attention_mask: Optional[torch.Tensor] = None,
835
- head_mask: Optional[torch.Tensor] = None,
836
- inputs_embeds: Optional[torch.Tensor] = None,
837
- labels: Optional[torch.Tensor] = None,
838
- use_cache: Optional[bool] = None,
839
- output_attentions: Optional[bool] = None,
840
- output_hidden_states: Optional[bool] = None,
841
- return_dict: Optional[bool] = None,
842
- **deprecated_arguments,
843
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
844
  r"""
845
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -951,18 +979,18 @@ class RWForTokenClassification(RWPreTrainedModel):
951
  self.post_init()
952
 
953
  def forward(
954
- self,
955
- input_ids: Optional[torch.LongTensor] = None,
956
- past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
957
- attention_mask: Optional[torch.Tensor] = None,
958
- head_mask: Optional[torch.Tensor] = None,
959
- inputs_embeds: Optional[torch.Tensor] = None,
960
- labels: Optional[torch.Tensor] = None,
961
- use_cache: Optional[bool] = None,
962
- output_attentions: Optional[bool] = None,
963
- output_hidden_states: Optional[bool] = None,
964
- return_dict: Optional[bool] = None,
965
- **deprecated_arguments,
966
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
967
  r"""
968
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1028,17 +1056,17 @@ class RWForQuestionAnswering(RWPreTrainedModel):
1028
  self.post_init()
1029
 
1030
  def forward(
1031
- self,
1032
- input_ids: Optional[torch.LongTensor] = None,
1033
- attention_mask: Optional[torch.FloatTensor] = None,
1034
- position_ids: Optional[torch.LongTensor] = None,
1035
- head_mask: Optional[torch.FloatTensor] = None,
1036
- inputs_embeds: Optional[torch.FloatTensor] = None,
1037
- start_positions: Optional[torch.LongTensor] = None,
1038
- end_positions: Optional[torch.LongTensor] = None,
1039
- output_attentions: Optional[bool] = None,
1040
- output_hidden_states: Optional[bool] = None,
1041
- return_dict: Optional[bool] = None,
1042
  ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1043
  r"""
1044
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
25
 
26
  logger = logging.get_logger(__name__)
27
 
28
+
29
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
30
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
31
  class Linear(nn.Linear):
 
39
 
40
  from einops import rearrange
41
 
42
+
43
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
44
  def rotate_half(x):
45
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
46
  return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
47
 
48
 
 
53
  """
54
 
55
  def __init__(
56
+ self,
57
+ head_dim: int,
58
+ base=10000,
59
  ):
60
  super().__init__()
61
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
 
67
  self.sin_cached: torch.Tensor | None = None
68
 
69
  def cos_sin(
70
+ self,
71
+ seq_len: int,
72
+ device="cuda",
73
+ dtype=torch.bfloat16,
74
  ) -> torch.Tensor:
75
  if seq_len != self.seq_len_cached:
76
  self.seq_len_cached = seq_len
 
89
 
90
  return self.cos_cached, self.sin_cached
91
 
92
+ def forward(self, q, k, past_seq_length=None):
93
+ if past_seq_length is None:
94
+ batch, seq_len, head_dim = q.shape
95
+ else:
96
+ batch, input_seq_len, head_dim = q.shape
97
+ seq_len = input_seq_len + past_seq_length
98
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
99
+ if past_seq_length is not None:
100
+ return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (
101
+ k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
102
+ else:
103
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
104
 
105
 
106
  def _make_causal_mask(
107
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
108
  ) -> torch.BoolTensor:
109
  batch_size, target_length = input_ids_shape
110
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
111
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
112
  seq_ids = torch.arange(target_length, device=device)
113
+ mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
114
 
115
  if past_key_values_length > 0:
116
+ mask[:, :past_key_values_length] = True
117
 
118
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
119
  return expanded_mask
 
240
  return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
241
 
242
  def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ alibi: torch.Tensor,
246
+ attention_mask: torch.Tensor,
247
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
248
+ head_mask: Optional[torch.Tensor] = None,
249
+ use_cache: bool = False,
250
+ output_attentions: bool = False,
251
  ):
252
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
253
 
 
266
 
267
  query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
268
 
269
+ if layer_past is not None:
270
+ past_key, past_value = layer_past
271
+ past_kv_length = past_key.shape[2]
272
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
273
+ else:
274
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
275
+
276
  if layer_past is not None:
277
  past_key, past_value = layer_past
278
  # concatenate along seq_length dimension:
279
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
280
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
281
+ past_key = past_key.permute(0, 2, 1)
282
  key_layer = torch.cat((past_key, key_layer), dim=1)
283
  value_layer = torch.cat((past_value, value_layer), dim=1)
284
 
285
  _, kv_length, _ = key_layer.shape
286
 
287
  if use_cache is True:
288
+ key_layer_permute = key_layer.permute(0, 2, 1)
289
+ present = (key_layer_permute, value_layer)
290
  else:
291
  present = None
292
 
 
295
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
296
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
297
 
298
+ if attention_mask is not None:
299
+ attn_output = F.scaled_dot_product_attention(
300
+ query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
301
+ )
302
+ else:
303
+ attn_output = F.scaled_dot_product_attention(
304
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
305
+ )
306
 
307
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
308
  x = x.permute(0, 2, 1, 3)
 
327
  attention_scores = attention_scores.to(torch.float32)
328
  # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
329
  attention_probs = F.softmax(
330
+ (attention_scores + alibi.view(batch_size, self.num_heads, 1,
331
+ -1)) * self.inv_norm_factor + attention_mask_float,
332
  dim=-1,
333
  dtype=hidden_states.dtype,
334
  )
 
393
  self.config = config
394
 
395
  def forward(
396
+ self,
397
+ hidden_states: torch.Tensor,
398
+ alibi: torch.Tensor,
399
+ attention_mask: torch.Tensor,
400
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
401
+ head_mask: Optional[torch.Tensor] = None,
402
+ use_cache: bool = False,
403
+ output_attentions: bool = False,
404
  ):
405
 
406
  layernorm_output = self.input_layernorm(hidden_states)
 
478
 
479
  @staticmethod
480
  def _convert_to_standard_cache(
481
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
482
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
483
  """
484
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
 
498
 
499
  @staticmethod
500
  def _convert_to_rw_cache(
501
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
502
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
503
  batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
504
  batch_size_times_num_heads = batch_size * num_heads
 
539
  return self.word_embeddings
540
 
541
  def _prepare_attn_mask(
542
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
543
  ) -> torch.BoolTensor:
544
  # create causal mask
545
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 
547
  device = attention_mask.device
548
  _, src_length = input_shape
549
 
550
+ #if src_length > 1:
551
+ combined_attention_mask = _make_causal_mask(
552
+ input_shape, device=device, past_key_values_length=past_key_values_length
553
+ )
554
 
555
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
556
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
 
564
  self.word_embeddings = new_embeddings
565
 
566
  def forward(
567
+ self,
568
+ input_ids: Optional[torch.LongTensor] = None,
569
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ head_mask: Optional[torch.LongTensor] = None,
572
+ inputs_embeds: Optional[torch.LongTensor] = None,
573
+ use_cache: Optional[bool] = None,
574
+ output_attentions: Optional[bool] = None,
575
+ output_hidden_states: Optional[bool] = None,
576
+ return_dict: Optional[bool] = None,
577
+ **deprecated_arguments,
578
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
579
  if deprecated_arguments.pop("position_ids", False) is not False:
580
  # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
 
722
  self.lm_head = new_embeddings
723
 
724
  def prepare_inputs_for_generation(
725
+ self,
726
+ input_ids: torch.LongTensor,
727
+ past: Optional[torch.Tensor] = None,
728
+ attention_mask: Optional[torch.Tensor] = None,
729
+ **kwargs,
730
  ) -> dict:
731
  # only last token for input_ids if past is not None
732
+ if kwargs.get("past_key_values", None) :
733
  input_ids = input_ids[:, -1].unsqueeze(-1)
734
+ past_key_values = kwargs["past_key_values"]
735
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
736
+ # if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
737
+ # past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
738
+ # past_key_values = kwargs["past_key_values"]
739
+ else :
740
+ past_key_values = None
741
 
742
  return {
743
  "input_ids": input_ids,
744
+ "past_key_values": past_key_values,
745
  "use_cache": kwargs.get("use_cache"),
746
  "attention_mask": attention_mask,
747
  }
748
 
749
  def forward(
750
+ self,
751
+ input_ids: Optional[torch.LongTensor] = None,
752
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
753
+ attention_mask: Optional[torch.Tensor] = None,
754
+ head_mask: Optional[torch.Tensor] = None,
755
+ inputs_embeds: Optional[torch.Tensor] = None,
756
+ labels: Optional[torch.Tensor] = None,
757
+ use_cache: Optional[bool] = None,
758
+ output_attentions: Optional[bool] = None,
759
+ output_hidden_states: Optional[bool] = None,
760
+ return_dict: Optional[bool] = None,
761
+ **deprecated_arguments,
762
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
763
  r"""
764
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
818
  )
819
 
820
  def _reorder_cache(
821
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
822
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
823
  """
824
  This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
 
856
  self.post_init()
857
 
858
  def forward(
859
+ self,
860
+ input_ids: Optional[torch.LongTensor] = None,
861
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
862
+ attention_mask: Optional[torch.Tensor] = None,
863
+ head_mask: Optional[torch.Tensor] = None,
864
+ inputs_embeds: Optional[torch.Tensor] = None,
865
+ labels: Optional[torch.Tensor] = None,
866
+ use_cache: Optional[bool] = None,
867
+ output_attentions: Optional[bool] = None,
868
+ output_hidden_states: Optional[bool] = None,
869
+ return_dict: Optional[bool] = None,
870
+ **deprecated_arguments,
871
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
872
  r"""
873
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
979
  self.post_init()
980
 
981
  def forward(
982
+ self,
983
+ input_ids: Optional[torch.LongTensor] = None,
984
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
985
+ attention_mask: Optional[torch.Tensor] = None,
986
+ head_mask: Optional[torch.Tensor] = None,
987
+ inputs_embeds: Optional[torch.Tensor] = None,
988
+ labels: Optional[torch.Tensor] = None,
989
+ use_cache: Optional[bool] = None,
990
+ output_attentions: Optional[bool] = None,
991
+ output_hidden_states: Optional[bool] = None,
992
+ return_dict: Optional[bool] = None,
993
+ **deprecated_arguments,
994
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
995
  r"""
996
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1056
  self.post_init()
1057
 
1058
  def forward(
1059
+ self,
1060
+ input_ids: Optional[torch.LongTensor] = None,
1061
+ attention_mask: Optional[torch.FloatTensor] = None,
1062
+ position_ids: Optional[torch.LongTensor] = None,
1063
+ head_mask: Optional[torch.FloatTensor] = None,
1064
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1065
+ start_positions: Optional[torch.LongTensor] = None,
1066
+ end_positions: Optional[torch.LongTensor] = None,
1067
+ output_attentions: Optional[bool] = None,
1068
+ output_hidden_states: Optional[bool] = None,
1069
+ return_dict: Optional[bool] = None,
1070
  ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1071
  r"""
1072
  start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):