Markus28 commited on
Commit
87b642a
0 Parent(s):

initial commit

Browse files
Files changed (2) hide show
  1. configuration_bert.py +95 -0
  2. modeling_bert.py +760 -0
configuration_bert.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration"""
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ from transformers import PretrainedConfig
21
+
22
+
23
+ class JinaBertConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`BertModel`] or a [`TFBertModel`]. It is used to
26
+ instantiate a BERT model according to the specified arguments, defining the model architecture. Instantiating a
27
+ configuration with the defaults will yield a similar configuration to that of the BERT
28
+ [google-bert/bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) architecture.
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+
34
+ Args:
35
+ vocab_size (`int`, *optional*, defaults to 30522):
36
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
37
+ `inputs_ids` passed when calling [`BertModel`] or [`TFBertModel`].
38
+ hidden_size (`int`, *optional*, defaults to 768):
39
+ Dimensionality of the encoder layers and the pooler layer.
40
+ num_hidden_layers (`int`, *optional*, defaults to 12):
41
+ Number of hidden layers in the Transformer encoder.
42
+ num_attention_heads (`int`, *optional*, defaults to 12):
43
+ Number of attention heads for each attention layer in the Transformer encoder.
44
+ intermediate_size (`int`, *optional*, defaults to 3072):
45
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
46
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
49
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
52
+ The dropout ratio for the attention probabilities.
53
+ type_vocab_size (`int`, *optional*, defaults to 2):
54
+ The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or [`TFBertModel`].
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
58
+ The epsilon used by the layer normalization layers.
59
+ window_size (`tuple`, *optional*, defaults to `(-1, -1)`): If not the default, use local attention
60
+ """
61
+
62
+ model_type = "bert"
63
+
64
+ def __init__(
65
+ self,
66
+ vocab_size=30522,
67
+ hidden_size=768,
68
+ num_hidden_layers=12,
69
+ num_attention_heads=12,
70
+ intermediate_size=3072,
71
+ hidden_act="gelu",
72
+ hidden_dropout_prob=0.1,
73
+ attention_probs_dropout_prob=0.1,
74
+ type_vocab_size=2,
75
+ initializer_range=0.02,
76
+ layer_norm_eps=1e-12,
77
+ pad_token_id=0,
78
+ window_size=(-1, -1),
79
+ **kwargs,
80
+ ):
81
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
82
+
83
+ self.vocab_size = vocab_size
84
+ self.hidden_size = hidden_size
85
+ self.num_hidden_layers = num_hidden_layers
86
+ self.num_attention_heads = num_attention_heads
87
+ self.hidden_act = hidden_act
88
+ self.intermediate_size = intermediate_size
89
+ self.hidden_dropout_prob = hidden_dropout_prob
90
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
91
+ self.type_vocab_size = type_vocab_size
92
+ self.initializer_range = initializer_range
93
+ self.layer_norm_eps = layer_norm_eps
94
+ self.window_size = window_size
95
+
modeling_bert.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
+
8
+ import logging
9
+ import re
10
+ from collections import OrderedDict
11
+ from collections.abc import Sequence
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import rearrange
18
+ from transformers import PretrainedConfig
19
+ from configuration_bert import JinaBertConfig
20
+ from transformers.models.bert.modeling_bert import (
21
+ BaseModelOutputWithPoolingAndCrossAttentions,
22
+ BertForPreTrainingOutput,
23
+ )
24
+
25
+ from flash_attn.bert_padding import (
26
+ index_first_axis,
27
+ index_first_axis_residual,
28
+ pad_input,
29
+ unpad_input,
30
+ )
31
+ from flash_attn.modules.block import Block
32
+ from flash_attn.modules.embedding import BertEmbeddings
33
+ from flash_attn.modules.mha import MHA
34
+ from flash_attn.modules.mlp import FusedMLP, Mlp
35
+ from flash_attn.utils.pretrained import state_dict_from_pretrained
36
+
37
+ try:
38
+ from flash_attn.ops.fused_dense import FusedDense
39
+ except ImportError:
40
+ FusedDense = None
41
+
42
+ try:
43
+ from flash_attn.ops.triton.layer_norm import layer_norm_fn
44
+ except ImportError:
45
+ layer_norm_fn = None
46
+
47
+
48
+ try:
49
+ from flash_attn.losses.cross_entropy import CrossEntropyLoss
50
+ except ImportError:
51
+ CrossEntropyLoss = None
52
+
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ def create_mixer_cls(config, cross_attn=False, return_residual=False):
58
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
59
+ window_size = getattr(config, "window_size", (-1, -1))
60
+ mixer_cls = partial(
61
+ MHA,
62
+ num_heads=config.num_attention_heads,
63
+ cross_attn=cross_attn,
64
+ dropout=config.attention_probs_dropout_prob,
65
+ causal=False,
66
+ fused_bias_fc=fused_bias_fc,
67
+ use_flash_attn=True,
68
+ return_residual=return_residual,
69
+ use_alibi=True,
70
+ window_size=window_size,
71
+ )
72
+ return mixer_cls
73
+
74
+
75
+ def create_mlp_cls(config, layer_idx=None, return_residual=False):
76
+ inner_dim = config.intermediate_size
77
+ fused_mlp = getattr(config, "fused_mlp", False)
78
+ if fused_mlp:
79
+ assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
80
+ "fused_mlp only " "supports approximate gelu"
81
+ )
82
+ if not fused_mlp:
83
+ approximate = (
84
+ "tanh"
85
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
86
+ else "none"
87
+ )
88
+ mlp_cls = partial(
89
+ Mlp,
90
+ hidden_features=inner_dim,
91
+ activation=partial(F.gelu, approximate=approximate),
92
+ return_residual=return_residual,
93
+ )
94
+ else:
95
+ if FusedMLP is None:
96
+ raise ImportError("fused_dense is not installed")
97
+ mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
98
+ # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
99
+ if isinstance(mlp_checkpoint_lvl, Sequence):
100
+ assert layer_idx is not None
101
+ mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
102
+ mlp_cls = partial(
103
+ FusedMLP,
104
+ hidden_features=inner_dim,
105
+ checkpoint_lvl=mlp_checkpoint_lvl,
106
+ return_residual=return_residual,
107
+ )
108
+ return mlp_cls
109
+
110
+
111
+ def create_block(config, layer_idx=None):
112
+ last_layer_subset = getattr(config, "last_layer_subset", False)
113
+ cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
114
+ # TD [2022-12-19]: For cross attention (last layer), we actually want to return the
115
+ # residual x_kv, not residual x. But it's annoying to change the API (and it only affects
116
+ # one layer) so we just choose not to return residual in this case.
117
+ return_residual = not cross_attn
118
+ mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
119
+ mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
120
+ norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
121
+ block = Block(
122
+ config.hidden_size,
123
+ mixer_cls,
124
+ mlp_cls,
125
+ norm_cls=norm_cls,
126
+ prenorm=False,
127
+ resid_dropout1=config.hidden_dropout_prob,
128
+ resid_dropout2=config.hidden_dropout_prob,
129
+ fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
130
+ return_residual=return_residual,
131
+ )
132
+ return block
133
+
134
+
135
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
136
+ def _init_weights(module, initializer_range=0.02):
137
+ if isinstance(module, nn.Linear):
138
+ nn.init.normal_(module.weight, std=initializer_range)
139
+ if module.bias is not None:
140
+ nn.init.zeros_(module.bias)
141
+ elif isinstance(module, nn.Embedding):
142
+ nn.init.normal_(module.weight, std=initializer_range)
143
+ if module.padding_idx is not None:
144
+ nn.init.zeros_(module.weight[module.padding_idx])
145
+
146
+
147
+ class BertEncoder(nn.Module):
148
+ def __init__(self, config: JinaBertConfig):
149
+ super().__init__()
150
+ self.use_flash_attn = getattr(config, "use_flash_attn", False)
151
+ self.layers = nn.ModuleList(
152
+ [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
153
+ )
154
+
155
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
156
+ """If subset_mask is not None, we only want output for the subset of the sequence.
157
+ This means that we only compute the last layer output for these tokens.
158
+ subset_mask: (batch, seqlen), dtype=torch.bool
159
+ """
160
+ if key_padding_mask is None or not self.use_flash_attn:
161
+ mixer_kwargs = (
162
+ {"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
163
+ )
164
+ for layer in self.layers:
165
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
166
+ print(hidden_states)
167
+ if subset_mask is not None:
168
+ hidden_states = hidden_states[subset_mask]
169
+ else:
170
+ batch, seqlen = hidden_states.shape[:2]
171
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
172
+ hidden_states, key_padding_mask
173
+ )
174
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
175
+ if subset_mask is None:
176
+ for layer in self.layers:
177
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
178
+ hidden_states = pad_input(hidden_states, indices, batch, seqlen)
179
+ else:
180
+ for layer in self.layers[:-1]:
181
+ hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
182
+ if key_padding_mask is not None:
183
+ subset_idx = torch.nonzero(
184
+ subset_mask[key_padding_mask], as_tuple=False
185
+ ).flatten()
186
+ subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32)
187
+ subset_cu_seqlens = F.pad(
188
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
189
+ )
190
+ else:
191
+ subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
192
+ subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
193
+ subset_cu_seqlens = F.pad(
194
+ torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32), (1, 0)
195
+ )
196
+ hidden_states_subset, hidden_states = index_first_axis_residual(
197
+ hidden_states, subset_idx
198
+ )
199
+ # It's ok to set max_seqlen_q to be much larger
200
+ mixer_kwargs = {
201
+ "x_kv": hidden_states,
202
+ "cu_seqlens": subset_cu_seqlens,
203
+ "max_seqlen": max_seqlen_in_batch,
204
+ "cu_seqlens_k": cu_seqlens,
205
+ "max_seqlen_k": max_seqlen_in_batch,
206
+ }
207
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
208
+ return hidden_states
209
+
210
+
211
+ class BertPooler(nn.Module):
212
+ def __init__(self, config):
213
+ super().__init__()
214
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
215
+ if fused_bias_fc and FusedDense is None:
216
+ raise ImportError("fused_dense is not installed")
217
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
218
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
219
+ self.activation = nn.Tanh()
220
+
221
+ def forward(self, hidden_states, pool=True):
222
+ # We "pool" the model by simply taking the hidden state corresponding
223
+ # to the first token.
224
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
225
+ pooled_output = self.dense(first_token_tensor)
226
+ pooled_output = self.activation(pooled_output)
227
+ return pooled_output
228
+
229
+
230
+ class BertPredictionHeadTransform(nn.Module):
231
+ def __init__(self, config):
232
+ super().__init__()
233
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
234
+ if fused_bias_fc and FusedDense is None:
235
+ raise ImportError("fused_dense is not installed")
236
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
237
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
238
+ raise ImportError("Triton is not installed")
239
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
240
+ self.dense = linear_cls(config.hidden_size, config.hidden_size)
241
+ approximate = (
242
+ "tanh"
243
+ if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
244
+ else "none"
245
+ )
246
+ self.transform_act_fn = nn.GELU(approximate=approximate)
247
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
248
+
249
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
250
+ hidden_states = self.dense(hidden_states)
251
+ hidden_states = self.transform_act_fn(hidden_states)
252
+ if not self.fused_dropout_add_ln:
253
+ hidden_states = self.layer_norm(hidden_states)
254
+ else:
255
+ hidden_states = layer_norm_fn(
256
+ hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
257
+ )
258
+ return hidden_states
259
+
260
+
261
+ class BertLMPredictionHead(nn.Module):
262
+ def __init__(self, config):
263
+ super().__init__()
264
+ fused_bias_fc = getattr(config, "fused_bias_fc", False)
265
+ if fused_bias_fc and FusedDense is None:
266
+ raise ImportError("fused_dense is not installed")
267
+ linear_cls = nn.Linear if not fused_bias_fc else FusedDense
268
+
269
+ self.transform = BertPredictionHeadTransform(config)
270
+
271
+ # The output weights are the same as the input embeddings, but there is
272
+ # an output-only bias for each token.
273
+ self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
274
+
275
+ def forward(self, hidden_states):
276
+ hidden_states = self.transform(hidden_states)
277
+ hidden_states = self.decoder(hidden_states)
278
+ return hidden_states
279
+
280
+
281
+ class BertPreTrainingHeads(nn.Module):
282
+ def __init__(self, config):
283
+ super().__init__()
284
+ self.predictions = BertLMPredictionHead(config)
285
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
286
+
287
+ def forward(self, sequence_output, pooled_output):
288
+ prediction_scores = self.predictions(sequence_output)
289
+ seq_relationship_score = self.seq_relationship(pooled_output)
290
+ return prediction_scores, seq_relationship_score
291
+
292
+
293
+ class BertPreTrainedModel(nn.Module):
294
+ """An abstract class to handle weights initialization and
295
+ a simple interface for dowloading and loading pretrained models.
296
+ """
297
+
298
+ def __init__(self, config, *inputs, **kwargs):
299
+ super().__init__()
300
+ if not isinstance(config, JinaBertConfig):
301
+ raise ValueError(
302
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
303
+ "To create a model from a Google pretrained model use "
304
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
305
+ self.__class__.__name__, self.__class__.__name__
306
+ )
307
+ )
308
+ self.config = config
309
+
310
+ @classmethod
311
+ def from_pretrained(cls, model_name, config, *inputs, **kwargs):
312
+ """
313
+ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
314
+ Download and cache the pre-trained model file if needed.
315
+
316
+ Params:
317
+ pretrained_model_name_or_path: either:
318
+ - a path or url to a pretrained model archive containing:
319
+ . `bert_config.json` a configuration file for the model
320
+ . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
321
+ - a path or url to a pretrained model archive containing:
322
+ . `bert_config.json` a configuration file for the model
323
+ . `model.chkpt` a TensorFlow checkpoint
324
+ *inputs, **kwargs: additional input for the specific Bert class
325
+ (ex: num_labels for BertForSequenceClassification)
326
+ """
327
+ # Instantiate model.
328
+ model = cls(config, *inputs, **kwargs)
329
+ load_return = model.load_state_dict(
330
+ remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
331
+ )
332
+ logger.info(load_return)
333
+ return model
334
+
335
+
336
+ class BertModel(BertPreTrainedModel):
337
+ def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
338
+ super().__init__(config)
339
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
340
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
341
+ config.vocab_size += self.pad_vocab_size_multiple - (
342
+ config.vocab_size % self.pad_vocab_size_multiple
343
+ )
344
+ self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
345
+ if self.fused_dropout_add_ln and layer_norm_fn is None:
346
+ raise ImportError("Triton is not installed")
347
+ assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
348
+
349
+ self.embeddings = BertEmbeddings(
350
+ config.hidden_size,
351
+ config.vocab_size,
352
+ -1, # No position embeddings
353
+ config.type_vocab_size,
354
+ padding_idx=config.pad_token_id,
355
+ )
356
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
357
+ self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
358
+ self.encoder = BertEncoder(config)
359
+ self.pooler = BertPooler(config) if add_pooling_layer else None
360
+
361
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
362
+
363
+ def forward(
364
+ self,
365
+ input_ids,
366
+ position_ids=None,
367
+ token_type_ids=None,
368
+ attention_mask=None,
369
+ masked_tokens_mask=None,
370
+ ):
371
+ """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
372
+ we only want the output for the masked tokens. This means that we only compute the last
373
+ layer output for these tokens.
374
+ masked_tokens_mask: (batch, seqlen), dtype=torch.bool
375
+ """
376
+ hidden_states = self.embeddings(
377
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
378
+ )
379
+ # TD [2022-12:18]: Don't need to force residual in fp32
380
+ # BERT puts embedding LayerNorm before embedding dropout.
381
+ if not self.fused_dropout_add_ln:
382
+ hidden_states = self.emb_ln(hidden_states)
383
+ else:
384
+ hidden_states = layer_norm_fn(
385
+ hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
386
+ )
387
+ hidden_states = self.emb_drop(hidden_states)
388
+
389
+ if masked_tokens_mask is not None:
390
+ batch_size, seqlen = input_ids.shape[:2]
391
+ # We also need the first column for the CLS token
392
+ first_col_mask = torch.zeros(
393
+ batch_size, seqlen, dtype=torch.bool, device=input_ids.device
394
+ )
395
+ first_col_mask[:, 0] = True
396
+ subset_mask = masked_tokens_mask | first_col_mask
397
+ else:
398
+ subset_mask = None
399
+
400
+ sequence_output = self.encoder(
401
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
402
+ )
403
+
404
+ if masked_tokens_mask is None:
405
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
406
+ else:
407
+ # TD [2022-03-01]: the indexing here is very tricky.
408
+ if attention_mask is not None:
409
+ subset_idx = subset_mask[attention_mask]
410
+ pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
411
+ sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]]
412
+ else:
413
+ pool_input = sequence_output[first_col_mask[subset_mask]]
414
+ sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
415
+ pooled_output = self.pooler(pool_input, pool=False) if self.pooler is not None else None
416
+
417
+ return BaseModelOutputWithPoolingAndCrossAttentions(
418
+ last_hidden_state=sequence_output,
419
+ pooler_output=pooled_output,
420
+ )
421
+
422
+
423
+ class BertForPreTraining(BertPreTrainedModel):
424
+ def __init__(self, config: JinaBertConfig):
425
+ super().__init__(config)
426
+ # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
427
+ # (around 15%) to the classifier heads.
428
+ self.dense_seq_output = getattr(config, "dense_seq_output", False)
429
+ # If last_layer_subset, we only need the compute the last layer for a subset of tokens
430
+ # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
431
+ self.last_layer_subset = getattr(config, "last_layer_subset", False)
432
+ if self.last_layer_subset:
433
+ assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
434
+ use_xentropy = getattr(config, "use_xentropy", False)
435
+ if use_xentropy and CrossEntropyLoss is None:
436
+ raise ImportError("xentropy_cuda is not installed")
437
+ loss_cls = (
438
+ nn.CrossEntropyLoss
439
+ if not use_xentropy
440
+ else partial(CrossEntropyLoss, inplace_backward=True)
441
+ )
442
+
443
+ self.bert = BertModel(config)
444
+ self.cls = BertPreTrainingHeads(config)
445
+ self.mlm_loss = loss_cls(ignore_index=0)
446
+ self.nsp_loss = loss_cls(ignore_index=-1)
447
+
448
+ # Initialize weights and apply final processing
449
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
450
+ self.tie_weights()
451
+
452
+ def tie_weights(self):
453
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
454
+
455
+ def forward(
456
+ self,
457
+ input_ids,
458
+ position_ids=None,
459
+ token_type_ids=None,
460
+ attention_mask=None,
461
+ labels=None,
462
+ next_sentence_label=None,
463
+ ):
464
+ """
465
+ If labels are provided, they must be 0 for masked out tokens (as specified in the attention
466
+ mask).
467
+ Outputs:
468
+ if `labels` and `next_sentence_label` are not `None`:
469
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
470
+ sentence classification loss.
471
+ if `labels` or `next_sentence_label` is `None`:
472
+ Outputs a tuple comprising
473
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
474
+ - the next sentence classification logits of shape [batch_size, 2].
475
+
476
+ """
477
+ masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
478
+ outputs = self.bert(
479
+ input_ids,
480
+ position_ids=position_ids,
481
+ token_type_ids=token_type_ids,
482
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
483
+ masked_tokens_mask=masked_tokens_mask,
484
+ )
485
+ sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
486
+ if self.dense_seq_output and labels is not None:
487
+ masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
488
+ if not self.last_layer_subset:
489
+ sequence_output = index_first_axis(
490
+ rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
491
+ )
492
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
493
+
494
+ total_loss = None
495
+ if labels is not None and next_sentence_label is not None:
496
+ if (
497
+ self.dense_seq_output and labels is not None
498
+ ): # prediction_scores are already flattened
499
+ masked_lm_loss = self.mlm_loss(
500
+ prediction_scores, labels.flatten()[masked_token_idx]
501
+ )
502
+ else:
503
+ masked_lm_loss = self.mlm_loss(
504
+ rearrange(prediction_scores, "... v -> (...) v"),
505
+ rearrange(labels, "... -> (...)"),
506
+ )
507
+ next_sentence_loss = self.nsp_loss(
508
+ rearrange(seq_relationship_score, "... t -> (...) t"),
509
+ rearrange(next_sentence_label, "... -> (...)"),
510
+ )
511
+ total_loss = masked_lm_loss.float() + next_sentence_loss.float()
512
+
513
+ return BertForPreTrainingOutput(
514
+ loss=total_loss,
515
+ prediction_logits=prediction_scores,
516
+ seq_relationship_logits=seq_relationship_score,
517
+ )
518
+
519
+
520
+ def remap_state_dict(state_dict, config: PretrainedConfig):
521
+ """
522
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
523
+ """
524
+
525
+ # LayerNorm
526
+ def key_mapping_ln_gamma_beta(key):
527
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
528
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
529
+ return key
530
+
531
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
532
+
533
+ # Layers
534
+ def key_mapping_layers(key):
535
+ return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
536
+
537
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
538
+
539
+ # LayerNorm
540
+ def key_mapping_ln(key):
541
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
542
+ key = re.sub(
543
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
544
+ r"bert.encoder.layers.\1.norm1.\2",
545
+ key,
546
+ )
547
+ key = re.sub(
548
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
549
+ r"bert.encoder.layers.\1.norm2.\2",
550
+ key,
551
+ )
552
+ key = re.sub(
553
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
554
+ r"cls.predictions.transform.layer_norm.\1",
555
+ key,
556
+ )
557
+ return key
558
+
559
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
560
+
561
+ # MLP
562
+ def key_mapping_mlp(key):
563
+ key = re.sub(
564
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
565
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
566
+ key,
567
+ )
568
+ key = re.sub(
569
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
570
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
571
+ key,
572
+ )
573
+ return key
574
+
575
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
576
+
577
+ # Attention
578
+ last_layer_subset = getattr(config, "last_layer_subset", False)
579
+ for d in range(config.num_hidden_layers):
580
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
581
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
582
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
583
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
584
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
585
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
586
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
587
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
588
+ [Wq, Wk, Wv], dim=0
589
+ )
590
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
591
+ else:
592
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
593
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
594
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
595
+ state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
596
+
597
+ def key_mapping_attn(key):
598
+ return re.sub(
599
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
600
+ r"bert.encoder.layers.\1.mixer.out_proj.\2",
601
+ key,
602
+ )
603
+
604
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
605
+
606
+ def key_mapping_decoder_bias(key):
607
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
608
+
609
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
610
+
611
+ # Word embedding
612
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
613
+ if pad_vocab_size_multiple > 1:
614
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
615
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
616
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
617
+ )
618
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
619
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
620
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
621
+ )
622
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
623
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
624
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
625
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
626
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
627
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
628
+ )
629
+
630
+ return state_dict
631
+
632
+
633
+ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
634
+ """
635
+ Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
636
+
637
+ This function is meant to be the inverse of remap_state_dict.
638
+ """
639
+ # Word embedding
640
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
641
+ if pad_vocab_size_multiple > 1:
642
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
643
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
644
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
645
+ # unpad embeddings
646
+ state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
647
+ : config.orig_vocab_size, :
648
+ ]
649
+ state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
650
+ state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
651
+
652
+ for d in range(config.num_hidden_layers):
653
+ last_layer_subset = getattr(config, "last_layer_subset", False)
654
+ if not last_layer_subset or d != (config.num_hidden_layers - 1):
655
+ Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
656
+ Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
657
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
658
+ : Wqkv_weights.shape[0] // 3, :
659
+ ]
660
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
661
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
662
+ ]
663
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
664
+ 2 * Wqkv_weights.shape[0] // 3 :, :
665
+ ]
666
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
667
+ : Wqkv_biases.shape[0] // 3
668
+ ]
669
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
670
+ Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
671
+ ]
672
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
673
+ 2 * Wqkv_biases.shape[0] // 3 :
674
+ ]
675
+ else:
676
+ Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
677
+ Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
678
+ Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
679
+ Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
680
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
681
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
682
+ : Wkv_weights.shape[0] // 2, :
683
+ ]
684
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
685
+ Wkv_weights.shape[0] // 2 :, :
686
+ ]
687
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
688
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
689
+ : Wkv_biases.shape[0] // 2
690
+ ]
691
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
692
+ Wkv_biases.shape[0] // 2 :
693
+ ]
694
+
695
+ def inv_key_mapping_ln(key):
696
+ key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
697
+ key = re.sub(
698
+ r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
699
+ r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
700
+ key,
701
+ )
702
+ key = re.sub(
703
+ r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
704
+ r"bert.encoder.layers.\1.output.LayerNorm.\2",
705
+ key,
706
+ )
707
+ key = re.sub(
708
+ r"cls.predictions.transform.layer_norm.(weight|bias)",
709
+ r"cls.predictions.transform.LayerNorm.\1",
710
+ key,
711
+ )
712
+ return key
713
+
714
+ def inv_key_mapping_ln_gamma_beta(key):
715
+ key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
716
+ key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
717
+ return key
718
+
719
+ def inv_key_mapping_layers(key):
720
+ return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
721
+
722
+ def inv_key_mapping_mlp(key):
723
+ key = re.sub(
724
+ r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
725
+ r"bert.encoder.layer.\1.intermediate.dense.\2",
726
+ key,
727
+ )
728
+ key = re.sub(
729
+ r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
730
+ r"bert.encoder.layer.\1.output.dense.\2",
731
+ key,
732
+ )
733
+ return key
734
+
735
+ def inv_key_mapping_attn(key):
736
+ return re.sub(
737
+ r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
738
+ r"bert.encoder.layer.\1.attention.output.dense.\2",
739
+ key,
740
+ )
741
+
742
+ def inv_key_mapping_decoder_bias(key):
743
+ return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
744
+
745
+ state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
746
+ state_dict = OrderedDict(
747
+ (inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
748
+ )
749
+ state_dict = OrderedDict(
750
+ (inv_key_mapping_layers(key), value) for key, value in state_dict.items()
751
+ )
752
+ state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
753
+ state_dict = OrderedDict(
754
+ (inv_key_mapping_attn(key), value) for key, value in state_dict.items()
755
+ )
756
+ state_dict = OrderedDict(
757
+ (inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
758
+ )
759
+
760
+ return state_dict