Jackmin108 commited on
Commit
84441b5
1 Parent(s): 0007991
Files changed (1) hide show
  1. modeling_bert.py +6 -4
modeling_bert.py CHANGED
@@ -280,6 +280,8 @@ class JinaBertSelfAttention(nn.Module):
280
  self.query = nn.Linear(config.hidden_size, self.all_head_size)
281
  self.key = nn.Linear(config.hidden_size, self.all_head_size)
282
  self.value = nn.Linear(config.hidden_size, self.all_head_size)
 
 
283
 
284
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
285
  self.position_embedding_type = position_embedding_type or getattr(
@@ -315,7 +317,7 @@ class JinaBertSelfAttention(nn.Module):
315
  output_attentions: Optional[bool] = False,
316
  bias: Optional[torch.FloatTensor] = None,
317
  ) -> Tuple[torch.Tensor]:
318
- mixed_query_layer = self.query(hidden_states)
319
 
320
  # If this is instantiated as a cross-attention module, the keys
321
  # and values come from an encoder; the attention mask needs to be
@@ -328,16 +330,16 @@ class JinaBertSelfAttention(nn.Module):
328
  value_layer = past_key_value[1]
329
  attention_mask = encoder_attention_mask
330
  elif is_cross_attention:
331
- key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
332
  value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
333
  attention_mask = encoder_attention_mask
334
  elif past_key_value is not None:
335
- key_layer = self.transpose_for_scores(self.key(hidden_states))
336
  value_layer = self.transpose_for_scores(self.value(hidden_states))
337
  key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
338
  value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
339
  else:
340
- key_layer = self.transpose_for_scores(self.key(hidden_states))
341
  value_layer = self.transpose_for_scores(self.value(hidden_states))
342
 
343
  query_layer = self.transpose_for_scores(mixed_query_layer)
 
280
  self.query = nn.Linear(config.hidden_size, self.all_head_size)
281
  self.key = nn.Linear(config.hidden_size, self.all_head_size)
282
  self.value = nn.Linear(config.hidden_size, self.all_head_size)
283
+ self.layer_norm_q = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
284
+ self.layer_norm_k = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
285
 
286
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
287
  self.position_embedding_type = position_embedding_type or getattr(
 
317
  output_attentions: Optional[bool] = False,
318
  bias: Optional[torch.FloatTensor] = None,
319
  ) -> Tuple[torch.Tensor]:
320
+ mixed_query_layer = self.layer_norm_q(self.query(hidden_states))
321
 
322
  # If this is instantiated as a cross-attention module, the keys
323
  # and values come from an encoder; the attention mask needs to be
 
330
  value_layer = past_key_value[1]
331
  attention_mask = encoder_attention_mask
332
  elif is_cross_attention:
333
+ key_layer = self.transpose_for_scores(self.layer_norm_k(self.key(encoder_hidden_states)))
334
  value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
335
  attention_mask = encoder_attention_mask
336
  elif past_key_value is not None:
337
+ key_layer = self.transpose_for_scores(self.layer_norm_k(self.key(hidden_states)))
338
  value_layer = self.transpose_for_scores(self.value(hidden_states))
339
  key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
340
  value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
341
  else:
342
+ key_layer = self.transpose_for_scores(self.layer_norm_k(self.key(hidden_states)))
343
  value_layer = self.transpose_for_scores(self.value(hidden_states))
344
 
345
  query_layer = self.transpose_for_scores(mixed_query_layer)