vaibhavad commited on
Commit
4f499ad
1 Parent(s): e3adaf9

Create attn_mask_utils.py

Browse files
Files changed (1) hide show
  1. attn_mask_utils.py +224 -0
attn_mask_utils.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ from packaging import version
4
+ import importlib.metadata
5
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
6
+
7
+ from transformers.utils.import_utils import _is_package_available
8
+
9
+ def is_transformers_attn_greater_or_equal_4_39():
10
+ if not _is_package_available("transformers"):
11
+ return False
12
+
13
+ return version.parse(importlib.metadata.version("transformers")) >= version.parse(
14
+ "4.39.0"
15
+ )
16
+
17
+ def _prepare_4d_attention_mask_for_sdpa(
18
+ attention_mask: Optional[torch.Tensor],
19
+ input_shape: Union[torch.Size, Tuple, List],
20
+ inputs_embeds: torch.Tensor,
21
+ past_key_values_length: int,
22
+ sliding_window: Optional[int] = None,
23
+ ):
24
+ attn_mask_converter = AttentionMaskConverter(is_causal=False, sliding_window=sliding_window)
25
+
26
+ key_value_length = input_shape[-1] + past_key_values_length
27
+ batch_size, query_length = input_shape
28
+
29
+ # torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
30
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
31
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
32
+ is_tracing = torch.jit.is_tracing()
33
+
34
+ if attention_mask is not None:
35
+ if torch.all(attention_mask == 1):
36
+ if is_tracing:
37
+ pass
38
+ elif query_length == 1:
39
+ # For query_length == 1, causal attention and bi-directional attention are the same.
40
+ attention_mask = None
41
+ elif key_value_length == query_length:
42
+ attention_mask = None
43
+ else:
44
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
45
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
46
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
47
+ pass
48
+ elif query_length > 1 and key_value_length != query_length:
49
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
50
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
51
+ attention_mask = True
52
+ elif is_tracing:
53
+ raise ValueError(
54
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
55
+ )
56
+
57
+ if attention_mask is None:
58
+ expanded_4d_mask = None
59
+ elif attention_mask is True:
60
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
61
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
62
+ )
63
+ else:
64
+ expanded_4d_mask = attn_mask_converter.to_4d(
65
+ attention_mask,
66
+ input_shape[-1],
67
+ dtype=inputs_embeds.dtype,
68
+ key_value_length=key_value_length,
69
+ )
70
+
71
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
72
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
73
+ if query_length > 1:
74
+ if is_transformers_attn_greater_or_equal_4_39():
75
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
76
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
77
+ )
78
+ else:
79
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
80
+ expanded_4d_mask, attention_mask, unmasked_value=0.0
81
+ )
82
+
83
+ return expanded_4d_mask
84
+
85
+
86
+ def _prepare_4d_attention_mask(
87
+ attention_mask: Optional[torch.Tensor],
88
+ input_shape: Union[torch.Size, Tuple, List],
89
+ inputs_embeds: torch.Tensor,
90
+ past_key_values_length: int,
91
+ sliding_window: Optional[int] = None,
92
+ ):
93
+ attn_mask_converter = AttentionMaskConverter(is_causal=False, sliding_window=sliding_window)
94
+
95
+ key_value_length = input_shape[-1] + past_key_values_length
96
+
97
+ # 4d mask is passed through the layers
98
+ if attention_mask is not None:
99
+ attention_mask = attn_mask_converter.to_4d(
100
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
101
+ )
102
+ else:
103
+ attention_mask = attn_mask_converter.to_causal_4d(
104
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
105
+ )
106
+
107
+ return attention_mask
108
+
109
+
110
+ def _prepare_4d_causal_attention_mask(
111
+ attention_mask: Optional[torch.Tensor],
112
+ input_shape: Union[torch.Size, Tuple, List],
113
+ inputs_embeds: torch.Tensor,
114
+ past_key_values_length: int,
115
+ sliding_window: Optional[int] = None,
116
+ ):
117
+ attn_mask_converter = AttentionMaskConverter(is_causal=False, sliding_window=sliding_window)
118
+
119
+ key_value_length = input_shape[-1] + past_key_values_length
120
+
121
+ # 4d mask is passed through the layers
122
+ if attention_mask is not None:
123
+ attention_mask = attn_mask_converter.to_4d(
124
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
125
+ )
126
+ else:
127
+ attention_mask = attn_mask_converter.to_causal_4d(
128
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
129
+ )
130
+
131
+ return attention_mask
132
+
133
+
134
+ def _prepare_4d_causal_attention_mask_for_sdpa(
135
+ attention_mask: Optional[torch.Tensor],
136
+ input_shape: Union[torch.Size, Tuple, List],
137
+ inputs_embeds: torch.Tensor,
138
+ past_key_values_length: int,
139
+ sliding_window: Optional[int] = None,
140
+ ):
141
+ """
142
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
143
+
144
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
145
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
146
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
147
+ """
148
+ attn_mask_converter = AttentionMaskConverter(is_causal=False, sliding_window=sliding_window)
149
+
150
+ key_value_length = input_shape[-1] + past_key_values_length
151
+ batch_size, query_length = input_shape
152
+
153
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
154
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
155
+ # TODO: Fix this as well when using torchdynamo with fullgraph=True.
156
+ is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy)
157
+
158
+ if attention_mask is not None:
159
+ # 4d mask is passed through
160
+ if len(attention_mask.shape) == 4:
161
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
162
+ if tuple(attention_mask.shape) != expected_shape:
163
+ raise ValueError(
164
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
165
+ )
166
+ else:
167
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
168
+ inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
169
+ attention_mask = inverted_mask.masked_fill(
170
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
171
+ )
172
+ return attention_mask
173
+
174
+ elif not is_tracing and torch.all(attention_mask == 1):
175
+ if query_length == 1:
176
+ # For query_length == 1, causal attention and bi-directional attention are the same.
177
+ attention_mask = None
178
+ elif key_value_length == query_length:
179
+ attention_mask = None
180
+ else:
181
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
182
+ # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
183
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
184
+ pass
185
+ elif query_length > 1 and key_value_length != query_length:
186
+ # See the comment above (https://github.com/pytorch/pytorch/issues/108108).
187
+ # Ugly: we set it to True here to dispatch in the following controlflow to `to_causal_4d`.
188
+ attention_mask = True
189
+ elif is_tracing:
190
+ raise ValueError(
191
+ 'Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument `attn_implementation="eager"` or pass an attention_mask input when tracing the model.'
192
+ )
193
+
194
+ if attention_mask is None:
195
+ expanded_4d_mask = None
196
+ elif attention_mask is True:
197
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
198
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
199
+ )
200
+ else:
201
+ expanded_4d_mask = attn_mask_converter.to_4d(
202
+ attention_mask,
203
+ input_shape[-1],
204
+ dtype=inputs_embeds.dtype,
205
+ key_value_length=key_value_length,
206
+ )
207
+
208
+ # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
209
+ # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
210
+ #
211
+ # This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent
212
+ # controlflow that can not be captured properly.
213
+ # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
214
+ if query_length > 1 and not is_tracing:
215
+ if is_transformers_attn_greater_or_equal_4_39():
216
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
217
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
218
+ )
219
+ else:
220
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
221
+ expanded_4d_mask, attention_mask, unmasked_value=0.0
222
+ )
223
+
224
+ return expanded_4d_mask