isacat commited on
Commit
f669876
1 Parent(s): 8584efc

Update modeling_bert.py

Browse files
Files changed (1) hide show
  1. modeling_bert.py +1 -22
modeling_bert.py CHANGED
@@ -697,11 +697,6 @@ class JinaBertEncoder(nn.Module):
697
  )
698
  self.gradient_checkpointing = False
699
  self.num_attention_heads = config.num_attention_heads
700
- self.register_buffer(
701
- "alibi",
702
- self.rebuild_alibi_tensor(size=config.max_position_embeddings),
703
- persistent=False,
704
- )
705
 
706
  def rebuild_alibi_tensor(
707
  self, size: int, device: Optional[Union[torch.device, str]] = None
@@ -769,23 +764,7 @@ class JinaBertEncoder(nn.Module):
769
 
770
  # Add alibi matrix to extended_attention_mask
771
  _, seqlen, _ = hidden_states.size()
772
- if self._current_alibi_size < seqlen:
773
- # Rebuild the alibi tensor when needed
774
- warnings.warn(
775
- f'Increasing alibi size from {self._current_alibi_size} to {seqlen}.'
776
- )
777
- self.register_buffer(
778
- "alibi",
779
- self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(
780
- hidden_states.dtype
781
- ),
782
- persistent=False,
783
- )
784
- elif self.alibi.device != hidden_states.device:
785
- # Device catch-up
786
- self.alibi = self.alibi.to(hidden_states.device)
787
-
788
- alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
789
  if self.gradient_checkpointing and self.training:
790
  if use_cache:
791
  logger.warning_once(
 
697
  )
698
  self.gradient_checkpointing = False
699
  self.num_attention_heads = config.num_attention_heads
 
 
 
 
 
700
 
701
  def rebuild_alibi_tensor(
702
  self, size: int, device: Optional[Union[torch.device, str]] = None
 
764
 
765
  # Add alibi matrix to extended_attention_mask
766
  _, seqlen, _ = hidden_states.size()
767
+ alibi_bias = self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(hidden_states.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
  if self.gradient_checkpointing and self.training:
769
  if use_cache:
770
  logger.warning_once(