vaibhavad commited on
Commit
5069f5e
1 Parent(s): 48b3d3b

Update attn_mask_utils.py

Browse files
Files changed (1) hide show
  1. attn_mask_utils.py +3 -2
attn_mask_utils.py CHANGED
@@ -175,8 +175,9 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
175
  if query_length == 1:
176
  # For query_length == 1, causal attention and bi-directional attention are the same.
177
  attention_mask = None
178
- elif key_value_length == query_length:
179
- attention_mask = None
 
180
  else:
181
  # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
182
  # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
 
175
  if query_length == 1:
176
  # For query_length == 1, causal attention and bi-directional attention are the same.
177
  attention_mask = None
178
+ # Commented out to deal with batch size=1 cases
179
+ # elif key_value_length == query_length:
180
+ # attention_mask = None
181
  else:
182
  # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation
183
  # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.