Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
3e386a9
·
1 Parent(s): e509b16

Upload NomicBertForPreTraining

Browse files
config.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "swiglu",
3
+ "architectures": [
4
+ "NomicBertForPreTraining"
5
+ ],
6
+ "attn_pdrop": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_nomic_bert.NomicBertConfig",
9
+ "AutoModelForMaskedLM": "modeling_hf_nomic_bert.NomicBertForPreTraining"
10
+ },
11
+ "bos_token_id": null,
12
+ "causal": false,
13
+ "dense_seq_output": true,
14
+ "embd_pdrop": 0.1,
15
+ "eos_token_id": null,
16
+ "fused_bias_fc": true,
17
+ "fused_dropout_add_ln": true,
18
+ "initializer_range": 0.02,
19
+ "layer_norm_epsilon": 1e-12,
20
+ "mlp_fc1_bias": false,
21
+ "mlp_fc2_bias": false,
22
+ "model_type": "nomic_bert",
23
+ "n_embd": 768,
24
+ "n_head": 12,
25
+ "n_inner": 3072,
26
+ "n_layer": 12,
27
+ "n_positions": 0,
28
+ "pad_vocab_size_multiple": 64,
29
+ "parallel_block": false,
30
+ "parallel_block_tied_norm": false,
31
+ "prenorm": false,
32
+ "qkv_proj_bias": false,
33
+ "reorder_and_upcast_attn": false,
34
+ "resid_pdrop": 0.1,
35
+ "rotary_emb_base": 1000,
36
+ "rotary_emb_fraction": 1.0,
37
+ "rotary_emb_interleaved": false,
38
+ "rotary_emb_scale_base": null,
39
+ "scale_attn_by_inverse_layer_idx": false,
40
+ "scale_attn_weights": true,
41
+ "summary_activation": null,
42
+ "summary_first_dropout": 0.1,
43
+ "summary_proj_to_labels": true,
44
+ "summary_type": "cls_index",
45
+ "summary_use_proj": true,
46
+ "torch_dtype": "float32",
47
+ "transformers_version": "4.34.0",
48
+ "type_vocab_size": 2,
49
+ "use_cache": true,
50
+ "use_flash_attn": true,
51
+ "use_rms_norm": false,
52
+ "use_xentropy": true,
53
+ "vocab_size": 30528
54
+ }
configuration_nomic_bert.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+
4
+ class NomicBertConfig(GPT2Config):
5
+ model_type = "nomic_bert"
6
+
7
+ def __init__(self,
8
+ prenorm=False,
9
+ parallel_block=False,
10
+ parallel_block_tied_norm=False,
11
+ rotary_emb_fraction=0.0,
12
+ fused_dropout_add_ln=False,
13
+ fused_bias_fc=False,
14
+ use_flash_attn=False,
15
+ use_xentropy=False,
16
+ qkv_proj_bias=True,
17
+ rotary_emb_base=1000,
18
+ rotary_emb_scale_base=None,
19
+ rotary_emb_interleaved=False,
20
+ mlp_fc1_bias=True,
21
+ mlp_fc2_bias=True,
22
+ use_rms_norm=False,
23
+ causal=False,
24
+ type_vocab_size=2,
25
+ dense_seq_output=True,
26
+ pad_vocab_size_multiple=1,
27
+ tie_word_embeddings=True,
28
+ **kwargs,
29
+ ):
30
+ self.prenorm = prenorm
31
+ self.parallel_block = parallel_block
32
+ self.parallel_block_tied_norm = parallel_block_tied_norm
33
+ self.rotary_emb_fraction = rotary_emb_fraction
34
+ self.tie_word_embeddings = tie_word_embeddings
35
+ self.fused_dropout_add_ln = fused_dropout_add_ln
36
+ self.fused_bias_fc = fused_bias_fc
37
+ self.use_flash_attn = use_flash_attn
38
+ self.use_xentropy = use_xentropy
39
+ self.qkv_proj_bias = qkv_proj_bias
40
+ self.rotary_emb_base = rotary_emb_base
41
+ self.rotary_emb_scale_base = rotary_emb_scale_base
42
+ self.rotary_emb_interleaved = rotary_emb_interleaved
43
+ self.mlp_fc1_bias = mlp_fc1_bias
44
+ self.mlp_fc2_bias = mlp_fc2_bias
45
+ self.use_rms_norm = use_rms_norm
46
+ self.causal = causal
47
+ self.type_vocab_size = type_vocab_size
48
+ self.dense_seq_output = dense_seq_output
49
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
50
+
51
+ super().__init__(**kwargs)
modeling_hf_nomic_bert.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
+
8
+ import os
9
+ import logging
10
+ from functools import partial
11
+ from typing import Optional, List, Tuple, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from einops import rearrange, repeat
17
+ from transformers import GPT2Config, PreTrainedModel
18
+ from transformers.models.bert.modeling_bert import (
19
+ BaseModelOutputWithPoolingAndCrossAttentions,
20
+ BertForPreTrainingOutput,
21
+ SequenceClassifierOutput
22
+ )
23
+
24
+ from contrastors.models.encoder.configuration_nomic_bert import NomicBertConfig
25
+ from contrastors.models.model_utils import state_dict_from_pretrained, filter_shapes
26
+ from contrastors.models.encoder.bert import remap_bert_state_dict
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class NomicBertPreTrainedModel(PreTrainedModel):
32
+ """An abstract class to handle weights initialization and
33
+ a simple interface for dowloading and loading pretrained models.
34
+ """
35
+ config_class = NomicBertConfig
36
+ base_model_prefix = "model"
37
+ supports_gradient_checkpointing = True
38
+ _no_split_modules = ["Block"]
39
+ _skip_keys_device_placement = "past_key_values"
40
+
41
+ def __init__(self, config, *inputs, **kwargs):
42
+ super().__init__(config)
43
+ if not isinstance(config, GPT2Config):
44
+ raise ValueError(
45
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
46
+ "To create a model from a Google pretrained model use "
47
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
48
+ self.__class__.__name__, self.__class__.__name__
49
+ )
50
+ )
51
+ self.config = config
52
+
53
+ @classmethod
54
+ def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
55
+ """
56
+ Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
57
+ Download and cache the pre-trained model file if needed.
58
+
59
+ Params:
60
+ pretrained_model_name_or_path: either:
61
+ - a path or url to a pretrained model archive containing:
62
+ . `bert_config.json` a configuration file for the model
63
+ . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
64
+ - a path or url to a pretrained model archive containing:
65
+ . `bert_config.json` a configuration file for the model
66
+ . `model.chkpt` a TensorFlow checkpoint
67
+ *inputs, **kwargs: additional input for the specific NomicBert class
68
+ (ex: num_labels for NomicBertForSequenceClassification)
69
+ """
70
+ # Instantiate model.
71
+ if config is None:
72
+ config = cls.config_class.from_pretrained(model_name)
73
+ remove_cls = cls != NomicBertForPreTraining
74
+ remove_bert_prefix = cls != NomicBertForPreTraining
75
+ ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
76
+ model = cls(config, *inputs, **kwargs)
77
+ # TODO: fix this
78
+ # Assuming we know what we're doing when loading from disk
79
+ # Prob a bad assumption but i'm tired and want to train this asap
80
+ if os.path.exists(model_name):
81
+ state_dict = torch.load(f"{model_name}/pytorch_model.bin")
82
+ if ignore_mismatched_shapes:
83
+ state_dict = filter_shapes(state_dict, model)
84
+ load_return = model.load_state_dict(state_dict, strict=False)
85
+ else:
86
+ # TODO: can probably check config class and see if we need to remap from a bert model
87
+ state_dict = state_dict_from_pretrained(model_name)
88
+ state_dict = remap_bert_state_dict(state_dict,
89
+ config,
90
+ remove_bert=remove_bert_prefix,
91
+ remove_cls_weights=remove_cls,
92
+ add_pooling_layer=getattr(config, "add_pooling_layer", False)
93
+ )
94
+ if ignore_mismatched_shapes:
95
+ state_dict = filter_shapes(state_dict, model)
96
+
97
+ load_return = model.load_state_dict(
98
+ state_dict,
99
+ strict=True
100
+ )
101
+ logger.info(load_return)
102
+ return model
103
+
104
+ def _set_gradient_checkpointing(self, module, value=False):
105
+ if isinstance(module, NomicBertEncoder):
106
+ module.gradient_checkpointing = value
107
+
108
+
109
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
110
+ def _init_weights(module, initializer_range=0.02):
111
+ if isinstance(module, nn.Linear):
112
+ nn.init.normal_(module.weight, std=initializer_range)
113
+ if module.bias is not None:
114
+ nn.init.zeros_(module.bias)
115
+ elif isinstance(module, nn.Embedding):
116
+ nn.init.normal_(module.weight, std=initializer_range)
117
+ if module.padding_idx is not None:
118
+ nn.init.zeros_(module.weight[module.padding_idx])
119
+
120
+
121
+ class NomicBertEmbeddings(nn.Module):
122
+ def __init__(
123
+ self,
124
+ config
125
+ ):
126
+ """
127
+ If max_position_embeddings <= 0, there's no position embeddings
128
+ If type_vocab_size <= 0, there's no token type embeddings
129
+ """
130
+ super().__init__()
131
+ self.word_embeddings = nn.Embedding(
132
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
133
+ )
134
+ self.max_position_embeddings = config.max_position_embeddings
135
+ self.type_vocab_size = config.type_vocab_size
136
+ if self.max_position_embeddings > 0:
137
+ self.position_embeddings = nn.Embedding(
138
+ config.max_position_embeddings, config.hidden_size,
139
+ )
140
+ if self.type_vocab_size > 0:
141
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
142
+
143
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
144
+ """
145
+ input_ids: (batch, seqlen)
146
+ position_ids: (batch, seqlen)
147
+ token_type_ids: (batch, seqlen)
148
+ """
149
+ batch_size, seqlen = input_ids.shape
150
+ embeddings = self.word_embeddings(input_ids)
151
+
152
+ if self.type_vocab_size > 0:
153
+ if token_type_ids is None:
154
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
155
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
156
+ embeddings = embeddings + token_type_embeddings
157
+
158
+ if self.max_position_embeddings > 0:
159
+ if position_ids is None:
160
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
161
+ position_embeddings = self.position_embeddings(position_ids)
162
+ embeddings = embeddings + position_embeddings
163
+ return embeddings
164
+
165
+ class NomicBertMLP(nn.Module):
166
+ def __init__(
167
+ self,
168
+ in_features,
169
+ hidden_features=None,
170
+ out_features=None,
171
+ activation=F.gelu,
172
+ bias1=True,
173
+ bias2=True,
174
+ return_residual=False,
175
+ fused_bias_fc=False,
176
+ ):
177
+ super().__init__()
178
+ out_features = out_features if out_features is not None else in_features
179
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
180
+ self.return_residual = return_residual
181
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
182
+ approximate = (
183
+ "tanh"
184
+ if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
185
+ else "none"
186
+ )
187
+ self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
188
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
189
+
190
+ def forward(self, x):
191
+ y = self.fc1(x)
192
+ y = self.activation(y)
193
+ y = self.fc2(y)
194
+ return y if not self.return_residual else (y, x)
195
+
196
+
197
+ class NomciBertGatedMLP(nn.Module):
198
+ def __init__(
199
+ self,
200
+ in_features,
201
+ hidden_features=None,
202
+ out_features=None,
203
+ activation=F.sigmoid,
204
+ bias1=True,
205
+ bias2=True,
206
+ multiple_of=256,
207
+ return_residual=False,
208
+ fused_bias_fc=True,
209
+ device=None,
210
+ dtype=None,
211
+ ):
212
+ super().__init__()
213
+ out_features = out_features if out_features is not None else in_features
214
+ hidden_features = (
215
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
216
+ )
217
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
218
+ self.return_residual = return_residual
219
+
220
+ self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
221
+ self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
222
+ self.activation = activation
223
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
224
+
225
+ def forward(self, x):
226
+ y = self.fc11(x)
227
+ gate = self.fc12(x)
228
+ if self.activation == F.sigmoid: # Special case for GLU
229
+ y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
230
+ else:
231
+ y = y * self.activation(gate)
232
+ y = self.fc2(y)
233
+ return y if not self.return_residual else (y, x)
234
+
235
+
236
+ def rotate_half(x, interleaved=False):
237
+ if not interleaved:
238
+ x1, x2 = x.chunk(2, dim=-1)
239
+ return torch.cat((-x2, x1), dim=-1)
240
+ else:
241
+ x1, x2 = x[..., ::2], x[..., 1::2]
242
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
243
+
244
+
245
+ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
246
+ """
247
+ x: (batch_size, seqlen, nheads, headdim)
248
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
249
+ """
250
+ ro_dim = cos.shape[-1] * 2
251
+ assert ro_dim <= x.shape[-1]
252
+ cos, sin = (
253
+ cos[offset: offset + x.shape[1]],
254
+ sin[offset: offset + x.shape[1]],
255
+ )
256
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
257
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
258
+ return torch.cat(
259
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
260
+ dim=-1,
261
+ )
262
+
263
+
264
+ class NomicBertRotaryEmbedding(nn.Module):
265
+ def __init__(
266
+ self,
267
+ dim: int,
268
+ base=10000.0,
269
+ interleaved=False,
270
+ scale_base=None,
271
+ pos_idx_in_fp32=True,
272
+ device=None,
273
+ ):
274
+ """
275
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
276
+ of 1st half and 2nd half (GPT-NeoX style).
277
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
278
+ otherwise they might be in lower precision.
279
+ This option was added because previously (before 2023-07-02), when we construct
280
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
281
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
282
+ self.inv_freq would be bf16, and the position indices are also in bf16.
283
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
284
+ embeddings for some positions will coincide.
285
+ To maintain compatibility with models previously trained in pure bf16,
286
+ we add this option.
287
+ """
288
+ super().__init__()
289
+ self.dim = dim
290
+ self.base = float(base)
291
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
292
+ # Generate and save the inverse frequency buffer (non trainable)
293
+ inv_freq = self._compute_inv_freq(device)
294
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
295
+ self.interleaved = interleaved
296
+ self.scale_base = scale_base
297
+
298
+ self._seq_len_cached = 0
299
+ self._cos_cached = None
300
+ self._sin_cached = None
301
+ self._cos_k_cached = None
302
+ self._sin_k_cached = None
303
+
304
+ def _compute_inv_freq(self, device=None):
305
+ return 1.0 / (
306
+ self.base
307
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
308
+ )
309
+
310
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
311
+ # Reset the tables if the sequence length has changed,
312
+ # if we're on a new device (possibly due to tracing for instance),
313
+ # or if we're switching from inference mode to training
314
+ if (
315
+ seqlen > self._seq_len_cached
316
+ or self._cos_cached is None
317
+ or self._cos_cached.device != device
318
+ or self._cos_cached.dtype != dtype
319
+ or (self.training and self._cos_cached.is_inference())
320
+ ):
321
+ self._seq_len_cached = seqlen
322
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
323
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
324
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
325
+ if self.pos_idx_in_fp32:
326
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
327
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
328
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
329
+ # cos & sin output to change significantly.
330
+ # We want to recompute self.inv_freq if it was not loaded in fp32
331
+ if self.inv_freq.dtype != torch.float32:
332
+ inv_freq = self._compute_inv_freq(device=device)
333
+ else:
334
+ inv_freq = self.inv_freq
335
+ else:
336
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
337
+ inv_freq = self.inv_freq
338
+ # Don't do einsum, it converts fp32 to fp16 under AMP
339
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
340
+ freqs = torch.outer(t, inv_freq)
341
+ self._cos_cached = torch.cos(freqs).to(dtype)
342
+ self._sin_cached = torch.sin(freqs).to(dtype)
343
+
344
+ def forward(
345
+ self,
346
+ qkv: torch.Tensor,
347
+ kv: Optional[torch.Tensor] = None,
348
+ seqlen_offset: Union[int, torch.Tensor] = 0,
349
+ max_seqlen: Optional[int] = None,
350
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
351
+ """
352
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
353
+ else it's just q of shape (batch, seqlen, nheads, headdim)
354
+ kv: (batch, seqlen, 2, nheads, headdim)
355
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
356
+ Most commonly used in inference when we have KV cache.
357
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
358
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
359
+ Apply rotary embedding *inplace* to qkv and / or kv.
360
+ """
361
+ seqlen = qkv.shape[1]
362
+ if max_seqlen is not None:
363
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
364
+ elif isinstance(seqlen_offset, int):
365
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
366
+
367
+ q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
368
+ k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
369
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
370
+
371
+
372
+
373
+ class NomicBertAttention(nn.Module):
374
+ """Multi-head self-attention and cross-attention"""
375
+
376
+ def __init__(
377
+ self,
378
+ config,
379
+ ) -> None:
380
+ """
381
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
382
+ return_residual: whether to return the input x along with the output. This is for
383
+ performance reason: for post-norm architecture, returning the input allows us
384
+ to fuse the backward of nn.Linear with the residual connection.
385
+ """
386
+ super().__init__()
387
+ self.embed_dim = config.n_embd
388
+ self.use_flash_attn = config.use_flash_attn
389
+ self.fused_bias_fc = config.fused_bias_fc
390
+
391
+ self.num_heads = config.n_head
392
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
393
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
394
+ self.head_dim = self.embed_dim // self.num_heads
395
+ # we don't really support mqa / gqa for now
396
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
397
+
398
+ self.register_buffer(
399
+ "norm_factor",
400
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
401
+ persistent=False,
402
+ )
403
+
404
+ self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
405
+ if self.rotary_emb_dim > 0:
406
+ self.rotary_emb = NomicBertRotaryEmbedding(
407
+ self.rotary_emb_dim,
408
+ base=config.rotary_emb_base,
409
+ scale_base=config.rotary_emb_scale_base,
410
+ interleaved=config.rotary_emb_interleaved,
411
+ )
412
+ # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
413
+ # uses the head dimension instead of the sequence dimension
414
+ self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
415
+
416
+ self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
417
+
418
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
419
+ self.causal = config.causal
420
+ self.drop = nn.Dropout(config.attn_pdrop)
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.Tensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
428
+ output_attentions: bool = False,
429
+ use_cache: bool = False,
430
+ is_padded_inputs: Optional[bool] = True,
431
+ cu_seqlens: Optional[torch.Tensor] = None,
432
+ max_seq_len: Optional[int] = None,
433
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
434
+
435
+ has_layer_past = past_key_value is not None
436
+
437
+ if has_layer_past:
438
+ past_key_value = past_key_value[0]
439
+ past_len = past_key_value[1]
440
+ else:
441
+ past_len = 0
442
+
443
+ qkv = self.Wqkv(hidden_states)
444
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
445
+
446
+ past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
447
+
448
+ if self.rotary_emb_dim > 0:
449
+ if self.rotary_head_dim:
450
+ qkv = rearrange(qkv, "b s three h d -> b h three s d")
451
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
452
+
453
+ if self.rotary_head_dim:
454
+ qkv = rearrange(qkv, "b h three s d -> b s three h d")
455
+
456
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
457
+
458
+ query = query.permute(0, 2, 1, 3)
459
+ key = key.permute(0, 2, 1, 3)
460
+ value = value.permute(0, 2, 1, 3)
461
+
462
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
463
+ if attention_mask is not None:
464
+ attention_scores = attention_scores + attention_mask
465
+
466
+ attentions_probs = F.softmax(attention_scores, dim=-1)
467
+ attentions_probs = self.drop(attentions_probs)
468
+
469
+ attn_output = torch.matmul(attentions_probs, value)
470
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
471
+
472
+ attn_output = self.out_proj(attn_output)
473
+
474
+ return attn_output
475
+
476
+
477
+ class NomicBertBlock(nn.Module):
478
+ def __init__(
479
+ self,
480
+ config,
481
+ ):
482
+ super().__init__()
483
+ self.prenorm = config.prenorm
484
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
485
+
486
+ self.attn = NomicBertAttention(config)
487
+ activation = (
488
+ F.sigmoid
489
+ if config.activation_function == "glu"
490
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
491
+ )
492
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
493
+ self.mlp = NomciBertGatedMLP(config.n_embd, hidden_features=config.n_inner, bias1=config.mlp_fc1_bias, bias2=config.mlp_fc2_bias, activation=activation, fused_bias_fc=config.fused_bias_fc)
494
+ else:
495
+ self.mlp = NomicBertMLP(config.n_embd, hidden_features=config.n_inner, bias1=config.mlp_fc1_bias, bias2=config.mlp_fc2_bias, activation=activation, fused_bias_fc=config.fused_bias_fc)
496
+
497
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
498
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
499
+ self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
500
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
501
+
502
+ def forward(
503
+ self,
504
+ hidden_states: torch.Tensor,
505
+ hidden_states2: torch.Tensor,
506
+ residual: Optional[torch.Tensor] = None,
507
+ attention_mask: Optional[torch.Tensor] = None,
508
+ position_ids: Optional[torch.LongTensor] = None,
509
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
510
+ is_padded_inputs: Optional[bool] = True,
511
+ output_attentions: Optional[bool] = False,
512
+ use_cache: Optional[bool] = False,
513
+ cu_seqlens: Optional[torch.Tensor] = None,
514
+ max_seq_len: Optional[int] = None,
515
+ ):
516
+ r"""Pass the input through the encoder layer.
517
+
518
+ Args:
519
+ hidden_states: the sequence to the encoder layer (required).
520
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
521
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
522
+ before applying the query projection. Useful for e.g., ViT where we only care
523
+ about the CLS token in the last layer.
524
+ """
525
+ if self.prenorm:
526
+ dropped = self.dropout1(hidden_states)
527
+ residual = (dropped + residual) if residual is not None else dropped
528
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
529
+ hidden_states = self.attn(hidden_states, attention_mask=attention_mask, is_padded_inputs=is_padded_inputs, cu_seqlens=cu_seqlens, max_seq_len=max_seq_len)
530
+
531
+ dropped = self.dropout2(hidden_states)
532
+ residual = (dropped + residual) if residual is not None else dropped
533
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
534
+ hidden_states = self.mlp(hidden_states)
535
+
536
+ return hidden_states, None, residual
537
+ else:
538
+ assert residual is None
539
+ attn_outputs = self.attn(hidden_states,
540
+ attention_mask=attention_mask,
541
+ is_padded_inputs=is_padded_inputs,
542
+ cu_seqlens=cu_seqlens,
543
+ max_seq_len=max_seq_len)
544
+ hidden_states = self.norm1(
545
+ (self.dropout1(attn_outputs) + hidden_states).to(
546
+ dtype=self.norm1.weight.dtype
547
+ )
548
+ )
549
+ mlp_out = self.mlp(hidden_states)
550
+
551
+ hidden_states = self.norm2(
552
+ (self.dropout2(mlp_out) + hidden_states).to(
553
+ dtype=self.norm2.weight.dtype
554
+ )
555
+ )
556
+ return hidden_states, None, None
557
+
558
+
559
+ class NomicBertEncoder(nn.Module):
560
+ def __init__(self, config: GPT2Config):
561
+ super().__init__()
562
+ self.layers = nn.ModuleList(
563
+ [NomicBertBlock(config) for _ in range(config.n_layer)]
564
+ )
565
+ self.gradient_checkpointing = False
566
+ self.config = config
567
+
568
+ def forward(self,
569
+ hidden_states: torch.LongTensor = None,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ position_ids: Optional[torch.LongTensor] = None,
572
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
573
+ inputs_embeds: Optional[torch.FloatTensor] = None,
574
+ use_cache: Optional[bool] = None,
575
+ output_attentions: Optional[bool] = None,
576
+ output_hidden_states: Optional[bool] = None,
577
+ return_dict: Optional[bool] = None,
578
+ is_padded_inputs: Optional[bool] = True,):
579
+
580
+ """If subset_mask is not None, we only want output for the subset of the sequence.
581
+ This means that we only compute the last layer output for these tokens.
582
+ subset_mask: (batch, seqlen), dtype=torch.bool
583
+ """
584
+ hidden_states2 = None
585
+ residual = None
586
+
587
+
588
+ for _, layer in enumerate(self.layers):
589
+ if self.gradient_checkpointing and self.training:
590
+
591
+ def create_custom_forward(module):
592
+ def custom_forward(*inputs):
593
+ # None for past_key_value
594
+ return module(*inputs)
595
+
596
+ return custom_forward
597
+
598
+ hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
599
+ create_custom_forward(layer),
600
+ hidden_states,
601
+ hidden_states2,
602
+ residual,
603
+ attention_mask,
604
+ None,
605
+ None,
606
+ is_padded_inputs,
607
+ # if you freeze ANY layers, you need `use_reentrant=False`
608
+ # https://github.com/huggingface/transformers/issues/21381
609
+ # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
610
+ use_reentrant=False,
611
+ )
612
+
613
+ else:
614
+ hidden_states, hidden_states2, residual = layer(
615
+ hidden_states,
616
+ hidden_states2,
617
+ residual,
618
+ attention_mask,
619
+ position_ids,
620
+ None,
621
+ is_padded_inputs,
622
+ output_attentions,
623
+ use_cache,
624
+ )
625
+ return hidden_states
626
+
627
+
628
+ class NomicBertPooler(nn.Module):
629
+ def __init__(self, config):
630
+ super().__init__()
631
+ self.dense = nn.Linear(config.n_embd, config.n_embd)
632
+ self.activation = nn.Tanh()
633
+
634
+ def forward(self, hidden_states, pool=True):
635
+ # We "pool" the model by simply taking the hidden state corresponding
636
+ # to the first token.
637
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
638
+ pooled_output = self.dense(first_token_tensor)
639
+ pooled_output = self.activation(pooled_output)
640
+ return pooled_output
641
+
642
+
643
+ class NomicBertPredictionHeadTransform(nn.Module):
644
+ def __init__(self, config):
645
+ super().__init__()
646
+ self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
647
+ approximate = (
648
+ "tanh"
649
+ if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
650
+ else "none"
651
+ )
652
+ if config.activation_function == "swiglu":
653
+ self.transform_act_fn = F.silu
654
+ else:
655
+ self.transform_act_fn = nn.GELU(approximate=approximate)
656
+
657
+ self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
658
+
659
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
660
+ hidden_states = self.dense(hidden_states)
661
+ hidden_states = self.transform_act_fn(hidden_states)
662
+ hidden_states = self.layer_norm(hidden_states)
663
+
664
+ return hidden_states
665
+
666
+
667
+ class NomicBertLMPredictionHead(nn.Module):
668
+ def __init__(self, config):
669
+ super().__init__()
670
+
671
+ self.transform = NomicBertPredictionHeadTransform(config)
672
+
673
+ self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
674
+
675
+ def forward(self, hidden_states):
676
+ hidden_states = self.transform(hidden_states)
677
+ hidden_states = self.decoder(hidden_states)
678
+ return hidden_states
679
+
680
+
681
+ class NomicBertPreTrainingHeads(nn.Module):
682
+ def __init__(self, config):
683
+ super().__init__()
684
+ self.predictions = NomicBertLMPredictionHead(config)
685
+
686
+ def forward(self, sequence_output):
687
+ prediction_scores = self.predictions(sequence_output)
688
+ return prediction_scores
689
+
690
+
691
+ class NomicBertModel(NomicBertPreTrainedModel):
692
+ def __init__(self, config: GPT2Config, add_pooling_layer=True):
693
+ super().__init__(config)
694
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
695
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
696
+ config.vocab_size += self.pad_vocab_size_multiple - (
697
+ config.vocab_size % self.pad_vocab_size_multiple
698
+ )
699
+
700
+ assert config.activation_function in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu", "geglu", "glu"]
701
+
702
+ self.embeddings = NomicBertEmbeddings(
703
+ config
704
+ )
705
+ self.emb_drop = nn.Dropout(config.resid_pdrop)
706
+ self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
707
+ self.encoder = NomicBertEncoder(config)
708
+ self.pooler = NomicBertPooler(config) if add_pooling_layer else None
709
+
710
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
711
+
712
+ def forward(
713
+ self,
714
+ input_ids,
715
+ position_ids=None,
716
+ token_type_ids=None,
717
+ attention_mask=None,
718
+ ):
719
+ if token_type_ids is None:
720
+ token_type_ids = torch.zeros_like(input_ids)
721
+ hidden_states = self.embeddings(
722
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
723
+ )
724
+ hidden_states = self.emb_ln(hidden_states)
725
+ hidden_states = self.emb_drop(hidden_states)
726
+
727
+ attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
728
+ sequence_output = self.encoder(
729
+ hidden_states, attention_mask=attention_mask
730
+ )
731
+
732
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
733
+
734
+ return BaseModelOutputWithPoolingAndCrossAttentions(
735
+ last_hidden_state=sequence_output,
736
+ pooler_output=pooled_output,
737
+ )
738
+
739
+
740
+ class NomicBertForPreTraining(NomicBertPreTrainedModel):
741
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
742
+
743
+ def __init__(self, config: GPT2Config):
744
+ super().__init__(config)
745
+
746
+ self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
747
+ self.cls = NomicBertPreTrainingHeads(config)
748
+ self.mlm_loss = nn.CrossEntropyLoss()
749
+
750
+ # Initialize weights and apply final processing
751
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
752
+ self.tie_weights()
753
+
754
+ def tie_weights(self):
755
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
756
+
757
+ def forward(
758
+ self,
759
+ input_ids,
760
+ position_ids=None,
761
+ token_type_ids=None,
762
+ attention_mask=None,
763
+ labels=None,
764
+ ):
765
+ """
766
+ If labels are provided, they must be -100 for masked out tokens (as specified in the attention
767
+ mask).
768
+ Outputs:
769
+ if `labels` and `next_sentence_label` are not `None`:
770
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
771
+ sentence classification loss.
772
+ if `labels` or `next_sentence_label` is `None`:
773
+ Outputs a tuple comprising
774
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
775
+ - the next sentence classification logits of shape [batch_size, 2].
776
+
777
+ """
778
+ outputs = self.bert(
779
+ input_ids,
780
+ position_ids=position_ids,
781
+ token_type_ids=token_type_ids,
782
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
783
+ )
784
+ sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
785
+
786
+ prediction_scores = self.cls(sequence_output)
787
+
788
+ total_loss = None
789
+ if labels is not None:
790
+ masked_lm_loss = self.mlm_loss(
791
+ rearrange(prediction_scores, "... v -> (...) v"),
792
+ rearrange(labels, "... -> (...)"),
793
+ )
794
+ total_loss = masked_lm_loss.float()
795
+
796
+ return BertForPreTrainingOutput(
797
+ loss=total_loss,
798
+ prediction_logits=prediction_scores,
799
+ )
800
+
801
+
802
+ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
803
+ def __init__(self, config):
804
+ super().__init__(config)
805
+ self.num_labels = config.num_labels
806
+ self.config = config
807
+
808
+ self.bert = NomicBertModel(config)
809
+ classifier_dropout = (
810
+ getattr(config, "classifier_dropout", config.embd_pdrop)
811
+ )
812
+ self.dropout = nn.Dropout(classifier_dropout)
813
+ self.classifier = nn.Linear(config.n_embd, config.num_labels)
814
+
815
+ # Initialize weights and apply final processing
816
+ self.post_init()
817
+
818
+ def forward(
819
+ self,
820
+ input_ids: Optional[torch.Tensor] = None,
821
+ attention_mask: Optional[torch.Tensor] = None,
822
+ token_type_ids: Optional[torch.Tensor] = None,
823
+ position_ids: Optional[torch.Tensor] = None,
824
+ head_mask: Optional[torch.Tensor] = None,
825
+ inputs_embeds: Optional[torch.Tensor] = None,
826
+ labels: Optional[torch.Tensor] = None,
827
+ output_attentions: Optional[bool] = None,
828
+ output_hidden_states: Optional[bool] = None,
829
+ return_dict: Optional[bool] = None,
830
+ ):
831
+ r"""
832
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
833
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
834
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
835
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
836
+ """
837
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
838
+ outputs = self.bert(
839
+ input_ids,
840
+ position_ids=position_ids,
841
+ token_type_ids=token_type_ids,
842
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
843
+ )
844
+
845
+ pooled_output = outputs[1]
846
+
847
+ pooled_output = self.dropout(pooled_output)
848
+ logits = self.classifier(pooled_output)
849
+
850
+ loss = None
851
+ if labels is not None:
852
+ if self.config.problem_type is None:
853
+ if self.num_labels == 1:
854
+ self.config.problem_type = "regression"
855
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
856
+ self.config.problem_type = "single_label_classification"
857
+ else:
858
+ self.config.problem_type = "multi_label_classification"
859
+
860
+ if self.config.problem_type == "regression":
861
+ loss_fct = nn.MSELoss()
862
+ if self.num_labels == 1:
863
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
864
+ else:
865
+ loss = loss_fct(logits, labels)
866
+ elif self.config.problem_type == "single_label_classification":
867
+ loss_fct = nn.CrossEntropyLoss()
868
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
869
+ elif self.config.problem_type == "multi_label_classification":
870
+ loss_fct = nn.BCEWithLogitsLoss()
871
+ loss = loss_fct(logits, labels)
872
+ if not return_dict:
873
+ output = (logits,) + outputs[2:]
874
+ return ((loss,) + output) if loss is not None else output
875
+
876
+ return SequenceClassifierOutput(
877
+ loss=loss,
878
+ logits=logits,
879
+ hidden_states=outputs.hidden_states,
880
+ attentions=outputs.attentions,
881
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0b1c7db843aeae0c744716acf3c7b7da60d142e7ca31edd11853f5b163c8776
3
+ size 549328982