Jackmin108
commited on
Commit
•
b5794c5
1
Parent(s):
344bcbc
Allow pytorch<2 to use without passing attn_implementation flag (#4)
Browse files- Allow pytorch<2 to use without passing attn_implementation flag (969de83296b491de451557789e9770b9335612bb)
- modeling_bert.py +1 -1
modeling_bert.py
CHANGED
@@ -353,7 +353,7 @@ class JinaBertSelfAttention(nn.Module):
|
|
353 |
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
354 |
past_key_value = (key_layer, value_layer)
|
355 |
|
356 |
-
if self.attn_implementation == 'torch':
|
357 |
b, _, s, _ = query_layer.shape
|
358 |
new_bias = attention_mask + bias
|
359 |
attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias)
|
|
|
353 |
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
354 |
past_key_value = (key_layer, value_layer)
|
355 |
|
356 |
+
if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
|
357 |
b, _, s, _ = query_layer.shape
|
358 |
new_bias = attention_mask + bias
|
359 |
attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias)
|