Markus28 commited on
Commit
5bc2987
1 Parent(s): c41d17d

fix: use attention dropout with torch SDPA implementation

Browse files
Files changed (1) hide show
  1. modeling_bert.py +2 -1
modeling_bert.py CHANGED
@@ -356,7 +356,8 @@ class JinaBertSelfAttention(nn.Module):
356
  if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
357
  b, _, s, _ = query_layer.shape
358
  new_bias = attention_mask + bias
359
- attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias)
 
360
  attn = attn.permute(0, 2, 1, 3).contiguous()
361
  return (attn.view(b, s, self.all_head_size),)
362
 
 
356
  if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
357
  b, _, s, _ = query_layer.shape
358
  new_bias = attention_mask + bias
359
+ dropout_p = self.dropout.p if self.training else 0.0
360
+ attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias, dropout_p=dropout_p)
361
  attn = attn.permute(0, 2, 1, 3).contiguous()
362
  return (attn.view(b, s, self.all_head_size),)
363