Update modeling_bert.py
Browse files- modeling_bert.py +1 -22
modeling_bert.py
CHANGED
@@ -697,11 +697,6 @@ class JinaBertEncoder(nn.Module):
|
|
697 |
)
|
698 |
self.gradient_checkpointing = False
|
699 |
self.num_attention_heads = config.num_attention_heads
|
700 |
-
self.register_buffer(
|
701 |
-
"alibi",
|
702 |
-
self.rebuild_alibi_tensor(size=config.max_position_embeddings),
|
703 |
-
persistent=False,
|
704 |
-
)
|
705 |
|
706 |
def rebuild_alibi_tensor(
|
707 |
self, size: int, device: Optional[Union[torch.device, str]] = None
|
@@ -769,23 +764,7 @@ class JinaBertEncoder(nn.Module):
|
|
769 |
|
770 |
# Add alibi matrix to extended_attention_mask
|
771 |
_, seqlen, _ = hidden_states.size()
|
772 |
-
|
773 |
-
# Rebuild the alibi tensor when needed
|
774 |
-
warnings.warn(
|
775 |
-
f'Increasing alibi size from {self._current_alibi_size} to {seqlen}.'
|
776 |
-
)
|
777 |
-
self.register_buffer(
|
778 |
-
"alibi",
|
779 |
-
self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(
|
780 |
-
hidden_states.dtype
|
781 |
-
),
|
782 |
-
persistent=False,
|
783 |
-
)
|
784 |
-
elif self.alibi.device != hidden_states.device:
|
785 |
-
# Device catch-up
|
786 |
-
self.alibi = self.alibi.to(hidden_states.device)
|
787 |
-
|
788 |
-
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
|
789 |
if self.gradient_checkpointing and self.training:
|
790 |
if use_cache:
|
791 |
logger.warning_once(
|
|
|
697 |
)
|
698 |
self.gradient_checkpointing = False
|
699 |
self.num_attention_heads = config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
|
700 |
|
701 |
def rebuild_alibi_tensor(
|
702 |
self, size: int, device: Optional[Union[torch.device, str]] = None
|
|
|
764 |
|
765 |
# Add alibi matrix to extended_attention_mask
|
766 |
_, seqlen, _ = hidden_states.size()
|
767 |
+
alibi_bias = self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device).to(hidden_states.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
768 |
if self.gradient_checkpointing and self.training:
|
769 |
if use_cache:
|
770 |
logger.warning_once(
|