Update modeling_bert.py
Browse files- modeling_bert.py +23 -1
modeling_bert.py
CHANGED
@@ -699,6 +699,11 @@ class JinaBertEncoder(nn.Module):
|
|
699 |
)
|
700 |
self.gradient_checkpointing = False
|
701 |
self.num_attention_heads = config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
|
702 |
|
703 |
def rebuild_alibi_tensor(
|
704 |
self, size: int, device: Optional[Union[torch.device, str]] = None
|
@@ -766,7 +771,24 @@ class JinaBertEncoder(nn.Module):
|
|
766 |
|
767 |
# Add alibi matrix to extended_attention_mask
|
768 |
_, seqlen, _ = hidden_states.size()
|
769 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
770 |
if self.gradient_checkpointing and self.training:
|
771 |
if use_cache:
|
772 |
logger.warning_once(
|
|
|
699 |
)
|
700 |
self.gradient_checkpointing = False
|
701 |
self.num_attention_heads = config.num_attention_heads
|
702 |
+
self.register_buffer(
|
703 |
+
"alibi",
|
704 |
+
self.rebuild_alibi_tensor(size=config.max_position_embeddings),
|
705 |
+
persistent=False,
|
706 |
+
)
|
707 |
|
708 |
def rebuild_alibi_tensor(
|
709 |
self, size: int, device: Optional[Union[torch.device, str]] = None
|
|
|
771 |
|
772 |
# Add alibi matrix to extended_attention_mask
|
773 |
_, seqlen, _ = hidden_states.size()
|
774 |
+
if self._current_alibi_size < seqlen:
|
775 |
+
# Rebuild the alibi tensor when needed
|
776 |
+
warnings.warn(
|
777 |
+
f'Increasing alibi size from {self._current_alibi_size} to {seqlen}.'
|
778 |
+
)
|
779 |
+
self.register_buffer(
|
780 |
+
"alibi",
|
781 |
+
self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(
|
782 |
+
hidden_states.dtype
|
783 |
+
),
|
784 |
+
persistent=False,
|
785 |
+
)
|
786 |
+
elif self.alibi.device != hidden_states.device:
|
787 |
+
# Device catch-up
|
788 |
+
self.alibi = self.alibi.to(hidden_states.device)
|
789 |
+
|
790 |
+
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
|
791 |
+
|
792 |
if self.gradient_checkpointing and self.training:
|
793 |
if use_cache:
|
794 |
logger.warning_once(
|