Markus28 commited on
Commit
02ebe52
1 Parent(s): 5bc2987

feat: set self.dropout_p in constructor

Browse files
Files changed (1) hide show
  1. modeling_bert.py +3 -2
modeling_bert.py CHANGED
@@ -281,7 +281,8 @@ class JinaBertSelfAttention(nn.Module):
281
  self.key = nn.Linear(config.hidden_size, self.all_head_size)
282
  self.value = nn.Linear(config.hidden_size, self.all_head_size)
283
 
284
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
 
285
  self.position_embedding_type = position_embedding_type or getattr(
286
  config, "position_embedding_type", "absolute"
287
  )
@@ -356,7 +357,7 @@ class JinaBertSelfAttention(nn.Module):
356
  if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
357
  b, _, s, _ = query_layer.shape
358
  new_bias = attention_mask + bias
359
- dropout_p = self.dropout.p if self.training else 0.0
360
  attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias, dropout_p=dropout_p)
361
  attn = attn.permute(0, 2, 1, 3).contiguous()
362
  return (attn.view(b, s, self.all_head_size),)
 
281
  self.key = nn.Linear(config.hidden_size, self.all_head_size)
282
  self.value = nn.Linear(config.hidden_size, self.all_head_size)
283
 
284
+ self.dropout_p = config.attention_probs_dropout_prob
285
+ self.dropout = nn.Dropout(self.dropout_p)
286
  self.position_embedding_type = position_embedding_type or getattr(
287
  config, "position_embedding_type", "absolute"
288
  )
 
357
  if self.attn_implementation == 'torch' and scaled_dot_product_attention is not None:
358
  b, _, s, _ = query_layer.shape
359
  new_bias = attention_mask + bias
360
+ dropout_p = self.dropout_p if self.training else 0.0
361
  attn = scaled_dot_product_attention(query_layer, key_layer, value_layer, new_bias, dropout_p=dropout_p)
362
  attn = attn.permute(0, 2, 1, 3).contiguous()
363
  return (attn.view(b, s, self.all_head_size),)