Allow pytorch<2 to use without passing attn_implementation flag

#4
by Jackmin108 - opened
Files changed (1) hide show
  1. modeling_bert.py +1 -1
modeling_bert.py CHANGED
@@ -353,7 +353,7 @@ class JinaBertSelfAttention(nn.Module):
353
  # if encoder bi-directional self-attention `past_key_value` is always `None`
354
  past_key_value = (key_layer, value_layer)
355
 
356
- if self.attn_implementation == 'torch':
357
  b, _, s, _ = query_layer.shape
358
  new_bias = attention_mask + bias
359
  attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias)
 
353
  # if encoder bi-directional self-attention `past_key_value` is always `None`
354
  past_key_value = (key_layer, value_layer)
355
 
356
+ if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
357
  b, _, s, _ = query_layer.shape
358
  new_bias = attention_mask + bias
359
  attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias)