Feature Extraction
Transformers
Safetensors
English
bamboo
custom_code
Yixin Song commited on
Commit
ed33910
1 Parent(s): fb5d1a6

update modeling

Browse files
Files changed (1) hide show
  1. modeling_bamboo.py +473 -1
modeling_bamboo.py CHANGED
@@ -19,10 +19,14 @@
19
  # See the License for the specific language governing permissions and
20
  # limitations under the License.
21
  """ PyTorch Bamboo model."""
 
22
  import inspect
23
  import math
24
  import warnings
25
  from typing import List, Optional, Tuple, Union
 
 
 
26
 
27
  import torch
28
  import torch.nn.functional as F
@@ -34,7 +38,7 @@ from transformers.activations import ACT2FN
34
 
35
  from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.activations import ACT2FN
37
- from .modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
38
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
39
  from transformers.modeling_utils import PreTrainedModel
40
  from transformers.utils import (
@@ -46,6 +50,474 @@ from transformers.utils import (
46
  replace_return_docstrings,
47
  )
48
  from .configuration_bamboo import BambooConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  if is_flash_attn_2_available():
 
19
  # See the License for the specific language governing permissions and
20
  # limitations under the License.
21
  """ PyTorch Bamboo model."""
22
+ import torch
23
  import inspect
24
  import math
25
  import warnings
26
  from typing import List, Optional, Tuple, Union
27
+ from dataclasses import dataclass
28
+ from typing import List, Optional, Tuple, Union
29
+
30
 
31
  import torch
32
  import torch.nn.functional as F
 
38
 
39
  from transformers.cache_utils import Cache, DynamicCache
40
  from transformers.activations import ACT2FN
41
+ # from .modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
42
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
43
  from transformers.modeling_utils import PreTrainedModel
44
  from transformers.utils import (
 
50
  replace_return_docstrings,
51
  )
52
  from .configuration_bamboo import BambooConfig
53
+ @dataclass
54
+ class AttentionMaskConverter:
55
+ """
56
+ A utility attention mask class that allows one to:
57
+ - Create a causal 4d mask
58
+ - Create a causal 4d mask with slided window
59
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
60
+ key_value_length) that can be multiplied with attention scores
61
+ Examples:
62
+ ```python
63
+ >>> import torch
64
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
65
+ >>> converter = AttentionMaskConverter(True)
66
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
67
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
68
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
69
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
70
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
71
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
72
+ ```
73
+ Parameters:
74
+ is_causal (`bool`):
75
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
76
+ sliding_window (`int`, *optional*):
77
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
78
+ """
79
+
80
+ is_causal: bool
81
+ sliding_window: int
82
+
83
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
84
+ self.is_causal = is_causal
85
+ self.sliding_window = sliding_window
86
+
87
+ if self.sliding_window is not None and self.sliding_window <= 0:
88
+ raise ValueError(
89
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
90
+ )
91
+
92
+ def to_causal_4d(
93
+ self,
94
+ batch_size: int,
95
+ query_length: int,
96
+ key_value_length: int,
97
+ dtype: torch.dtype,
98
+ device: Union[torch.device, "str"] = "cpu",
99
+ ) -> Optional[torch.Tensor]:
100
+ """
101
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
102
+ bias to upper right hand triangular matrix (causal mask).
103
+ """
104
+ if not self.is_causal:
105
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
106
+
107
+ # If shape is not cached, create a new causal mask and cache it
108
+ input_shape = (batch_size, query_length)
109
+ past_key_values_length = key_value_length - query_length
110
+
111
+ # create causal mask
112
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
113
+ causal_4d_mask = None
114
+ if input_shape[-1] > 1 or self.sliding_window is not None:
115
+ causal_4d_mask = self._make_causal_mask(
116
+ input_shape,
117
+ dtype,
118
+ device=device,
119
+ past_key_values_length=past_key_values_length,
120
+ sliding_window=self.sliding_window,
121
+ )
122
+
123
+ return causal_4d_mask
124
+
125
+ def to_4d(
126
+ self,
127
+ attention_mask_2d: torch.Tensor,
128
+ query_length: int,
129
+ dtype: torch.dtype,
130
+ key_value_length: Optional[int] = None,
131
+ ) -> torch.Tensor:
132
+ """
133
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
134
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
135
+ causal, a causal mask will be added.
136
+ """
137
+ input_shape = (attention_mask_2d.shape[0], query_length)
138
+
139
+ # create causal mask
140
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
141
+ causal_4d_mask = None
142
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
143
+ if key_value_length is None:
144
+ raise ValueError(
145
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
146
+ )
147
+
148
+ past_key_values_length = key_value_length - query_length
149
+ causal_4d_mask = self._make_causal_mask(
150
+ input_shape,
151
+ dtype,
152
+ device=attention_mask_2d.device,
153
+ past_key_values_length=past_key_values_length,
154
+ sliding_window=self.sliding_window,
155
+ )
156
+ elif self.sliding_window is not None:
157
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
158
+
159
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
160
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
161
+ attention_mask_2d.device
162
+ )
163
+
164
+ if causal_4d_mask is not None:
165
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
166
+
167
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
168
+ expanded_4d_mask = expanded_attn_mask
169
+
170
+ return expanded_4d_mask
171
+
172
+ @staticmethod
173
+ def _make_causal_mask(
174
+ input_ids_shape: torch.Size,
175
+ dtype: torch.dtype,
176
+ device: torch.device,
177
+ past_key_values_length: int = 0,
178
+ sliding_window: Optional[int] = None,
179
+ ):
180
+ """
181
+ Make causal mask used for bi-directional self-attention.
182
+ """
183
+ bsz, tgt_len = input_ids_shape
184
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
185
+ mask_cond = torch.arange(mask.size(-1), device=device)
186
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
187
+
188
+ mask = mask.to(dtype)
189
+
190
+ if past_key_values_length > 0:
191
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
192
+
193
+ # add lower triangular sliding window mask if necessary
194
+ if sliding_window is not None:
195
+ diagonal = past_key_values_length - sliding_window + 1
196
+
197
+ context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
198
+ mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
199
+
200
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
201
+
202
+ @staticmethod
203
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
204
+ """
205
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
206
+ """
207
+ bsz, src_len = mask.size()
208
+ tgt_len = tgt_len if tgt_len is not None else src_len
209
+
210
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
211
+
212
+ inverted_mask = 1.0 - expanded_mask
213
+
214
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
215
+
216
+ @staticmethod
217
+ def _unmask_unattended(
218
+ expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
219
+ ):
220
+ # fmt: off
221
+ """
222
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
223
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
224
+ Details: https://github.com/pytorch/pytorch/issues/110213
225
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
226
+ `attention_mask` is [bsz, src_seq_len].
227
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
228
+ For example, if `attention_mask` is
229
+ ```
230
+ [[0, 0, 1],
231
+ [1, 1, 1],
232
+ [0, 1, 1]]
233
+ ```
234
+ and `expanded_mask` is (e.g. here left-padding case)
235
+ ```
236
+ [[[[0, 0, 0],
237
+ [0, 0, 0],
238
+ [0, 0, 1]]],
239
+ [[[1, 0, 0],
240
+ [1, 1, 0],
241
+ [1, 1, 1]]],
242
+ [[[0, 0, 0],
243
+ [0, 1, 0],
244
+ [0, 1, 1]]]]
245
+ ```
246
+ then the modified `expanded_mask` will be
247
+ ```
248
+ [[[[1, 1, 1], <-- modified
249
+ [1, 1, 1], <-- modified
250
+ [0, 0, 1]]],
251
+ [[[1, 0, 0],
252
+ [1, 1, 0],
253
+ [1, 1, 1]]],
254
+ [[[1, 1, 1], <-- modified
255
+ [0, 1, 0],
256
+ [0, 1, 1]]]]
257
+ ```
258
+ """
259
+ # fmt: on
260
+
261
+ # Get the index of the first non-zero value for every sample in the batch.
262
+ # In the above example, indices = [[2], [0], [1]]]
263
+ tmp = torch.arange(attention_mask.shape[1], 0, -1)
264
+ indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)
265
+
266
+ # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
267
+ # expanded mask will be completely unattended.
268
+ left_masked_rows = torch.where(indices > 0)[0]
269
+
270
+ if left_masked_rows.shape[0] == 0:
271
+ return expanded_mask
272
+ indices = indices[left_masked_rows]
273
+
274
+ max_len = torch.max(indices)
275
+ range_tensor = torch.arange(max_len).unsqueeze(0)
276
+ range_tensor = range_tensor.repeat(indices.size(0), 1)
277
+
278
+ # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
279
+ range_tensor[range_tensor >= indices] = 0
280
+
281
+ # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
282
+ if expanded_mask.dim() == 4:
283
+ num_masks = expanded_mask.shape[1]
284
+ if num_masks == 1:
285
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
286
+ mask_slice = (left_masked_rows[:, None], 0, range_tensor)
287
+ else:
288
+ # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
289
+ mask_slice = (
290
+ left_masked_rows[:, None, None],
291
+ torch.arange(num_masks)[None, :, None],
292
+ range_tensor[:, None, :],
293
+ )
294
+ else:
295
+ # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
296
+ mask_slice = (left_masked_rows[:, None], range_tensor)
297
+
298
+ expanded_mask[mask_slice] = unmasked_value
299
+
300
+ return expanded_mask
301
+
302
+
303
+ def _prepare_4d_causal_attention_mask(
304
+ attention_mask: Optional[torch.Tensor],
305
+ input_shape: Union[torch.Size, Tuple, List],
306
+ inputs_embeds: torch.Tensor,
307
+ past_key_values_length: int,
308
+ sliding_window: Optional[int] = None,
309
+ ):
310
+ """
311
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
312
+ `(batch_size, key_value_length)`
313
+ Args:
314
+ attention_mask (`torch.Tensor` or `None`):
315
+ A 2D attention mask of shape `(batch_size, key_value_length)`
316
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
317
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
318
+ inputs_embeds (`torch.Tensor`):
319
+ The embedded inputs as a torch Tensor.
320
+ past_key_values_length (`int`):
321
+ The length of the key value cache.
322
+ sliding_window (`int`, *optional*):
323
+ If the model uses windowed attention, a sliding window should be passed.
324
+ """
325
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
326
+
327
+ key_value_length = input_shape[-1] + past_key_values_length
328
+
329
+ # 4d mask is passed through the layers
330
+ if attention_mask is not None and len(attention_mask.shape) == 2:
331
+ attention_mask = attn_mask_converter.to_4d(
332
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
333
+ )
334
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
335
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
336
+ if tuple(attention_mask.shape) != expected_shape:
337
+ raise ValueError(
338
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
339
+ )
340
+ else:
341
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
342
+ inverted_mask = 1.0 - attention_mask
343
+ attention_mask = inverted_mask.masked_fill(
344
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
345
+ )
346
+ else:
347
+ attention_mask = attn_mask_converter.to_causal_4d(
348
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
349
+ )
350
+
351
+ return attention_mask
352
+
353
+
354
+ # Adapted from _prepare_4d_causal_attention_mask
355
+ def _prepare_4d_causal_attention_mask_for_sdpa(
356
+ attention_mask: Optional[torch.Tensor],
357
+ input_shape: Union[torch.Size, Tuple, List],
358
+ inputs_embeds: torch.Tensor,
359
+ past_key_values_length: int,
360
+ sliding_window: Optional[int] = None,
361
+ ):
362
+ """
363
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
364
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
365
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
366
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
367
+ """
368
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
369
+
370
+ key_value_length = input_shape[-1] + past_key_values_length
371
+ batch_size, query_length = input_shape
372
+
373
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
374
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
375
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
376
+ is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)
377
+
378
+ if attention_mask is not None:
379
+ # 4d mask is passed through
380
+ if len(attention_mask.shape) == 4:
381
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
382
+ if tuple(attention_mask.shape) != expected_shape:
383
+ raise ValueError(
384
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
385
+ )
386
+ else:
387
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
388
+ inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
389
+ attention_mask = inverted_mask.masked_fill(
390
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
391
+ )
392
+ return attention_mask
393
+
394
+ elif not is_tracing and torch.all(attention_mask == 1):
395
+ if query_length == 1:
396
+ # For query_length == 1, causal attention and bi-directional attention are the same.
397
+ attention_mask = None
398
+ elif key_value_length == query_length:
399
+ attention_mask = None
400
+ else:
401
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
402
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
403
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
404
+ pass
405
+ elif query_length > 1 and key_value_length != query_length:
406
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
407
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
408
+ attention_mask = True
409
+ elif is_tracing:
410
+ raise ValueError(
411
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
412
+ )
413
+
414
+ if attention_mask is None:
415
+ expanded_4d_mask = None
416
+ elif attention_mask is True:
417
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
418
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
419
+ )
420
+ else:
421
+ expanded_4d_mask = attn_mask_converter.to_4d(
422
+ attention_mask,
423
+ input_shape[-1],
424
+ dtype=inputs_embeds.dtype,
425
+ key_value_length=key_value_length,
426
+ )
427
+
428
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
429
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
430
+ #
431
+ # This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent
432
+ # controlflow that can not be captured properly.
433
+ # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
434
+ if query_length > 1 and not is_tracing:
435
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
436
+ expanded_4d_mask, attention_mask, unmasked_value=0.0
437
+ )
438
+
439
+ return expanded_4d_mask
440
+
441
+
442
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
443
+ """
444
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
445
+ `(batch_size, key_value_length)`
446
+ Args:
447
+ mask (`torch.Tensor` or `None`):
448
+ A 2D attention mask of shape `(batch_size, key_value_length)`
449
+ dtype (`torch.dtype`):
450
+ The torch dtype the created mask shall have.
451
+ tgt_len (`int`):
452
+ The target length or query length the created mask shall have.
453
+ """
454
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
455
+
456
+
457
+ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
458
+ """
459
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
460
+ `(batch_size, key_value_length)`
461
+ Args:
462
+ mask (`torch.Tensor` or `None`):
463
+ A 2D attention mask of shape `(batch_size, key_value_length)`
464
+ dtype (`torch.dtype`):
465
+ The torch dtype the created mask shall have.
466
+ tgt_len (`int`):
467
+ The target length or query length the created mask shall have.
468
+ """
469
+ batch_size, key_value_length = mask.shape
470
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
471
+
472
+ # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
473
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
474
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
475
+ is_tracing = torch.jit.is_tracing()
476
+
477
+ if torch.all(mask == 1):
478
+ if is_tracing:
479
+ pass
480
+ elif tgt_len == 1:
481
+ # For query_length == 1, causal attention and bi-directional attention are the same.
482
+ return None
483
+ elif key_value_length == tgt_len:
484
+ return None
485
+ else:
486
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we can not generally ignore the attention mask, as SDPA causal mask generation
487
+ # may be wrong. We will set is_causal=False in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
488
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
489
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
490
+ else:
491
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
492
+
493
+
494
+ def _create_4d_causal_attention_mask(
495
+ input_shape: Union[torch.Size, Tuple, List],
496
+ dtype: torch.dtype,
497
+ device: torch.device,
498
+ past_key_values_length: int = 0,
499
+ sliding_window: Optional[int] = None,
500
+ ) -> Optional[torch.Tensor]:
501
+ """
502
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
503
+ Args:
504
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
505
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
506
+ dtype (`torch.dtype`):
507
+ The torch dtype the created mask shall have.
508
+ device (`int`):
509
+ The torch device the created mask shall have.
510
+ sliding_window (`int`, *optional*):
511
+ If the model uses windowed attention, a sliding window should be passed.
512
+ """
513
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
514
+
515
+ key_value_length = past_key_values_length + input_shape[-1]
516
+ attention_mask = attn_mask_converter.to_causal_4d(
517
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
518
+ )
519
+
520
+ return attention_mask
521
 
522
 
523
  if is_flash_attn_2_available():