bwang0911 commited on
Commit
2122141
1 Parent(s): 099544e

Update modeling_bert.py (#1)

Browse files

- Update modeling_bert.py (49a7d66af50f42d4b73950229a7874f86c0da0eb)

Files changed (1) hide show
  1. modeling_bert.py +1 -22
modeling_bert.py CHANGED
@@ -675,11 +675,6 @@ class JinaBertEncoder(nn.Module):
675
  )
676
  self.gradient_checkpointing = False
677
  self.num_attention_heads = config.num_attention_heads
678
- self.register_buffer(
679
- "alibi",
680
- self.rebuild_alibi_tensor(size=config.max_position_embeddings),
681
- persistent=False,
682
- )
683
 
684
  def rebuild_alibi_tensor(
685
  self, size: int, device: Optional[Union[torch.device, str]] = None
@@ -747,23 +742,7 @@ class JinaBertEncoder(nn.Module):
747
 
748
  # Add alibi matrix to extended_attention_mask
749
  _, seqlen, _ = hidden_states.size()
750
- if self._current_alibi_size < seqlen:
751
- # Rebuild the alibi tensor when needed
752
- warnings.warn(
753
- f'Increasing alibi size from {self._current_alibi_size} to {seqlen}.'
754
- )
755
- self.register_buffer(
756
- "alibi",
757
- self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(
758
- hidden_states.dtype
759
- ),
760
- persistent=False,
761
- )
762
- elif self.alibi.device != hidden_states.device:
763
- # Device catch-up
764
- self.alibi = self.alibi.to(hidden_states.device)
765
-
766
- alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
767
  if self.gradient_checkpointing and self.training:
768
  if use_cache:
769
  logger.warning_once(
 
675
  )
676
  self.gradient_checkpointing = False
677
  self.num_attention_heads = config.num_attention_heads
 
 
 
 
 
678
 
679
  def rebuild_alibi_tensor(
680
  self, size: int, device: Optional[Union[torch.device, str]] = None
 
742
 
743
  # Add alibi matrix to extended_attention_mask
744
  _, seqlen, _ = hidden_states.size()
745
+ alibi_bias = self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(hidden_states.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
746
  if self.gradient_checkpointing and self.training:
747
  if use_cache:
748
  logger.warning_once(