alaeddine-13 commited on
Commit
b4f2b16
1 Parent(s): e36c994

rename to jina bert

Browse files
Files changed (1) hide show
  1. modeling_bert.py +75 -129
modeling_bert.py CHANGED
@@ -54,7 +54,7 @@ from transformers.utils import (
54
  logging,
55
  replace_return_docstrings,
56
  )
57
- from .configuration_bert import MyBertConfig
58
 
59
  try:
60
  from tqdm.autonotebook import trange
@@ -66,7 +66,7 @@ except ImportError:
66
  logger = logging.get_logger(__name__)
67
 
68
  _CHECKPOINT_FOR_DOC = "bert-base-uncased"
69
- _CONFIG_FOR_DOC = "MyBertConfig"
70
 
71
  # TokenClassification docstring
72
  _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = (
@@ -197,10 +197,10 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
197
  return model
198
 
199
 
200
- class MyBertEmbeddings(nn.Module):
201
  """Construct the embeddings from word, position and token_type embeddings."""
202
 
203
- def __init__(self, config: MyBertConfig):
204
  super().__init__()
205
  self.word_embeddings = nn.Embedding(
206
  config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
@@ -280,7 +280,7 @@ class MyBertEmbeddings(nn.Module):
280
  return embeddings
281
 
282
 
283
- class MyBertSelfAttention(nn.Module):
284
  def __init__(self, config, position_embedding_type=None):
285
  super().__init__()
286
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
@@ -448,7 +448,7 @@ class MyBertSelfAttention(nn.Module):
448
  return outputs
449
 
450
 
451
- class MyBertSelfOutput(nn.Module):
452
  def __init__(self, config):
453
  super().__init__()
454
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@@ -464,13 +464,13 @@ class MyBertSelfOutput(nn.Module):
464
  return hidden_states
465
 
466
 
467
- class MyBertAttention(nn.Module):
468
  def __init__(self, config, position_embedding_type=None):
469
  super().__init__()
470
- self.self = MyBertSelfAttention(
471
  config, position_embedding_type=position_embedding_type
472
  )
473
- self.output = MyBertSelfOutput(config)
474
  self.pruned_heads = set()
475
 
476
  def prune_heads(self, heads):
@@ -524,7 +524,7 @@ class MyBertAttention(nn.Module):
524
  return outputs
525
 
526
 
527
- class MyBertIntermediate(nn.Module):
528
  def __init__(self, config):
529
  super().__init__()
530
  self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
@@ -539,8 +539,8 @@ class MyBertIntermediate(nn.Module):
539
  return hidden_states
540
 
541
 
542
- class MyBertOutput(nn.Module):
543
- def __init__(self, config: MyBertConfig):
544
  super().__init__()
545
  self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
546
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -555,8 +555,8 @@ class MyBertOutput(nn.Module):
555
  return hidden_states
556
 
557
 
558
- class MyBertGLUMLP(nn.Module):
559
- def __init__(self, config: MyBertConfig):
560
  super().__init__()
561
  self.config = config
562
  self.gated_layers = nn.Linear(
@@ -589,12 +589,12 @@ class MyBertGLUMLP(nn.Module):
589
  return hidden_states
590
 
591
 
592
- class MyBertLayer(nn.Module):
593
- def __init__(self, config: MyBertConfig):
594
  super().__init__()
595
  self.chunk_size_feed_forward = config.chunk_size_feed_forward
596
  self.seq_len_dim = 1
597
- self.attention = MyBertAttention(config)
598
  self.is_decoder = config.is_decoder
599
  self.add_cross_attention = config.add_cross_attention
600
  self.feed_forward_type = config.feed_forward_type
@@ -603,14 +603,14 @@ class MyBertLayer(nn.Module):
603
  raise ValueError(
604
  f"{self} should be used as a decoder model if cross attention is added"
605
  )
606
- self.crossattention = MyBertAttention(
607
  config, position_embedding_type="absolute"
608
  )
609
  if self.feed_forward_type.endswith('glu'):
610
- self.mlp = MyBertGLUMLP(config)
611
  else:
612
- self.intermediate = MyBertIntermediate(config)
613
- self.output = MyBertOutput(config)
614
 
615
  def forward(
616
  self,
@@ -699,12 +699,12 @@ class MyBertLayer(nn.Module):
699
  return layer_output
700
 
701
 
702
- class MyBertEncoder(nn.Module):
703
- def __init__(self, config: MyBertConfig):
704
  super().__init__()
705
  self.config = config
706
  self.layer = nn.ModuleList(
707
- [MyBertLayer(config) for _ in range(config.num_hidden_layers)]
708
  )
709
  self.gradient_checkpointing = False
710
  self.num_attention_heads = config.num_attention_heads
@@ -724,26 +724,6 @@ class MyBertEncoder(nn.Module):
724
  # will be applied, it is necessary to construct the diagonal mask.
725
  n_heads = self.num_attention_heads
726
 
727
- # Mosaics one
728
- # def _get_alibi_head_slopes(n_heads: int) -> List[float]:
729
- # def get_slopes_power_of_2(n_heads: int) -> List[float]:
730
- # start = 2 ** (-(2 ** -(math.log2(n_heads) - 3)))
731
- # ratio = start
732
- # return [start * ratio**i for i in range(n_heads)]
733
-
734
- # # In the paper, they only train models that have 2^a heads for some a. This function
735
- # # has some good properties that only occur when the input is a power of 2. To
736
- # # maintain that even when the number of heads is not a power of 2, we use a
737
- # # workaround.
738
- # if math.log2(n_heads).is_integer():
739
- # return get_slopes_power_of_2(n_heads)
740
-
741
- # closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
742
- # slopes_a = get_slopes_power_of_2(closest_power_of_2)
743
- # slopes_b = _get_alibi_head_slopes(2 * closest_power_of_2)
744
- # slopes_b = slopes_b[0::2][: n_heads - closest_power_of_2]
745
- # return slopes_a + slopes_b
746
-
747
  def _get_alibi_head_slopes(n_heads: int) -> List[float]:
748
  def get_slopes_power_of_2(n):
749
  start = 2 ** (-(2 ** -(math.log2(n) - 3)))
@@ -893,7 +873,7 @@ class MyBertEncoder(nn.Module):
893
  )
894
 
895
 
896
- class MyBertPooler(nn.Module):
897
  def __init__(self, config):
898
  super().__init__()
899
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@@ -908,7 +888,7 @@ class MyBertPooler(nn.Module):
908
  return pooled_output
909
 
910
 
911
- class MyBertPredictionHeadTransform(nn.Module):
912
  def __init__(self, config):
913
  super().__init__()
914
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@@ -925,10 +905,10 @@ class MyBertPredictionHeadTransform(nn.Module):
925
  return hidden_states
926
 
927
 
928
- class MyBertLMPredictionHead(nn.Module):
929
  def __init__(self, config):
930
  super().__init__()
931
- self.transform = MyBertPredictionHeadTransform(config)
932
 
933
  # The output weights are the same as the input embeddings, but there is
934
  # an output-only bias for each token.
@@ -945,17 +925,17 @@ class MyBertLMPredictionHead(nn.Module):
945
  return hidden_states
946
 
947
 
948
- class MyBertOnlyMLMHead(nn.Module):
949
  def __init__(self, config):
950
  super().__init__()
951
- self.predictions = MyBertLMPredictionHead(config)
952
 
953
  def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
954
  prediction_scores = self.predictions(sequence_output)
955
  return prediction_scores
956
 
957
 
958
- class MyBertOnlyNSPHead(nn.Module):
959
  def __init__(self, config):
960
  super().__init__()
961
  self.seq_relationship = nn.Linear(config.hidden_size, 2)
@@ -965,10 +945,10 @@ class MyBertOnlyNSPHead(nn.Module):
965
  return seq_relationship_score
966
 
967
 
968
- class MyBertPreTrainingHeads(nn.Module):
969
  def __init__(self, config):
970
  super().__init__()
971
- self.predictions = MyBertLMPredictionHead(config)
972
  self.seq_relationship = nn.Linear(config.hidden_size, 2)
973
 
974
  def forward(self, sequence_output, pooled_output):
@@ -977,13 +957,13 @@ class MyBertPreTrainingHeads(nn.Module):
977
  return prediction_scores, seq_relationship_score
978
 
979
 
980
- class MyBertPreTrainedModel(PreTrainedModel):
981
  """
982
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
983
  models.
984
  """
985
 
986
- config_class = MyBertConfig
987
  load_tf_weights = load_tf_weights_in_bert
988
  base_model_prefix = "bert"
989
  supports_gradient_checkpointing = True
@@ -1005,12 +985,12 @@ class MyBertPreTrainedModel(PreTrainedModel):
1005
  module.weight.data.fill_(1.0)
1006
 
1007
  def _set_gradient_checkpointing(self, module, value=False):
1008
- if isinstance(module, MyBertEncoder):
1009
  module.gradient_checkpointing = value
1010
 
1011
 
1012
  @dataclass
1013
- class MyBertForPreTrainingOutput(ModelOutput):
1014
  """
1015
  Output type of [`BertForPreTraining`].
1016
 
@@ -1113,7 +1093,7 @@ BERT_INPUTS_DOCSTRING = r"""
1113
  "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
1114
  BERT_START_DOCSTRING,
1115
  )
1116
- class MyBertModel(MyBertPreTrainedModel):
1117
  """
1118
 
1119
  The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
@@ -1126,7 +1106,7 @@ class MyBertModel(MyBertPreTrainedModel):
1126
  `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
1127
  """
1128
 
1129
- def __init__(self, config: MyBertConfig, add_pooling_layer=True):
1130
  super().__init__(config)
1131
  self.config = config
1132
 
@@ -1137,17 +1117,17 @@ class MyBertModel(MyBertPreTrainedModel):
1137
 
1138
  self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
1139
 
1140
- self.embeddings = MyBertEmbeddings(config)
1141
- self.encoder = MyBertEncoder(config)
1142
 
1143
- self.pooler = MyBertPooler(config) if add_pooling_layer else None
1144
 
1145
  # Initialize weights and apply final processing
1146
  self.post_init()
1147
 
1148
  @torch.inference_mode()
1149
  def encode(
1150
- self: 'MyBertModel',
1151
  sentences: Union[str, List[str]],
1152
  batch_size: int = 32,
1153
  show_progress_bar: Optional[bool] = None,
@@ -1479,14 +1459,14 @@ class MyBertModel(MyBertPreTrainedModel):
1479
  """,
1480
  BERT_START_DOCSTRING,
1481
  )
1482
- class MyBertForPreTraining(MyBertPreTrainedModel):
1483
  _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1484
 
1485
  def __init__(self, config):
1486
  super().__init__(config)
1487
 
1488
- self.bert = MyBertModel(config)
1489
- self.cls = MyBertPreTrainingHeads(config)
1490
 
1491
  # Initialize weights and apply final processing
1492
  self.post_init()
@@ -1501,7 +1481,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
1501
  BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1502
  )
1503
  @replace_return_docstrings(
1504
- output_type=MyBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
1505
  )
1506
  def forward(
1507
  self,
@@ -1516,7 +1496,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
1516
  output_attentions: Optional[bool] = None,
1517
  output_hidden_states: Optional[bool] = None,
1518
  return_dict: Optional[bool] = None,
1519
- ) -> Union[Tuple[torch.Tensor], MyBertForPreTrainingOutput]:
1520
  r"""
1521
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1522
  Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1532,22 +1512,6 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
1532
  Used to hide legacy arguments that have been deprecated.
1533
 
1534
  Returns:
1535
-
1536
- Example:
1537
-
1538
- ```python
1539
- >>> from transformers import AutoTokenizer, MyBertForPreTraining
1540
- >>> import torch
1541
-
1542
- >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
1543
- >>> model = MyBertForPreTraining.from_pretrained("bert-base-uncased")
1544
-
1545
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1546
- >>> outputs = model(**inputs)
1547
-
1548
- >>> prediction_logits = outputs.prediction_logits
1549
- >>> seq_relationship_logits = outputs.seq_relationship_logits
1550
- ```
1551
  """
1552
  return_dict = (
1553
  return_dict if return_dict is not None else self.config.use_return_dict
@@ -1585,7 +1549,7 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
1585
  output = (prediction_scores, seq_relationship_score) + outputs[2:]
1586
  return ((total_loss,) + output) if total_loss is not None else output
1587
 
1588
- return MyBertForPreTrainingOutput(
1589
  loss=total_loss,
1590
  prediction_logits=prediction_scores,
1591
  seq_relationship_logits=seq_relationship_score,
@@ -1595,10 +1559,10 @@ class MyBertForPreTraining(MyBertPreTrainedModel):
1595
 
1596
 
1597
  @add_start_docstrings(
1598
- """MyBert Model with a `language modeling` head on top for CLM fine-tuning.""",
1599
  BERT_START_DOCSTRING,
1600
  )
1601
- class MyBertLMHeadModel(MyBertPreTrainedModel):
1602
  _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1603
 
1604
  def __init__(self, config):
@@ -1606,11 +1570,11 @@ class MyBertLMHeadModel(MyBertPreTrainedModel):
1606
 
1607
  if not config.is_decoder:
1608
  logger.warning(
1609
- "If you want to use `MyBertLMHeadModel` as a standalone, add `is_decoder=True.`"
1610
  )
1611
 
1612
- self.bert = MyBertModel(config, add_pooling_layer=False)
1613
- self.cls = MyBertOnlyMLMHead(config)
1614
 
1615
  # Initialize weights and apply final processing
1616
  self.post_init()
@@ -1755,9 +1719,9 @@ class MyBertLMHeadModel(MyBertPreTrainedModel):
1755
 
1756
 
1757
  @add_start_docstrings(
1758
- """MyBert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING
1759
  )
1760
- class MyBertForMaskedLM(MyBertPreTrainedModel):
1761
  _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1762
 
1763
  def __init__(self, config):
@@ -1765,12 +1729,12 @@ class MyBertForMaskedLM(MyBertPreTrainedModel):
1765
 
1766
  if config.is_decoder:
1767
  logger.warning(
1768
- "If you want to use `MyBertForMaskedLM` make sure `config.is_decoder=False` for "
1769
  "bi-directional self-attention."
1770
  )
1771
 
1772
- self.bert = MyBertModel(config, add_pooling_layer=False)
1773
- self.cls = MyBertOnlyMLMHead(config)
1774
 
1775
  # Initialize weights and apply final processing
1776
  self.post_init()
@@ -1880,15 +1844,15 @@ class MyBertForMaskedLM(MyBertPreTrainedModel):
1880
 
1881
 
1882
  @add_start_docstrings(
1883
- """MyBert Model with a `next sentence prediction (classification)` head on top.""",
1884
  BERT_START_DOCSTRING,
1885
  )
1886
- class MyBertForNextSentencePrediction(MyBertPreTrainedModel):
1887
  def __init__(self, config):
1888
  super().__init__(config)
1889
 
1890
- self.bert = MyBertModel(config)
1891
- self.cls = MyBertOnlyNSPHead(config)
1892
 
1893
  # Initialize weights and apply final processing
1894
  self.post_init()
@@ -1922,24 +1886,6 @@ class MyBertForNextSentencePrediction(MyBertPreTrainedModel):
1922
  - 1 indicates sequence B is a random sequence.
1923
 
1924
  Returns:
1925
-
1926
- Example:
1927
-
1928
- ```python
1929
- >>> from transformers import AutoTokenizer, MyBertForNextSentencePrediction
1930
- >>> import torch
1931
-
1932
- >>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
1933
- >>> model = MyBertForNextSentencePrediction.from_pretrained("bert-base-uncased")
1934
-
1935
- >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1936
- >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1937
- >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1938
-
1939
- >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1940
- >>> logits = outputs.logits
1941
- >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1942
- ```
1943
  """
1944
 
1945
  if "next_sentence_label" in kwargs:
@@ -1995,18 +1941,18 @@ class MyBertForNextSentencePrediction(MyBertPreTrainedModel):
1995
 
1996
  @add_start_docstrings(
1997
  """
1998
- MyBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1999
  output) e.g. for GLUE tasks.
2000
  """,
2001
  BERT_START_DOCSTRING,
2002
  )
2003
- class MyBertForSequenceClassification(MyBertPreTrainedModel):
2004
  def __init__(self, config):
2005
  super().__init__(config)
2006
  self.num_labels = config.num_labels
2007
  self.config = config
2008
 
2009
- self.bert = MyBertModel(config)
2010
  classifier_dropout = (
2011
  config.classifier_dropout
2012
  if config.classifier_dropout is not None
@@ -2106,16 +2052,16 @@ class MyBertForSequenceClassification(MyBertPreTrainedModel):
2106
 
2107
  @add_start_docstrings(
2108
  """
2109
- MyBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
2110
  softmax) e.g. for RocStories/SWAG tasks.
2111
  """,
2112
  BERT_START_DOCSTRING,
2113
  )
2114
- class MyBertForMultipleChoice(MyBertPreTrainedModel):
2115
  def __init__(self, config):
2116
  super().__init__(config)
2117
 
2118
- self.bert = MyBertModel(config)
2119
  classifier_dropout = (
2120
  config.classifier_dropout
2121
  if config.classifier_dropout is not None
@@ -2222,17 +2168,17 @@ class MyBertForMultipleChoice(MyBertPreTrainedModel):
2222
 
2223
  @add_start_docstrings(
2224
  """
2225
- MyBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
2226
  Named-Entity-Recognition (NER) tasks.
2227
  """,
2228
  BERT_START_DOCSTRING,
2229
  )
2230
- class MyBertForTokenClassification(MyBertPreTrainedModel):
2231
  def __init__(self, config):
2232
  super().__init__(config)
2233
  self.num_labels = config.num_labels
2234
 
2235
- self.bert = MyBertModel(config, add_pooling_layer=False)
2236
  classifier_dropout = (
2237
  config.classifier_dropout
2238
  if config.classifier_dropout is not None
@@ -2311,17 +2257,17 @@ class MyBertForTokenClassification(MyBertPreTrainedModel):
2311
 
2312
  @add_start_docstrings(
2313
  """
2314
- MyBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
2315
  layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
2316
  """,
2317
  BERT_START_DOCSTRING,
2318
  )
2319
- class MyBertForQuestionAnswering(MyBertPreTrainedModel):
2320
  def __init__(self, config):
2321
  super().__init__(config)
2322
  self.num_labels = config.num_labels
2323
 
2324
- self.bert = MyBertModel(config, add_pooling_layer=False)
2325
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
2326
 
2327
  # Initialize weights and apply final processing
 
54
  logging,
55
  replace_return_docstrings,
56
  )
57
+ from .configuration_bert import JinaBertConfig
58
 
59
  try:
60
  from tqdm.autonotebook import trange
 
66
  logger = logging.get_logger(__name__)
67
 
68
  _CHECKPOINT_FOR_DOC = "bert-base-uncased"
69
+ _CONFIG_FOR_DOC = "JinaBertConfig"
70
 
71
  # TokenClassification docstring
72
  _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = (
 
197
  return model
198
 
199
 
200
+ class JinaBertEmbeddings(nn.Module):
201
  """Construct the embeddings from word, position and token_type embeddings."""
202
 
203
+ def __init__(self, config: JinaBertConfig):
204
  super().__init__()
205
  self.word_embeddings = nn.Embedding(
206
  config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
 
280
  return embeddings
281
 
282
 
283
+ class JinaBertSelfAttention(nn.Module):
284
  def __init__(self, config, position_embedding_type=None):
285
  super().__init__()
286
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
 
448
  return outputs
449
 
450
 
451
+ class JinaBertSelfOutput(nn.Module):
452
  def __init__(self, config):
453
  super().__init__()
454
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
464
  return hidden_states
465
 
466
 
467
+ class JinaBertAttention(nn.Module):
468
  def __init__(self, config, position_embedding_type=None):
469
  super().__init__()
470
+ self.self = JinaBertSelfAttention(
471
  config, position_embedding_type=position_embedding_type
472
  )
473
+ self.output = JinaBertSelfOutput(config)
474
  self.pruned_heads = set()
475
 
476
  def prune_heads(self, heads):
 
524
  return outputs
525
 
526
 
527
+ class JinaBertIntermediate(nn.Module):
528
  def __init__(self, config):
529
  super().__init__()
530
  self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
 
539
  return hidden_states
540
 
541
 
542
+ class JinaBertOutput(nn.Module):
543
+ def __init__(self, config: JinaBertConfig):
544
  super().__init__()
545
  self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
546
  self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
 
555
  return hidden_states
556
 
557
 
558
+ class JinaBertGLUMLP(nn.Module):
559
+ def __init__(self, config: JinaBertConfig):
560
  super().__init__()
561
  self.config = config
562
  self.gated_layers = nn.Linear(
 
589
  return hidden_states
590
 
591
 
592
+ class JinaBertLayer(nn.Module):
593
+ def __init__(self, config: JinaBertConfig):
594
  super().__init__()
595
  self.chunk_size_feed_forward = config.chunk_size_feed_forward
596
  self.seq_len_dim = 1
597
+ self.attention = JinaBertAttention(config)
598
  self.is_decoder = config.is_decoder
599
  self.add_cross_attention = config.add_cross_attention
600
  self.feed_forward_type = config.feed_forward_type
 
603
  raise ValueError(
604
  f"{self} should be used as a decoder model if cross attention is added"
605
  )
606
+ self.crossattention = JinaBertAttention(
607
  config, position_embedding_type="absolute"
608
  )
609
  if self.feed_forward_type.endswith('glu'):
610
+ self.mlp = JinaBertGLUMLP(config)
611
  else:
612
+ self.intermediate = JinaBertIntermediate(config)
613
+ self.output = JinaBertOutput(config)
614
 
615
  def forward(
616
  self,
 
699
  return layer_output
700
 
701
 
702
+ class JinaBertEncoder(nn.Module):
703
+ def __init__(self, config: JinaBertConfig):
704
  super().__init__()
705
  self.config = config
706
  self.layer = nn.ModuleList(
707
+ [JinaBertLayer(config) for _ in range(config.num_hidden_layers)]
708
  )
709
  self.gradient_checkpointing = False
710
  self.num_attention_heads = config.num_attention_heads
 
724
  # will be applied, it is necessary to construct the diagonal mask.
725
  n_heads = self.num_attention_heads
726
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  def _get_alibi_head_slopes(n_heads: int) -> List[float]:
728
  def get_slopes_power_of_2(n):
729
  start = 2 ** (-(2 ** -(math.log2(n) - 3)))
 
873
  )
874
 
875
 
876
+ class JinaBertPooler(nn.Module):
877
  def __init__(self, config):
878
  super().__init__()
879
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
888
  return pooled_output
889
 
890
 
891
+ class JinaBertPredictionHeadTransform(nn.Module):
892
  def __init__(self, config):
893
  super().__init__()
894
  self.dense = nn.Linear(config.hidden_size, config.hidden_size)
 
905
  return hidden_states
906
 
907
 
908
+ class JinaBertLMPredictionHead(nn.Module):
909
  def __init__(self, config):
910
  super().__init__()
911
+ self.transform = JinaBertPredictionHeadTransform(config)
912
 
913
  # The output weights are the same as the input embeddings, but there is
914
  # an output-only bias for each token.
 
925
  return hidden_states
926
 
927
 
928
+ class JinaBertOnlyMLMHead(nn.Module):
929
  def __init__(self, config):
930
  super().__init__()
931
+ self.predictions = JinaBertLMPredictionHead(config)
932
 
933
  def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
934
  prediction_scores = self.predictions(sequence_output)
935
  return prediction_scores
936
 
937
 
938
+ class JinaBertOnlyNSPHead(nn.Module):
939
  def __init__(self, config):
940
  super().__init__()
941
  self.seq_relationship = nn.Linear(config.hidden_size, 2)
 
945
  return seq_relationship_score
946
 
947
 
948
+ class JinaBertPreTrainingHeads(nn.Module):
949
  def __init__(self, config):
950
  super().__init__()
951
+ self.predictions = JinaBertLMPredictionHead(config)
952
  self.seq_relationship = nn.Linear(config.hidden_size, 2)
953
 
954
  def forward(self, sequence_output, pooled_output):
 
957
  return prediction_scores, seq_relationship_score
958
 
959
 
960
+ class JinaBertPreTrainedModel(PreTrainedModel):
961
  """
962
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
963
  models.
964
  """
965
 
966
+ config_class = JinaBertConfig
967
  load_tf_weights = load_tf_weights_in_bert
968
  base_model_prefix = "bert"
969
  supports_gradient_checkpointing = True
 
985
  module.weight.data.fill_(1.0)
986
 
987
  def _set_gradient_checkpointing(self, module, value=False):
988
+ if isinstance(module, JinaBertEncoder):
989
  module.gradient_checkpointing = value
990
 
991
 
992
  @dataclass
993
+ class JinaBertForPreTrainingOutput(ModelOutput):
994
  """
995
  Output type of [`BertForPreTraining`].
996
 
 
1093
  "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
1094
  BERT_START_DOCSTRING,
1095
  )
1096
+ class JinaBertModel(JinaBertPreTrainedModel):
1097
  """
1098
 
1099
  The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
 
1106
  `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
1107
  """
1108
 
1109
+ def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
1110
  super().__init__(config)
1111
  self.config = config
1112
 
 
1117
 
1118
  self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
1119
 
1120
+ self.embeddings = JinaBertEmbeddings(config)
1121
+ self.encoder = JinaBertEncoder(config)
1122
 
1123
+ self.pooler = JinaBertPooler(config) if add_pooling_layer else None
1124
 
1125
  # Initialize weights and apply final processing
1126
  self.post_init()
1127
 
1128
  @torch.inference_mode()
1129
  def encode(
1130
+ self: 'JinaBertModel',
1131
  sentences: Union[str, List[str]],
1132
  batch_size: int = 32,
1133
  show_progress_bar: Optional[bool] = None,
 
1459
  """,
1460
  BERT_START_DOCSTRING,
1461
  )
1462
+ class JinaBertForPreTraining(JinaBertPreTrainedModel):
1463
  _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1464
 
1465
  def __init__(self, config):
1466
  super().__init__(config)
1467
 
1468
+ self.bert = JinaBertModel(config)
1469
+ self.cls = JinaBertPreTrainingHeads(config)
1470
 
1471
  # Initialize weights and apply final processing
1472
  self.post_init()
 
1481
  BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")
1482
  )
1483
  @replace_return_docstrings(
1484
+ output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
1485
  )
1486
  def forward(
1487
  self,
 
1496
  output_attentions: Optional[bool] = None,
1497
  output_hidden_states: Optional[bool] = None,
1498
  return_dict: Optional[bool] = None,
1499
+ ) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
1500
  r"""
1501
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1502
  Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
 
1512
  Used to hide legacy arguments that have been deprecated.
1513
 
1514
  Returns:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1515
  """
1516
  return_dict = (
1517
  return_dict if return_dict is not None else self.config.use_return_dict
 
1549
  output = (prediction_scores, seq_relationship_score) + outputs[2:]
1550
  return ((total_loss,) + output) if total_loss is not None else output
1551
 
1552
+ return JinaBertForPreTrainingOutput(
1553
  loss=total_loss,
1554
  prediction_logits=prediction_scores,
1555
  seq_relationship_logits=seq_relationship_score,
 
1559
 
1560
 
1561
  @add_start_docstrings(
1562
+ """JinaBert Model with a `language modeling` head on top for CLM fine-tuning.""",
1563
  BERT_START_DOCSTRING,
1564
  )
1565
+ class JinaBertLMHeadModel(JinaBertPreTrainedModel):
1566
  _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1567
 
1568
  def __init__(self, config):
 
1570
 
1571
  if not config.is_decoder:
1572
  logger.warning(
1573
+ "If you want to use `JinaBertLMHeadModel` as a standalone, add `is_decoder=True.`"
1574
  )
1575
 
1576
+ self.bert = JinaBertModel(config, add_pooling_layer=False)
1577
+ self.cls = JinaBertOnlyMLMHead(config)
1578
 
1579
  # Initialize weights and apply final processing
1580
  self.post_init()
 
1719
 
1720
 
1721
  @add_start_docstrings(
1722
+ """JinaBert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING
1723
  )
1724
+ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
1725
  _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1726
 
1727
  def __init__(self, config):
 
1729
 
1730
  if config.is_decoder:
1731
  logger.warning(
1732
+ "If you want to use `JinaBertForMaskedLM` make sure `config.is_decoder=False` for "
1733
  "bi-directional self-attention."
1734
  )
1735
 
1736
+ self.bert = JinaBertModel(config, add_pooling_layer=False)
1737
+ self.cls = JinaBertOnlyMLMHead(config)
1738
 
1739
  # Initialize weights and apply final processing
1740
  self.post_init()
 
1844
 
1845
 
1846
  @add_start_docstrings(
1847
+ """JinaBert Model with a `next sentence prediction (classification)` head on top.""",
1848
  BERT_START_DOCSTRING,
1849
  )
1850
+ class JinaBertForNextSentencePrediction(JinaBertPreTrainedModel):
1851
  def __init__(self, config):
1852
  super().__init__(config)
1853
 
1854
+ self.bert = JinaBertModel(config)
1855
+ self.cls = JinaBertOnlyNSPHead(config)
1856
 
1857
  # Initialize weights and apply final processing
1858
  self.post_init()
 
1886
  - 1 indicates sequence B is a random sequence.
1887
 
1888
  Returns:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1889
  """
1890
 
1891
  if "next_sentence_label" in kwargs:
 
1941
 
1942
  @add_start_docstrings(
1943
  """
1944
+ JinaBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1945
  output) e.g. for GLUE tasks.
1946
  """,
1947
  BERT_START_DOCSTRING,
1948
  )
1949
+ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
1950
  def __init__(self, config):
1951
  super().__init__(config)
1952
  self.num_labels = config.num_labels
1953
  self.config = config
1954
 
1955
+ self.bert = JinaBertModel(config)
1956
  classifier_dropout = (
1957
  config.classifier_dropout
1958
  if config.classifier_dropout is not None
 
2052
 
2053
  @add_start_docstrings(
2054
  """
2055
+ JinaBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
2056
  softmax) e.g. for RocStories/SWAG tasks.
2057
  """,
2058
  BERT_START_DOCSTRING,
2059
  )
2060
+ class JinaBertForMultipleChoice(JinaBertPreTrainedModel):
2061
  def __init__(self, config):
2062
  super().__init__(config)
2063
 
2064
+ self.bert = JinaBertModel(config)
2065
  classifier_dropout = (
2066
  config.classifier_dropout
2067
  if config.classifier_dropout is not None
 
2168
 
2169
  @add_start_docstrings(
2170
  """
2171
+ JinaBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
2172
  Named-Entity-Recognition (NER) tasks.
2173
  """,
2174
  BERT_START_DOCSTRING,
2175
  )
2176
+ class JinaBertForTokenClassification(JinaBertPreTrainedModel):
2177
  def __init__(self, config):
2178
  super().__init__(config)
2179
  self.num_labels = config.num_labels
2180
 
2181
+ self.bert = JinaBertModel(config, add_pooling_layer=False)
2182
  classifier_dropout = (
2183
  config.classifier_dropout
2184
  if config.classifier_dropout is not None
 
2257
 
2258
  @add_start_docstrings(
2259
  """
2260
+ JinaBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
2261
  layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
2262
  """,
2263
  BERT_START_DOCSTRING,
2264
  )
2265
+ class JinaBertForQuestionAnswering(JinaBertPreTrainedModel):
2266
  def __init__(self, config):
2267
  super().__init__(config)
2268
  self.num_labels = config.num_labels
2269
 
2270
+ self.bert = JinaBertModel(config, add_pooling_layer=False)
2271
  self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
2272
 
2273
  # Initialize weights and apply final processing