linhphanff commited on
Commit
6e79123
1 Parent(s): 557050f

Delete modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +0 -1227
modeling_hf_nomic_bert.py DELETED
@@ -1,1227 +0,0 @@
1
- # Copyright (c) 2022, Tri Dao.
2
- # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
- # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
- # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
-
6
- import logging
7
-
8
- # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
- import os
10
- import re
11
- from collections import OrderedDict
12
- from functools import partial
13
- from typing import List, Optional, Tuple, Union
14
-
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- from einops import rearrange, repeat
19
- from safetensors.torch import load_file as safe_load_file
20
- from transformers import GPT2Config, PreTrainedModel
21
- from transformers.models.bert.modeling_bert import (
22
- BaseModelOutputWithPoolingAndCrossAttentions,
23
- MaskedLMOutput,
24
- SequenceClassifierOutput,
25
- )
26
- from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
27
- from transformers.utils.hub import cached_file, get_checkpoint_shard_files
28
-
29
- from .configuration_hf_nomic_bert import NomicBertConfig
30
-
31
- logger = logging.getLogger(__name__)
32
-
33
-
34
- # adapted from flash attention, added safe serialization option for hf models
35
- def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
36
- # If not fp32, then we don't want to load directly to the GPU
37
- mapped_device = "cpu" if dtype not in [torch.float32, None] else device
38
- is_sharded = False
39
- load_safe = False
40
- resolved_archive_file = None
41
-
42
- weights_path = os.path.join(model_name, WEIGHTS_NAME)
43
- weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
44
- safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
45
- safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
46
-
47
- if os.path.isfile(weights_path):
48
- resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
49
- elif os.path.isfile(weights_index_path):
50
- resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
51
- is_sharded = True
52
- elif os.path.isfile(safe_weights_path):
53
- resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
54
- load_safe = True
55
- elif os.path.isfile(safe_weights_index_path):
56
- resolved_archive_file = cached_file(
57
- model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
58
- )
59
- is_sharded = True
60
- load_safe = True
61
- else: # Try loading from HF hub instead of from local files
62
- weight_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
63
- resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
64
- if resolved_archive_file is None:
65
- weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
66
- resolved_archive_file = cached_file(model_name, weight_index, _raise_exceptions_for_missing_entries=False)
67
- if resolved_archive_file is not None:
68
- is_sharded = True
69
-
70
- load_safe = safe_serialization
71
-
72
- if resolved_archive_file is None:
73
- raise EnvironmentError(f"Model name {model_name} was not found.")
74
-
75
- if load_safe:
76
- loader = partial(safe_load_file, device=mapped_device)
77
- else:
78
- loader = partial(torch.load, map_location=mapped_device)
79
-
80
- if is_sharded:
81
- # resolved_archive_file becomes a list of files that point to the different
82
- # checkpoint shards in this case.
83
- resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
84
- state_dict = {}
85
- for sharded_file in resolved_archive_file:
86
- state_dict.update(loader(sharded_file))
87
- else:
88
- state_dict = loader(resolved_archive_file)
89
- # Convert dtype before moving to GPU to save memory
90
- if dtype is not None:
91
- state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
92
- state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
93
- return state_dict
94
-
95
-
96
- def filter_shapes(state_dict, model):
97
- """
98
- Filters the state dict to match the current model shape.
99
- """
100
- filtered_state_dict = {}
101
- for key, value in state_dict.items():
102
- if key in model.state_dict():
103
- if value.shape == model.state_dict()[key].shape:
104
- filtered_state_dict[key] = value
105
- return filtered_state_dict
106
-
107
-
108
- def remap_bert_state_dict(
109
- state_dict,
110
- config,
111
- remove_bert=False,
112
- remove_cls_weights=False,
113
- add_pooling_layer=False,
114
- ):
115
- """
116
- Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
117
- """
118
-
119
- def add_bert_prefix(key):
120
- # prepend bert. to the key
121
- if key.startswith("bert.") or key.startswith("cls."):
122
- return key
123
- return f"bert.{key}"
124
-
125
- state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
126
-
127
- # LayerNorm
128
- def key_mapping_ln_gamma_beta(key):
129
- key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
130
- key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
131
- return key
132
-
133
- state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
134
-
135
- # Layers
136
- def key_mapping_layers(key):
137
- return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
138
-
139
- state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
140
-
141
- # LayerNorm
142
- def key_mapping_ln(key):
143
- key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
144
- key = re.sub(
145
- r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
146
- r"bert.encoder.layers.\1.norm1.\2",
147
- key,
148
- )
149
- key = re.sub(
150
- r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
151
- r"bert.encoder.layers.\1.norm2.\2",
152
- key,
153
- )
154
- key = re.sub(
155
- r"^cls.predictions.transform.LayerNorm.(weight|bias)",
156
- r"cls.predictions.transform.layer_norm.\1",
157
- key,
158
- )
159
- return key
160
-
161
- state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
162
-
163
- # MLP
164
- def key_mapping_mlp(key):
165
- key = re.sub(
166
- r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
167
- r"bert.encoder.layers.\1.mlp.fc1.\2",
168
- key,
169
- )
170
- key = re.sub(
171
- r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
172
- r"bert.encoder.layers.\1.mlp.fc2.\2",
173
- key,
174
- )
175
- return key
176
-
177
- state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
178
-
179
- # Attention
180
- last_layer_subset = getattr(config, "last_layer_subset", False)
181
- for d in range(config.num_hidden_layers):
182
- if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
183
- continue
184
- Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
185
- Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
186
- Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
187
- bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
188
- bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
189
- bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
190
- if not (last_layer_subset and d == config.num_hidden_layers - 1):
191
- state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
192
- state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
193
- else:
194
- state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
195
- state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
196
- state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
197
- state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
198
-
199
- def key_mapping_attn(key):
200
- return re.sub(
201
- r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
202
- r"bert.encoder.layers.\1.attn.out_proj.\2",
203
- key,
204
- )
205
-
206
- state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
207
-
208
- def key_mapping_decoder_bias(key):
209
- return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
210
-
211
- # remove nsp weights, we don't use
212
- state_dict.pop("cls.seq_relationship.weight", None)
213
- state_dict.pop("cls.seq_relationship.bias", None)
214
- state_dict.pop("bert.embeddings.position_ids", None)
215
-
216
- state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
217
-
218
- if remove_cls_weights:
219
- cls_weights = [
220
- "cls.predictions.decoder.bias",
221
- "cls.predictions.transform.dense.weight",
222
- "cls.predictions.transform.dense.bias",
223
- "cls.predictions.transform.layer_norm.weight",
224
- "cls.predictions.transform.layer_norm.bias",
225
- "cls.predictions.decoder.weight",
226
- ]
227
- for weight in cls_weights:
228
- state_dict.pop(weight, None)
229
-
230
- # Word embedding
231
- pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
232
- if pad_vocab_size_multiple > 1:
233
- word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
234
- state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
235
- word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
236
- )
237
- if not remove_cls_weights:
238
- decoder_weight = state_dict["cls.predictions.decoder.weight"]
239
- state_dict["cls.predictions.decoder.weight"] = F.pad(
240
- decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
241
- )
242
- # If the vocab was padded, we want to set the decoder bias for those padded indices to be
243
- # strongly negative (i.e. the decoder shouldn't predict those indices).
244
- # TD [2022-05-09]: I don't think it affects the MLPerf training.
245
- if "cls.predictions.decoder.bias" in state_dict:
246
- decoder_bias = state_dict["cls.predictions.decoder.bias"]
247
- state_dict["cls.predictions.decoder.bias"] = F.pad(
248
- decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
249
- )
250
-
251
- if add_pooling_layer is False:
252
- pooler_weights = [
253
- "bert.pooler.dense.weight",
254
- "bert.pooler.dense.bias",
255
- ]
256
- for key in pooler_weights:
257
- state_dict.pop(key, None)
258
-
259
- if remove_bert:
260
-
261
- def remove_bert_prefix(key):
262
- key = re.sub(r"^bert.", "", key)
263
- return key
264
-
265
- state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
266
-
267
- return state_dict
268
-
269
-
270
- class NomicBertPreTrainedModel(PreTrainedModel):
271
- """An abstract class to handle weights initialization and
272
- a simple interface for dowloading and loading pretrained models.
273
- """
274
-
275
- config_class = NomicBertConfig
276
- base_model_prefix = "model"
277
- supports_gradient_checkpointing = True
278
- _no_split_modules = ["Block"]
279
- _skip_keys_device_placement = "past_key_values"
280
-
281
- def __init__(self, config, *inputs, **kwargs):
282
- super().__init__(config)
283
- if not isinstance(config, GPT2Config):
284
- raise ValueError(
285
- "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
286
- "To create a model from a Google pretrained model use "
287
- "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
288
- self.__class__.__name__, self.__class__.__name__
289
- )
290
- )
291
- self.config = config
292
-
293
- @classmethod
294
- def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
295
- """
296
- Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
297
- Download and cache the pre-trained model file if needed.
298
-
299
- Params:
300
- pretrained_model_name_or_path: either:
301
- - a path or url to a pretrained model archive containing:
302
- . `bert_config.json` a configuration file for the model
303
- . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
304
- - a path or url to a pretrained model archive containing:
305
- . `bert_config.json` a configuration file for the model
306
- . `model.chkpt` a TensorFlow checkpoint
307
- *inputs, **kwargs: additional input for the specific NomicBert class
308
- (ex: num_labels for NomicBertForSequenceClassification)
309
- """
310
- # Instantiate model.
311
- if config is None:
312
- config = cls.config_class.from_pretrained(model_name)
313
- remove_cls = cls != NomicBertForPreTraining
314
- remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
315
- ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
316
- num_labels = kwargs.pop("num_labels", None)
317
- rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
318
- strict = kwargs.pop("strict", True)
319
- config.rotary_scaling_factor = rotary_scaling_factor
320
- if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
321
- config.n_positions = 2048
322
- if num_labels:
323
- config.num_labels = num_labels
324
-
325
- if "add_pooling_layer" in kwargs:
326
- model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
327
- else:
328
- model = cls(config, *inputs)
329
- # TODO: fix this
330
- # Assuming we know what we're doing when loading from disk
331
- # Prob a bad assumption but i'm tired and want to train this asap
332
- if os.path.exists(model_name):
333
- model_path = f"{model_name}/pytorch_model.bin"
334
- if os.path.exists(model_path):
335
- state_dict = torch.load(f"{model_name}/pytorch_model.bin")
336
- else:
337
- model_path = f"{model_name}/model.safetensors"
338
- if not os.path.exists(model_path):
339
- raise ValueError(f"Model path {model_path} not found")
340
- state_dict = safe_load_file(model_path)
341
-
342
- if ignore_mismatched_shapes:
343
- state_dict = filter_shapes(state_dict, model)
344
- load_return = model.load_state_dict(state_dict, strict=False)
345
- else:
346
- # TODO: can probably check config class and see if we need to remap from a bert model
347
- state_dict = state_dict_from_pretrained(model_name)
348
- state_dict = remap_bert_state_dict(
349
- state_dict,
350
- config,
351
- remove_bert=remove_bert_prefix,
352
- remove_cls_weights=remove_cls,
353
- add_pooling_layer=getattr(config, "add_pooling_layer", False),
354
- )
355
- if ignore_mismatched_shapes:
356
- state_dict = filter_shapes(state_dict, model)
357
-
358
- load_return = model.load_state_dict(state_dict, strict=strict)
359
- logger.warning(load_return)
360
- return model
361
-
362
- def _set_gradient_checkpointing(self, module, value=False):
363
- if isinstance(module, NomicBertEncoder):
364
- module.gradient_checkpointing = value
365
-
366
-
367
- # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
368
- def _init_weights(module, initializer_range=0.02):
369
- if isinstance(module, nn.Linear):
370
- nn.init.normal_(module.weight, std=initializer_range)
371
- if module.bias is not None:
372
- nn.init.zeros_(module.bias)
373
- elif isinstance(module, nn.Embedding):
374
- nn.init.normal_(module.weight, std=initializer_range)
375
- if module.padding_idx is not None:
376
- nn.init.zeros_(module.weight[module.padding_idx])
377
-
378
-
379
- class NomicBertEmbeddings(nn.Module):
380
- def __init__(self, config):
381
- """
382
- If max_position_embeddings <= 0, there's no position embeddings
383
- If type_vocab_size <= 0, there's no token type embeddings
384
- """
385
- super().__init__()
386
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
387
- self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
388
- self.type_vocab_size = config.type_vocab_size
389
- if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
390
- self.position_embeddings = nn.Embedding(
391
- config.max_position_embeddings,
392
- config.hidden_size,
393
- )
394
- if self.type_vocab_size > 0:
395
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
396
-
397
- def forward(self, input_ids, position_ids=None, token_type_ids=None):
398
- """
399
- input_ids: (batch, seqlen)
400
- position_ids: (batch, seqlen)
401
- token_type_ids: (batch, seqlen)
402
- """
403
- batch_size, seqlen = input_ids.shape
404
- embeddings = self.word_embeddings(input_ids)
405
-
406
- if self.type_vocab_size > 0:
407
- if token_type_ids is None:
408
- token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
409
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
410
- embeddings = embeddings + token_type_embeddings
411
-
412
- if self.max_position_embeddings > 0:
413
- if position_ids is None:
414
- position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
415
- position_embeddings = self.position_embeddings(position_ids)
416
- embeddings = embeddings + position_embeddings
417
- return embeddings
418
-
419
-
420
- class NomicBertMLP(nn.Module):
421
- def __init__(
422
- self,
423
- in_features,
424
- hidden_features=None,
425
- out_features=None,
426
- activation=F.gelu,
427
- bias1=True,
428
- bias2=True,
429
- return_residual=False,
430
- fused_bias_fc=False,
431
- ):
432
- super().__init__()
433
- out_features = out_features if out_features is not None else in_features
434
- hidden_features = hidden_features if hidden_features is not None else in_features * 4
435
- self.return_residual = return_residual
436
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
437
- approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
438
- self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
439
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
440
-
441
- def forward(self, x):
442
- y = self.fc1(x)
443
- y = self.activation(y)
444
- y = self.fc2(y)
445
- return y if not self.return_residual else (y, x)
446
-
447
-
448
- class NomciBertGatedMLP(nn.Module):
449
- def __init__(
450
- self,
451
- in_features,
452
- hidden_features=None,
453
- out_features=None,
454
- activation=F.sigmoid,
455
- bias1=True,
456
- bias2=True,
457
- multiple_of=256,
458
- return_residual=False,
459
- fused_bias_fc=True,
460
- device=None,
461
- dtype=None,
462
- ):
463
- super().__init__()
464
- out_features = out_features if out_features is not None else in_features
465
- hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
466
- hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
467
- self.return_residual = return_residual
468
-
469
- self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
470
- self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
471
- self.activation = activation
472
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
473
-
474
- def forward(self, x):
475
- y = self.fc11(x)
476
- gate = self.fc12(x)
477
- if self.activation == F.sigmoid: # Special case for GLU
478
- y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
479
- else:
480
- y = y * self.activation(gate)
481
- y = self.fc2(y)
482
- return y if not self.return_residual else (y, x)
483
-
484
-
485
- def rotate_half(x, interleaved=False):
486
- if not interleaved:
487
- x1, x2 = x.chunk(2, dim=-1)
488
- return torch.cat((-x2, x1), dim=-1)
489
- else:
490
- x1, x2 = x[..., ::2], x[..., 1::2]
491
- return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
492
-
493
-
494
- def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
495
- """
496
- x: (batch_size, seqlen, nheads, headdim)
497
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
498
- """
499
- ro_dim = cos.shape[-1] * 2
500
- assert ro_dim <= x.shape[-1]
501
- cos, sin = (
502
- cos[offset : offset + x.shape[1]],
503
- sin[offset : offset + x.shape[1]],
504
- )
505
- cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
506
- sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
507
- return torch.cat(
508
- [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
509
- dim=-1,
510
- )
511
-
512
-
513
- class NomicBertRotaryEmbedding(nn.Module):
514
- def __init__(
515
- self,
516
- dim: int,
517
- base=10000.0,
518
- interleaved=False,
519
- scale_base=None,
520
- pos_idx_in_fp32=True,
521
- device=None,
522
- ):
523
- """
524
- interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
525
- of 1st half and 2nd half (GPT-NeoX style).
526
- pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
527
- otherwise they might be in lower precision.
528
- This option was added because previously (before 2023-07-02), when we construct
529
- the position indices, we use the dtype of self.inv_freq. In most cases this would
530
- be fp32, but if the model is trained in pure bf16 (not mixed precision), then
531
- self.inv_freq would be bf16, and the position indices are also in bf16.
532
- Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
533
- embeddings for some positions will coincide.
534
- To maintain compatibility with models previously trained in pure bf16,
535
- we add this option.
536
- """
537
- super().__init__()
538
- self.dim = dim
539
- self.base = float(base)
540
- self.pos_idx_in_fp32 = pos_idx_in_fp32
541
- # Generate and save the inverse frequency buffer (non trainable)
542
- inv_freq = self._compute_inv_freq(device)
543
- self.register_buffer("inv_freq", inv_freq, persistent=False)
544
- self.interleaved = interleaved
545
- self.scale_base = scale_base
546
- scale = (
547
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
548
- if scale_base is not None
549
- else None
550
- )
551
- self.register_buffer("scale", scale, persistent=False)
552
-
553
- self._seq_len_cached = 0
554
- self._cos_cached = None
555
- self._sin_cached = None
556
- self._cos_k_cached = None
557
- self._sin_k_cached = None
558
-
559
- def _compute_inv_freq(self, device=None):
560
- return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
561
-
562
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
563
- # Reset the tables if the sequence length has changed,
564
- # if we're on a new device (possibly due to tracing for instance),
565
- # or if we're switching from inference mode to training
566
- if (
567
- seqlen > self._seq_len_cached
568
- or self._cos_cached is None
569
- or self._cos_cached.device != device
570
- or self._cos_cached.dtype != dtype
571
- or (self.training and self._cos_cached.is_inference())
572
- ):
573
- self._seq_len_cached = seqlen
574
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
575
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
576
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
577
- if self.pos_idx_in_fp32:
578
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
579
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
580
- # will be large. Having it in bf16 will lose a lot of precision and cause the
581
- # cos & sin output to change significantly.
582
- # We want to recompute self.inv_freq if it was not loaded in fp32
583
- if self.inv_freq.dtype != torch.float32:
584
- inv_freq = self._compute_inv_freq(device=device)
585
- else:
586
- inv_freq = self.inv_freq
587
- else:
588
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
589
- inv_freq = self.inv_freq
590
- # Don't do einsum, it converts fp32 to fp16 under AMP
591
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
592
- freqs = torch.outer(t, inv_freq)
593
- self._cos_cached = torch.cos(freqs).to(dtype)
594
- self._sin_cached = torch.sin(freqs).to(dtype)
595
-
596
- def forward(
597
- self,
598
- qkv: torch.Tensor,
599
- kv: Optional[torch.Tensor] = None,
600
- seqlen_offset: Union[int, torch.Tensor] = 0,
601
- max_seqlen: Optional[int] = None,
602
- ) -> Tuple[torch.Tensor, torch.Tensor]:
603
- """
604
- qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
605
- else it's just q of shape (batch, seqlen, nheads, headdim)
606
- kv: (batch, seqlen, 2, nheads, headdim)
607
- seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
608
- Most commonly used in inference when we have KV cache.
609
- If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
610
- should pass in max_seqlen, which will update the cos / sin cache up to that length.
611
- Apply rotary embedding *inplace* to qkv and / or kv.
612
- """
613
- seqlen = qkv.shape[1]
614
- if seqlen > self._seq_len_cached:
615
- self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
616
- elif max_seqlen is not None:
617
- self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
618
- elif isinstance(seqlen_offset, int):
619
- self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
620
-
621
- q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
622
- k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
623
- return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
624
-
625
-
626
- class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
627
- def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
628
- super().__init__(**kwargs)
629
- self.rotary_scaling_factor = rotary_scaling_factor
630
- self.max_position_embeddings = max_position_embeddings
631
-
632
- def _compute_inv_freq(self, base=None, device=None):
633
- if base is None:
634
- base = self.base
635
- return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
636
-
637
- def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
638
- # Reset the tables if the sequence length has changed,
639
- # if we're on a new device (possibly due to tracing for instance),
640
- # or if we're switching from inference mode to training
641
- if seqlen > self.max_position_embeddings:
642
- base = self.base * (
643
- (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
644
- ) ** (self.dim / (self.dim - 2))
645
- inv_freq = self._compute_inv_freq(base=base, device=device)
646
- self.register_buffer("inv_freq", inv_freq, persistent=False)
647
-
648
- if (
649
- seqlen > self._seq_len_cached
650
- or self._cos_cached is None
651
- or self._cos_cached.device != device
652
- or self._cos_cached.dtype != dtype
653
- or (self.training and self._cos_cached.is_inference())
654
- ):
655
- self._seq_len_cached = seqlen
656
- # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
657
- # And the output of arange can be quite large, so bf16 would lose a lot of precision.
658
- # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
659
- if self.pos_idx_in_fp32:
660
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
661
- # We want fp32 here as well since inv_freq will be multiplied with t, and the output
662
- # will be large. Having it in bf16 will lose a lot of precision and cause the
663
- # cos & sin output to change significantly.
664
- # We want to recompute self.inv_freq if it was not loaded in fp32
665
- if self.inv_freq.dtype != torch.float32:
666
- if seqlen > self.max_position_embeddings:
667
- base = self.base * (
668
- (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
669
- ) ** (self.dim / (self.dim - 2))
670
- else:
671
- base = self.base
672
- inv_freq = self._compute_inv_freq(device=device, base=base)
673
- else:
674
- inv_freq = self.inv_freq
675
- else:
676
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
677
- inv_freq = self.inv_freq
678
- # Don't do einsum, it converts fp32 to fp16 under AMP
679
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
680
- freqs = torch.outer(t, inv_freq)
681
- if self.scale is None:
682
- self._cos_cached = torch.cos(freqs).to(dtype)
683
- self._sin_cached = torch.sin(freqs).to(dtype)
684
- else:
685
- power = (
686
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
687
- ) / self.scale_base
688
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
689
- # We want the multiplication by scale to happen in fp32
690
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
691
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
692
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
693
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
694
-
695
-
696
- class NomicBertAttention(nn.Module):
697
- """Multi-head self-attention and cross-attention"""
698
-
699
- def __init__(
700
- self,
701
- config,
702
- ) -> None:
703
- """
704
- num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
705
- return_residual: whether to return the input x along with the output. This is for
706
- performance reason: for post-norm architecture, returning the input allows us
707
- to fuse the backward of nn.Linear with the residual connection.
708
- """
709
- super().__init__()
710
- self.embed_dim = config.n_embd
711
- self.use_flash_attn = config.use_flash_attn
712
- self.fused_bias_fc = config.fused_bias_fc
713
-
714
- self.num_heads = config.n_head
715
- self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
716
- assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
717
- self.head_dim = self.embed_dim // self.num_heads
718
- # we don't really support mqa / gqa for now
719
- qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
720
-
721
- self.register_buffer(
722
- "norm_factor",
723
- torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
724
- persistent=False,
725
- )
726
-
727
- self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
728
- if self.rotary_emb_dim > 0:
729
- if getattr(config, "rotary_scaling_factor", None):
730
- self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
731
- dim=self.rotary_emb_dim,
732
- base=config.rotary_emb_base,
733
- scale_base=config.rotary_emb_scale_base,
734
- interleaved=config.rotary_emb_interleaved,
735
- rotary_scaling_factor=config.rotary_scaling_factor,
736
- max_position_embeddings=config.max_trained_positions,
737
- )
738
- else:
739
- self.rotary_emb = NomicBertRotaryEmbedding(
740
- dim=self.rotary_emb_dim,
741
- base=config.rotary_emb_base,
742
- scale_base=config.rotary_emb_scale_base,
743
- interleaved=config.rotary_emb_interleaved,
744
- )
745
- # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
746
- # uses the head dimension instead of the sequence dimension
747
- self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
748
-
749
- self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
750
-
751
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
752
- self.causal = config.causal
753
- self.drop = nn.Dropout(config.attn_pdrop)
754
-
755
- def forward(
756
- self,
757
- hidden_states: torch.Tensor,
758
- attention_mask: Optional[torch.Tensor] = None,
759
- position_ids: Optional[torch.LongTensor] = None,
760
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
761
- output_attentions: bool = False,
762
- use_cache: bool = False,
763
- is_padded_inputs: Optional[bool] = True,
764
- cu_seqlens: Optional[torch.Tensor] = None,
765
- max_seq_len: Optional[int] = None,
766
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
767
-
768
- has_layer_past = past_key_value is not None
769
-
770
- if has_layer_past:
771
- past_key_value = past_key_value[0]
772
- past_len = past_key_value[1]
773
- else:
774
- past_len = 0
775
-
776
- qkv = self.Wqkv(hidden_states)
777
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
778
-
779
- past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
780
-
781
- if self.rotary_emb_dim > 0:
782
- if self.rotary_head_dim:
783
- qkv = rearrange(qkv, "b s three h d -> b h three s d")
784
- qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
785
-
786
- if self.rotary_head_dim:
787
- qkv = rearrange(qkv, "b h three s d -> b s three h d")
788
-
789
- query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
790
-
791
- query = query.permute(0, 2, 1, 3)
792
- key = key.permute(0, 2, 1, 3)
793
- value = value.permute(0, 2, 1, 3)
794
-
795
- attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
796
- if attention_mask is not None:
797
- attention_scores = attention_scores + attention_mask
798
-
799
- attentions_probs = F.softmax(attention_scores, dim=-1)
800
- attentions_probs = self.drop(attentions_probs)
801
-
802
- attn_output = torch.matmul(attentions_probs, value)
803
- attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
804
-
805
- attn_output = self.out_proj(attn_output)
806
-
807
- return attn_output
808
-
809
-
810
- class NomicBertBlock(nn.Module):
811
- def __init__(
812
- self,
813
- config,
814
- ):
815
- super().__init__()
816
- self.prenorm = config.prenorm
817
- self.fused_dropout_add_ln = config.fused_dropout_add_ln
818
-
819
- self.attn = NomicBertAttention(config)
820
- activation = (
821
- F.sigmoid
822
- if config.activation_function == "glu"
823
- else (F.silu if config.activation_function == "swiglu" else F.gelu)
824
- )
825
- if config.activation_function in ["glu", "swiglu", "geglu"]:
826
- self.mlp = NomciBertGatedMLP(
827
- config.n_embd,
828
- hidden_features=config.n_inner,
829
- bias1=config.mlp_fc1_bias,
830
- bias2=config.mlp_fc2_bias,
831
- activation=activation,
832
- fused_bias_fc=config.fused_bias_fc,
833
- )
834
- else:
835
- self.mlp = NomicBertMLP(
836
- config.n_embd,
837
- hidden_features=config.n_inner,
838
- bias1=config.mlp_fc1_bias,
839
- bias2=config.mlp_fc2_bias,
840
- activation=activation,
841
- fused_bias_fc=config.fused_bias_fc,
842
- )
843
-
844
- self.dropout1 = nn.Dropout(config.resid_pdrop)
845
- self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
846
- self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
847
- self.dropout2 = nn.Dropout(config.resid_pdrop)
848
-
849
- def forward(
850
- self,
851
- hidden_states: torch.Tensor,
852
- hidden_states2: torch.Tensor,
853
- residual: Optional[torch.Tensor] = None,
854
- attention_mask: Optional[torch.Tensor] = None,
855
- position_ids: Optional[torch.LongTensor] = None,
856
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
857
- is_padded_inputs: Optional[bool] = True,
858
- output_attentions: Optional[bool] = False,
859
- use_cache: Optional[bool] = False,
860
- cu_seqlens: Optional[torch.Tensor] = None,
861
- max_seq_len: Optional[int] = None,
862
- ):
863
- r"""Pass the input through the encoder layer.
864
-
865
- Args:
866
- hidden_states: the sequence to the encoder layer (required).
867
- residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
868
- mixer_subset: for cross-attention only. If not None, will take a subset of x
869
- before applying the query projection. Useful for e.g., ViT where we only care
870
- about the CLS token in the last layer.
871
- """
872
- if self.prenorm:
873
- dropped = self.dropout1(hidden_states)
874
- residual = (dropped + residual) if residual is not None else dropped
875
- hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
876
- hidden_states = self.attn(
877
- hidden_states,
878
- attention_mask=attention_mask,
879
- is_padded_inputs=is_padded_inputs,
880
- cu_seqlens=cu_seqlens,
881
- max_seq_len=max_seq_len,
882
- )
883
-
884
- dropped = self.dropout2(hidden_states)
885
- residual = (dropped + residual) if residual is not None else dropped
886
- hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
887
- hidden_states = self.mlp(hidden_states)
888
-
889
- return hidden_states, None, residual
890
- else:
891
- assert residual is None
892
- attn_outputs = self.attn(
893
- hidden_states,
894
- attention_mask=attention_mask,
895
- is_padded_inputs=is_padded_inputs,
896
- cu_seqlens=cu_seqlens,
897
- max_seq_len=max_seq_len,
898
- )
899
- hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
900
- mlp_out = self.mlp(hidden_states)
901
-
902
- hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
903
- return hidden_states, None, None
904
-
905
-
906
- class NomicBertEncoder(nn.Module):
907
- def __init__(self, config: GPT2Config):
908
- super().__init__()
909
- self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
910
- self.gradient_checkpointing = False
911
- self.config = config
912
-
913
- def forward(
914
- self,
915
- hidden_states: torch.LongTensor = None,
916
- attention_mask: Optional[torch.Tensor] = None,
917
- position_ids: Optional[torch.LongTensor] = None,
918
- past_key_values: Optional[List[torch.FloatTensor]] = None,
919
- inputs_embeds: Optional[torch.FloatTensor] = None,
920
- use_cache: Optional[bool] = None,
921
- output_attentions: Optional[bool] = None,
922
- output_hidden_states: Optional[bool] = None,
923
- return_dict: Optional[bool] = None,
924
- is_padded_inputs: Optional[bool] = True,
925
- ):
926
- """If subset_mask is not None, we only want output for the subset of the sequence.
927
- This means that we only compute the last layer output for these tokens.
928
- subset_mask: (batch, seqlen), dtype=torch.bool
929
- """
930
- hidden_states2 = None
931
- residual = None
932
-
933
- for _, layer in enumerate(self.layers):
934
- if self.gradient_checkpointing and self.training:
935
-
936
- def create_custom_forward(module):
937
- def custom_forward(*inputs):
938
- # None for past_key_value
939
- return module(*inputs)
940
-
941
- return custom_forward
942
-
943
- hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
944
- create_custom_forward(layer),
945
- hidden_states,
946
- hidden_states2,
947
- residual,
948
- attention_mask,
949
- None,
950
- None,
951
- is_padded_inputs,
952
- # if you freeze ANY layers, you need `use_reentrant=False`
953
- # https://github.com/huggingface/transformers/issues/21381
954
- # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
955
- use_reentrant=False,
956
- )
957
-
958
- else:
959
- hidden_states, hidden_states2, residual = layer(
960
- hidden_states,
961
- hidden_states2,
962
- residual,
963
- attention_mask,
964
- position_ids,
965
- None,
966
- is_padded_inputs,
967
- output_attentions,
968
- use_cache,
969
- )
970
- return hidden_states
971
-
972
-
973
- class NomicBertPooler(nn.Module):
974
- def __init__(self, config):
975
- super().__init__()
976
- self.dense = nn.Linear(config.n_embd, config.n_embd)
977
- self.activation = nn.Tanh()
978
-
979
- def forward(self, hidden_states, pool=True):
980
- # We "pool" the model by simply taking the hidden state corresponding
981
- # to the first token.
982
- first_token_tensor = hidden_states[:, 0] if pool else hidden_states
983
- pooled_output = self.dense(first_token_tensor)
984
- pooled_output = self.activation(pooled_output)
985
- return pooled_output
986
-
987
-
988
- class NomicBertPredictionHeadTransform(nn.Module):
989
- def __init__(self, config):
990
- super().__init__()
991
- self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
992
- approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
993
- if config.activation_function == "swiglu":
994
- self.transform_act_fn = F.silu
995
- else:
996
- self.transform_act_fn = nn.GELU(approximate=approximate)
997
-
998
- self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
999
-
1000
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1001
- hidden_states = self.dense(hidden_states)
1002
- hidden_states = self.transform_act_fn(hidden_states)
1003
- hidden_states = self.layer_norm(hidden_states)
1004
-
1005
- return hidden_states
1006
-
1007
-
1008
- class NomicBertLMPredictionHead(nn.Module):
1009
- def __init__(self, config):
1010
- super().__init__()
1011
-
1012
- self.transform = NomicBertPredictionHeadTransform(config)
1013
-
1014
- self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
1015
-
1016
- def forward(self, hidden_states):
1017
- hidden_states = self.transform(hidden_states)
1018
- hidden_states = self.decoder(hidden_states)
1019
- return hidden_states
1020
-
1021
-
1022
- class NomicBertPreTrainingHeads(nn.Module):
1023
- def __init__(self, config):
1024
- super().__init__()
1025
- self.predictions = NomicBertLMPredictionHead(config)
1026
-
1027
- def forward(self, sequence_output):
1028
- prediction_scores = self.predictions(sequence_output)
1029
- return prediction_scores
1030
-
1031
-
1032
- class NomicBertModel(NomicBertPreTrainedModel):
1033
- def __init__(self, config: GPT2Config, add_pooling_layer=True):
1034
- super().__init__(config)
1035
- self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1036
- if config.vocab_size % self.pad_vocab_size_multiple != 0:
1037
- config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
1038
-
1039
- assert config.activation_function in [
1040
- "gelu",
1041
- "gelu_new",
1042
- "gelu_fast",
1043
- "gelu_pytorch_tanh",
1044
- "swiglu",
1045
- "geglu",
1046
- "glu",
1047
- ]
1048
-
1049
- self.embeddings = NomicBertEmbeddings(config)
1050
- self.emb_drop = nn.Dropout(config.resid_pdrop)
1051
- self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1052
- self.encoder = NomicBertEncoder(config)
1053
- self.pooler = NomicBertPooler(config) if add_pooling_layer else None
1054
-
1055
- self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1056
-
1057
- def forward(
1058
- self,
1059
- input_ids,
1060
- position_ids=None,
1061
- token_type_ids=None,
1062
- attention_mask=None,
1063
- return_dict=None,
1064
- matryoshka_dim=None,
1065
- ):
1066
- if token_type_ids is None:
1067
- token_type_ids = torch.zeros_like(input_ids)
1068
- hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
1069
- hidden_states = self.emb_ln(hidden_states)
1070
- hidden_states = self.emb_drop(hidden_states)
1071
-
1072
- attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1073
- sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
1074
-
1075
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1076
-
1077
- if matryoshka_dim:
1078
- sequence_output = sequence_output[:, :matryoshka_dim]
1079
-
1080
- return BaseModelOutputWithPoolingAndCrossAttentions(
1081
- last_hidden_state=sequence_output,
1082
- pooler_output=pooled_output,
1083
- )
1084
-
1085
-
1086
- class NomicBertForPreTraining(NomicBertPreTrainedModel):
1087
- _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1088
-
1089
- def __init__(self, config: GPT2Config):
1090
- super().__init__(config)
1091
-
1092
- self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
1093
- self.cls = NomicBertPreTrainingHeads(config)
1094
- self.mlm_loss = nn.CrossEntropyLoss()
1095
-
1096
- # Initialize weights and apply final processing
1097
- self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1098
- self.tie_weights()
1099
-
1100
- def tie_weights(self):
1101
- self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
1102
-
1103
- def forward(
1104
- self,
1105
- input_ids,
1106
- position_ids=None,
1107
- token_type_ids=None,
1108
- attention_mask=None,
1109
- labels=None,
1110
- ):
1111
- """
1112
- If labels are provided, they must be -100 for masked out tokens (as specified in the attention
1113
- mask).
1114
- Outputs:
1115
- if `labels` and `next_sentence_label` are not `None`:
1116
- Outputs the total_loss which is the sum of the masked language modeling loss and the next
1117
- sentence classification loss.
1118
- if `labels` or `next_sentence_label` is `None`:
1119
- Outputs a tuple comprising
1120
- - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1121
- - the next sentence classification logits of shape [batch_size, 2].
1122
-
1123
- """
1124
- outputs = self.bert(
1125
- input_ids,
1126
- position_ids=position_ids,
1127
- token_type_ids=token_type_ids,
1128
- attention_mask=attention_mask.bool() if attention_mask is not None else None,
1129
- )
1130
- sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
1131
-
1132
- prediction_scores = self.cls(sequence_output)
1133
-
1134
- total_loss = None
1135
- if labels is not None:
1136
- masked_lm_loss = self.mlm_loss(
1137
- rearrange(prediction_scores, "... v -> (...) v"),
1138
- rearrange(labels, "... -> (...)"),
1139
- )
1140
- total_loss = masked_lm_loss.float()
1141
-
1142
- return MaskedLMOutput(
1143
- loss=total_loss,
1144
- logits=prediction_scores,
1145
- hidden_states=outputs.hidden_states,
1146
- attentions=None,
1147
- )
1148
-
1149
-
1150
- class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1151
- def __init__(self, config):
1152
- super().__init__(config)
1153
- self.num_labels = config.num_labels
1154
- self.config = config
1155
-
1156
- self.bert = NomicBertModel(config)
1157
- classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
1158
- self.dropout = nn.Dropout(classifier_dropout)
1159
- self.classifier = nn.Linear(config.n_embd, config.num_labels)
1160
-
1161
- # Initialize weights and apply final processing
1162
- self.post_init()
1163
-
1164
- def forward(
1165
- self,
1166
- input_ids: Optional[torch.Tensor] = None,
1167
- attention_mask: Optional[torch.Tensor] = None,
1168
- token_type_ids: Optional[torch.Tensor] = None,
1169
- position_ids: Optional[torch.Tensor] = None,
1170
- head_mask: Optional[torch.Tensor] = None,
1171
- inputs_embeds: Optional[torch.Tensor] = None,
1172
- labels: Optional[torch.Tensor] = None,
1173
- output_attentions: Optional[bool] = None,
1174
- output_hidden_states: Optional[bool] = None,
1175
- return_dict: Optional[bool] = None,
1176
- ):
1177
- r"""
1178
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1179
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1180
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1181
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1182
- """
1183
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1184
- outputs = self.bert(
1185
- input_ids,
1186
- position_ids=position_ids,
1187
- token_type_ids=token_type_ids,
1188
- attention_mask=attention_mask.bool() if attention_mask is not None else None,
1189
- )
1190
-
1191
- pooled_output = outputs[1]
1192
-
1193
- pooled_output = self.dropout(pooled_output)
1194
- logits = self.classifier(pooled_output)
1195
-
1196
- loss = None
1197
- if labels is not None:
1198
- if self.config.problem_type is None:
1199
- if self.num_labels == 1:
1200
- self.config.problem_type = "regression"
1201
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1202
- self.config.problem_type = "single_label_classification"
1203
- else:
1204
- self.config.problem_type = "multi_label_classification"
1205
-
1206
- if self.config.problem_type == "regression":
1207
- loss_fct = nn.MSELoss()
1208
- if self.num_labels == 1:
1209
- loss = loss_fct(logits.squeeze(), labels.squeeze())
1210
- else:
1211
- loss = loss_fct(logits, labels)
1212
- elif self.config.problem_type == "single_label_classification":
1213
- loss_fct = nn.CrossEntropyLoss()
1214
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1215
- elif self.config.problem_type == "multi_label_classification":
1216
- loss_fct = nn.BCEWithLogitsLoss()
1217
- loss = loss_fct(logits, labels)
1218
- if not return_dict:
1219
- output = (logits,) + outputs[2:]
1220
- return ((loss,) + output) if loss is not None else output
1221
-
1222
- return SequenceClassifierOutput(
1223
- loss=loss,
1224
- logits=logits,
1225
- hidden_states=outputs.hidden_states,
1226
- attentions=outputs.attentions,
1227
- )