Positional Interpolation

#14
Files changed (1) hide show
  1. modeling_bert.py +2 -1
modeling_bert.py CHANGED
@@ -787,7 +787,8 @@ class JinaBertEncoder(nn.Module):
787
  # Device catch-up
788
  self.alibi = self.alibi.to(hidden_states.device)
789
 
790
- alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
 
791
  if self.gradient_checkpointing and self.training:
792
  if use_cache:
793
  logger.warning_once(
 
787
  # Device catch-up
788
  self.alibi = self.alibi.to(hidden_states.device)
789
 
790
+ unpadded_seqlens = torch.sum(attention_mask, dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
791
+ alibi_bias = self.alibi[:, :, :seqlen, :seqlen] * 512 / unpadded_seqlens
792
  if self.gradient_checkpointing and self.training:
793
  if use_cache:
794
  logger.warning_once(