vaibhavad commited on
Commit
48b3d3b
1 Parent(s): 813a65a

Adjust for latest transformer version

Browse files
Files changed (1) hide show
  1. attn_mask_utils.py +29 -7
attn_mask_utils.py CHANGED
@@ -1,7 +1,19 @@
1
  from typing import List, Optional, Tuple, Union
2
  import torch
 
 
3
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
4
 
 
 
 
 
 
 
 
 
 
 
5
  def _prepare_4d_attention_mask_for_sdpa(
6
  attention_mask: Optional[torch.Tensor],
7
  input_shape: Union[torch.Size, Tuple, List],
@@ -59,9 +71,14 @@ def _prepare_4d_attention_mask_for_sdpa(
59
  # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
60
  # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
61
  if query_length > 1:
62
- expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
63
- expanded_4d_mask, attention_mask, unmasked_value=0.0
64
- )
 
 
 
 
 
65
 
66
  return expanded_4d_mask
67
 
@@ -195,8 +212,13 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
195
  # controlflow that can not be captured properly.
196
  # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
197
  if query_length > 1 and not is_tracing:
198
- expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
199
- expanded_4d_mask, attention_mask, unmasked_value=0.0
200
- )
 
 
 
 
 
201
 
202
- return expanded_4d_mask
 
1
  from typing import List, Optional, Tuple, Union
2
  import torch
3
+ from packaging import version
4
+ import importlib.metadata
5
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
6
 
7
+ from transformers.utils.import_utils import _is_package_available
8
+
9
+ def is_transformers_attn_greater_or_equal_4_39():
10
+ if not _is_package_available("transformers"):
11
+ return False
12
+
13
+ return version.parse(importlib.metadata.version("transformers")) >= version.parse(
14
+ "4.39.0"
15
+ )
16
+
17
  def _prepare_4d_attention_mask_for_sdpa(
18
  attention_mask: Optional[torch.Tensor],
19
  input_shape: Union[torch.Size, Tuple, List],
 
71
  # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
72
  # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
73
  if query_length > 1:
74
+ if is_transformers_attn_greater_or_equal_4_39():
75
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
76
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
77
+ )
78
+ else:
79
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
80
+ expanded_4d_mask, attention_mask, unmasked_value=0.0
81
+ )
82
 
83
  return expanded_4d_mask
84
 
 
212
  # controlflow that can not be captured properly.
213
  # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
214
  if query_length > 1 and not is_tracing:
215
+ if is_transformers_attn_greater_or_equal_4_39():
216
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
217
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
218
+ )
219
+ else:
220
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
221
+ expanded_4d_mask, attention_mask, unmasked_value=0.0
222
+ )
223
 
224
+ return expanded_4d_mask