bwang0911 commited on
Commit
d7eb81e
1 Parent(s): a0ba9b2

Update modeling_bert.py (#2)

Browse files

- Update modeling_bert.py (1d6878a472195decd84a14a16930ffb9d2d25969)

Files changed (1) hide show
  1. modeling_bert.py +23 -1
modeling_bert.py CHANGED
@@ -699,6 +699,11 @@ class JinaBertEncoder(nn.Module):
699
  )
700
  self.gradient_checkpointing = False
701
  self.num_attention_heads = config.num_attention_heads
 
 
 
 
 
702
 
703
  def rebuild_alibi_tensor(
704
  self, size: int, device: Optional[Union[torch.device, str]] = None
@@ -766,7 +771,24 @@ class JinaBertEncoder(nn.Module):
766
 
767
  # Add alibi matrix to extended_attention_mask
768
  _, seqlen, _ = hidden_states.size()
769
- alibi_bias = self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(hidden_states.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
  if self.gradient_checkpointing and self.training:
771
  if use_cache:
772
  logger.warning_once(
 
699
  )
700
  self.gradient_checkpointing = False
701
  self.num_attention_heads = config.num_attention_heads
702
+ self.register_buffer(
703
+ "alibi",
704
+ self.rebuild_alibi_tensor(size=config.max_position_embeddings),
705
+ persistent=False,
706
+ )
707
 
708
  def rebuild_alibi_tensor(
709
  self, size: int, device: Optional[Union[torch.device, str]] = None
 
771
 
772
  # Add alibi matrix to extended_attention_mask
773
  _, seqlen, _ = hidden_states.size()
774
+ if self._current_alibi_size < seqlen:
775
+ # Rebuild the alibi tensor when needed
776
+ warnings.warn(
777
+ f'Increasing alibi size from {self._current_alibi_size} to {seqlen}.'
778
+ )
779
+ self.register_buffer(
780
+ "alibi",
781
+ self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(
782
+ hidden_states.dtype
783
+ ),
784
+ persistent=False,
785
+ )
786
+ elif self.alibi.device != hidden_states.device:
787
+ # Device catch-up
788
+ self.alibi = self.alibi.to(hidden_states.device)
789
+
790
+ alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
791
+
792
  if self.gradient_checkpointing and self.training:
793
  if use_cache:
794
  logger.warning_once(