zpn commited on
Commit
a350440
1 Parent(s): 523e3af

Upload model

Browse files
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "swiglu",
3
+ "architectures": [
4
+ "NomicBertModel"
5
+ ],
6
+ "attn_pdrop": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_hf_nomic_bert.NomicBertConfig",
9
+ "AutoModel": "modeling_hf_nomic_bert.NomicBertModel",
10
+ "AutoModelForMaskedLM": "nomic-ai/nomic-bert-2048--modeling_hf_nomic_bert.NomicBertForPreTraining"
11
+ },
12
+ "bos_token_id": null,
13
+ "causal": false,
14
+ "dense_seq_output": true,
15
+ "embd_pdrop": 0.1,
16
+ "eos_token_id": null,
17
+ "fused_bias_fc": true,
18
+ "fused_dropout_add_ln": true,
19
+ "initializer_range": 0.02,
20
+ "layer_norm_epsilon": 1e-12,
21
+ "mlp_fc1_bias": false,
22
+ "mlp_fc2_bias": false,
23
+ "model_type": "nomic_bert",
24
+ "n_embd": 768,
25
+ "n_head": 12,
26
+ "n_inner": 3072,
27
+ "n_layer": 12,
28
+ "n_positions": 2048,
29
+ "pad_vocab_size_multiple": 64,
30
+ "parallel_block": false,
31
+ "parallel_block_tied_norm": false,
32
+ "prenorm": false,
33
+ "qkv_proj_bias": false,
34
+ "reorder_and_upcast_attn": false,
35
+ "resid_pdrop": 0.1,
36
+ "rotary_emb_base": 1000,
37
+ "rotary_emb_fraction": 1.0,
38
+ "rotary_emb_interleaved": false,
39
+ "rotary_emb_scale_base": null,
40
+ "rotary_scaling_factor": null,
41
+ "scale_attn_by_inverse_layer_idx": false,
42
+ "scale_attn_weights": true,
43
+ "summary_activation": null,
44
+ "summary_first_dropout": 0.1,
45
+ "summary_proj_to_labels": true,
46
+ "summary_type": "cls_index",
47
+ "summary_use_proj": true,
48
+ "torch_dtype": "float32",
49
+ "transformers_version": "4.34.0",
50
+ "type_vocab_size": 2,
51
+ "use_cache": true,
52
+ "use_flash_attn": true,
53
+ "use_rms_norm": false,
54
+ "use_xentropy": true,
55
+ "vocab_size": 30528
56
+ }
configuration_hf_nomic_bert.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+
4
+ class NomicBertConfig(GPT2Config):
5
+ model_type = "nomic_bert"
6
+
7
+ def __init__(self,
8
+ prenorm=False,
9
+ parallel_block=False,
10
+ parallel_block_tied_norm=False,
11
+ rotary_emb_fraction=0.0,
12
+ fused_dropout_add_ln=False,
13
+ fused_bias_fc=False,
14
+ use_flash_attn=False,
15
+ use_xentropy=False,
16
+ qkv_proj_bias=True,
17
+ rotary_emb_base=1000,
18
+ rotary_emb_scale_base=None,
19
+ rotary_emb_interleaved=False,
20
+ mlp_fc1_bias=True,
21
+ mlp_fc2_bias=True,
22
+ use_rms_norm=False,
23
+ causal=False,
24
+ type_vocab_size=2,
25
+ dense_seq_output=True,
26
+ pad_vocab_size_multiple=1,
27
+ tie_word_embeddings=True,
28
+ rotary_scaling_factor=1.0,
29
+ **kwargs,
30
+ ):
31
+ self.prenorm = prenorm
32
+ self.parallel_block = parallel_block
33
+ self.parallel_block_tied_norm = parallel_block_tied_norm
34
+ self.rotary_emb_fraction = rotary_emb_fraction
35
+ self.tie_word_embeddings = tie_word_embeddings
36
+ self.fused_dropout_add_ln = fused_dropout_add_ln
37
+ self.fused_bias_fc = fused_bias_fc
38
+ self.use_flash_attn = use_flash_attn
39
+ self.use_xentropy = use_xentropy
40
+ self.qkv_proj_bias = qkv_proj_bias
41
+ self.rotary_emb_base = rotary_emb_base
42
+ self.rotary_emb_scale_base = rotary_emb_scale_base
43
+ self.rotary_emb_interleaved = rotary_emb_interleaved
44
+ self.mlp_fc1_bias = mlp_fc1_bias
45
+ self.mlp_fc2_bias = mlp_fc2_bias
46
+ self.use_rms_norm = use_rms_norm
47
+ self.causal = causal
48
+ self.type_vocab_size = type_vocab_size
49
+ self.dense_seq_output = dense_seq_output
50
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
51
+ self.rotary_scaling_factor = rotary_scaling_factor
52
+
53
+ super().__init__(**kwargs)
modeling_hf_nomic_bert.py ADDED
@@ -0,0 +1,1237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
+ import os
8
+ import logging
9
+ from functools import partial
10
+ from typing import Optional, List, Tuple, Union
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from einops import rearrange, repeat
16
+ from transformers import GPT2Config, PreTrainedModel
17
+ from transformers.models.bert.modeling_bert import (
18
+ BaseModelOutputWithPoolingAndCrossAttentions,
19
+ MaskedLMOutput,
20
+ SequenceClassifierOutput
21
+ )
22
+
23
+ import re
24
+ from collections import OrderedDict
25
+ from safetensors.torch import load_file as safe_load_file
26
+ from transformers.utils import (
27
+ SAFE_WEIGHTS_INDEX_NAME,
28
+ SAFE_WEIGHTS_NAME,
29
+ WEIGHTS_INDEX_NAME,
30
+ WEIGHTS_NAME,
31
+ )
32
+ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
33
+
34
+
35
+ from .configuration_hf_nomic_bert import NomicBertConfig
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ # adapted from flash attention, added safe serialization option for hf models
40
+ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
41
+ # If not fp32, then we don't want to load directly to the GPU
42
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
43
+ is_sharded = False
44
+ load_safe = False
45
+ resolved_archive_file = None
46
+
47
+ weights_path = os.path.join(model_name, WEIGHTS_NAME)
48
+ weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
49
+ safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
50
+ safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
51
+
52
+ if os.path.isfile(weights_path):
53
+ resolved_archive_file = cached_file(
54
+ model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
55
+ )
56
+ elif os.path.isfile(weights_index_path):
57
+ resolved_archive_file = cached_file(
58
+ model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
59
+ )
60
+ is_sharded = True
61
+ elif os.path.isfile(safe_weights_path):
62
+ resolved_archive_file = cached_file(
63
+ model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
64
+ )
65
+ load_safe = True
66
+ elif os.path.isfile(safe_weights_index_path):
67
+ resolved_archive_file = cached_file(
68
+ model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
69
+ )
70
+ is_sharded = True
71
+ load_safe = True
72
+ else: # Try loading from HF hub instead of from local files
73
+ weight_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME
74
+ resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
75
+ if resolved_archive_file is None:
76
+ weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
77
+ resolved_archive_file = cached_file(model_name, weight_index,
78
+ _raise_exceptions_for_missing_entries=False)
79
+ if resolved_archive_file is not None:
80
+ is_sharded = True
81
+
82
+ load_safe = safe_serialization
83
+
84
+ if resolved_archive_file is None:
85
+ raise EnvironmentError(f"Model name {model_name} was not found.")
86
+
87
+ if load_safe:
88
+ loader = partial(safe_load_file, device=mapped_device)
89
+ else:
90
+ loader = partial(torch.load, map_location=mapped_device)
91
+
92
+ if is_sharded:
93
+ # resolved_archive_file becomes a list of files that point to the different
94
+ # checkpoint shards in this case.
95
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
96
+ model_name, resolved_archive_file
97
+ )
98
+ state_dict = {}
99
+ for sharded_file in resolved_archive_file:
100
+ state_dict.update(loader(sharded_file))
101
+ else:
102
+ state_dict = loader(resolved_archive_file)
103
+ # Convert dtype before moving to GPU to save memory
104
+ if dtype is not None:
105
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
106
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
107
+ return state_dict
108
+
109
+
110
+ def filter_shapes(state_dict, model):
111
+ """
112
+ Filters the state dict to match the current model shape.
113
+ """
114
+ filtered_state_dict = {}
115
+ for key, value in state_dict.items():
116
+ if key in model.state_dict():
117
+ if value.shape == model.state_dict()[key].shape:
118
+ filtered_state_dict[key] = value
119
+ return filtered_state_dict
120
+
121
+
122
+ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
123
+ """
124
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
125
+ """
126
+ def add_bert_prefix(key):
127
+ # prepend bert. to the key
128
+ if key.startswith("bert.") or key.startswith("cls."):
129
+ return key
130
+ return f"bert.{key}"
131
+
132
+ state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
133
+
134
+ # LayerNorm
135
+ def key_mapping_ln_gamma_beta(key):
136
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
137
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
138
+ return key
139
+
140
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
141
+
142
+ # Layers
143
+ def key_mapping_layers(key):
144
+ return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
145
+
146
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
147
+
148
+ # LayerNorm
149
+ def key_mapping_ln(key):
150
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
151
+ key = re.sub(
152
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
153
+ r"bert.encoder.layers.\1.norm1.\2",
154
+ key,
155
+ )
156
+ key = re.sub(
157
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
158
+ r"bert.encoder.layers.\1.norm2.\2",
159
+ key,
160
+ )
161
+ key = re.sub(
162
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
163
+ r"cls.predictions.transform.layer_norm.\1",
164
+ key,
165
+ )
166
+ return key
167
+
168
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
169
+
170
+ # MLP
171
+ def key_mapping_mlp(key):
172
+ key = re.sub(
173
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
174
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
175
+ key,
176
+ )
177
+ key = re.sub(
178
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
179
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
180
+ key,
181
+ )
182
+ return key
183
+
184
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
185
+
186
+ # Attention
187
+ last_layer_subset = getattr(config, "last_layer_subset", False)
188
+ for d in range(config.num_hidden_layers):
189
+ if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
190
+ continue
191
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
192
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
193
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
194
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
195
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
196
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
197
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
198
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat(
199
+ [Wq, Wk, Wv], dim=0
200
+ )
201
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
202
+ else:
203
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
204
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
205
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
206
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
207
+
208
+ def key_mapping_attn(key):
209
+ return re.sub(
210
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
211
+ r"bert.encoder.layers.\1.attn.out_proj.\2",
212
+ key,
213
+ )
214
+
215
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
216
+
217
+ def key_mapping_decoder_bias(key):
218
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
219
+
220
+
221
+ # remove nsp weights, we don't use
222
+ state_dict.pop("cls.seq_relationship.weight", None)
223
+ state_dict.pop("cls.seq_relationship.bias", None)
224
+ state_dict.pop("bert.embeddings.position_ids", None)
225
+
226
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
227
+
228
+ if remove_cls_weights:
229
+ cls_weights = ["cls.predictions.decoder.bias",
230
+ "cls.predictions.transform.dense.weight",
231
+ "cls.predictions.transform.dense.bias",
232
+ "cls.predictions.transform.layer_norm.weight",
233
+ "cls.predictions.transform.layer_norm.bias",
234
+ "cls.predictions.decoder.weight"]
235
+ for weight in cls_weights:
236
+ state_dict.pop(weight, None)
237
+
238
+ # Word embedding
239
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
240
+ if pad_vocab_size_multiple > 1:
241
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
242
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
243
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
244
+ )
245
+ if not remove_cls_weights:
246
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
247
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
248
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
249
+ )
250
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
251
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
252
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
253
+ if "cls.predictions.decoder.bias" in state_dict:
254
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
255
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
256
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
257
+ )
258
+
259
+ if add_pooling_layer is False:
260
+ pooler_weights = ["bert.pooler.dense.weight",
261
+ "bert.pooler.dense.bias",
262
+ ]
263
+ for key in pooler_weights:
264
+ state_dict.pop(key, None)
265
+
266
+ if remove_bert:
267
+ def remove_bert_prefix(key):
268
+ key = re.sub(r"^bert.", "", key)
269
+ return key
270
+
271
+ state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
272
+
273
+
274
+ return state_dict
275
+
276
+
277
+ class NomicBertPreTrainedModel(PreTrainedModel):
278
+ """An abstract class to handle weights initialization and
279
+ a simple interface for dowloading and loading pretrained models.
280
+ """
281
+ config_class = NomicBertConfig
282
+ base_model_prefix = "model"
283
+ supports_gradient_checkpointing = True
284
+ _no_split_modules = ["Block"]
285
+ _skip_keys_device_placement = "past_key_values"
286
+
287
+ def __init__(self, config, *inputs, **kwargs):
288
+ super().__init__(config)
289
+ if not isinstance(config, GPT2Config):
290
+ raise ValueError(
291
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
292
+ "To create a model from a Google pretrained model use "
293
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
294
+ self.__class__.__name__, self.__class__.__name__
295
+ )
296
+ )
297
+ self.config = config
298
+
299
+ @classmethod
300
+ def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
301
+ """
302
+ Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
303
+ Download and cache the pre-trained model file if needed.
304
+
305
+ Params:
306
+ pretrained_model_name_or_path: either:
307
+ - a path or url to a pretrained model archive containing:
308
+ . `bert_config.json` a configuration file for the model
309
+ . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
310
+ - a path or url to a pretrained model archive containing:
311
+ . `bert_config.json` a configuration file for the model
312
+ . `model.chkpt` a TensorFlow checkpoint
313
+ *inputs, **kwargs: additional input for the specific NomicBert class
314
+ (ex: num_labels for NomicBertForSequenceClassification)
315
+ """
316
+ # Instantiate model.
317
+ if config is None:
318
+ config = cls.config_class.from_pretrained(model_name)
319
+ remove_cls = cls != NomicBertForPreTraining
320
+ remove_bert_prefix = cls != NomicBertForPreTraining
321
+ ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
322
+ num_labels = kwargs.pop("num_labels", None)
323
+ rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
324
+ if rotary_scaling_factor:
325
+ config.rotary_scaling_factor = rotary_scaling_factor
326
+ else:
327
+ config.rotary_scaling_factor = None
328
+ if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
329
+ config.n_positions = 2048
330
+ if num_labels:
331
+ config.num_labels = num_labels
332
+
333
+ if "add_pooling_layer" in kwargs:
334
+ model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
335
+ else:
336
+ if cls == NomicBertModel:
337
+ model = cls(config, *inputs, add_pooling_layer=False)
338
+ else:
339
+ model = cls(config, *inputs)
340
+ # TODO: fix this
341
+ # Assuming we know what we're doing when loading from disk
342
+ # Prob a bad assumption but i'm tired and want to train this asap
343
+ if os.path.exists(model_name):
344
+ state_dict = torch.load(f"{model_name}/pytorch_model.bin")
345
+ if ignore_mismatched_shapes:
346
+ state_dict = filter_shapes(state_dict, model)
347
+ load_return = model.load_state_dict(state_dict, strict=False)
348
+ else:
349
+ # TODO: can probably check config class and see if we need to remap from a bert model
350
+ state_dict = state_dict_from_pretrained(model_name)
351
+ state_dict = remap_bert_state_dict(state_dict,
352
+ config,
353
+ remove_bert=remove_bert_prefix,
354
+ remove_cls_weights=remove_cls,
355
+ add_pooling_layer=getattr(config, "add_pooling_layer", False)
356
+ )
357
+ if ignore_mismatched_shapes:
358
+ state_dict = filter_shapes(state_dict, model)
359
+
360
+ load_return = model.load_state_dict(
361
+ state_dict,
362
+ strict=True
363
+ )
364
+ logger.warning(load_return)
365
+ return model
366
+
367
+ def _set_gradient_checkpointing(self, module, value=False):
368
+ if isinstance(module, NomicBertEncoder):
369
+ module.gradient_checkpointing = value
370
+
371
+
372
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
373
+ def _init_weights(module, initializer_range=0.02):
374
+ if isinstance(module, nn.Linear):
375
+ nn.init.normal_(module.weight, std=initializer_range)
376
+ if module.bias is not None:
377
+ nn.init.zeros_(module.bias)
378
+ elif isinstance(module, nn.Embedding):
379
+ nn.init.normal_(module.weight, std=initializer_range)
380
+ if module.padding_idx is not None:
381
+ nn.init.zeros_(module.weight[module.padding_idx])
382
+
383
+
384
+ class NomicBertEmbeddings(nn.Module):
385
+ def __init__(
386
+ self,
387
+ config
388
+ ):
389
+ """
390
+ If max_position_embeddings <= 0, there's no position embeddings
391
+ If type_vocab_size <= 0, there's no token type embeddings
392
+ """
393
+ super().__init__()
394
+ self.word_embeddings = nn.Embedding(
395
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
396
+ )
397
+ self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
398
+ self.type_vocab_size = config.type_vocab_size
399
+ if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
400
+ self.position_embeddings = nn.Embedding(
401
+ config.max_position_embeddings, config.hidden_size,
402
+ )
403
+ if self.type_vocab_size > 0:
404
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
405
+
406
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
407
+ """
408
+ input_ids: (batch, seqlen)
409
+ position_ids: (batch, seqlen)
410
+ token_type_ids: (batch, seqlen)
411
+ """
412
+ batch_size, seqlen = input_ids.shape
413
+ embeddings = self.word_embeddings(input_ids)
414
+
415
+ if self.type_vocab_size > 0:
416
+ if token_type_ids is None:
417
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
418
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
419
+ embeddings = embeddings + token_type_embeddings
420
+
421
+ if self.max_position_embeddings > 0:
422
+ if position_ids is None:
423
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
424
+ position_embeddings = self.position_embeddings(position_ids)
425
+ embeddings = embeddings + position_embeddings
426
+ return embeddings
427
+
428
+ class NomicBertMLP(nn.Module):
429
+ def __init__(
430
+ self,
431
+ in_features,
432
+ hidden_features=None,
433
+ out_features=None,
434
+ activation=F.gelu,
435
+ bias1=True,
436
+ bias2=True,
437
+ return_residual=False,
438
+ fused_bias_fc=False,
439
+ ):
440
+ super().__init__()
441
+ out_features = out_features if out_features is not None else in_features
442
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
443
+ self.return_residual = return_residual
444
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
445
+ approximate = (
446
+ "tanh"
447
+ if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
448
+ else "none"
449
+ )
450
+ self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
451
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
452
+
453
+ def forward(self, x):
454
+ y = self.fc1(x)
455
+ y = self.activation(y)
456
+ y = self.fc2(y)
457
+ return y if not self.return_residual else (y, x)
458
+
459
+
460
+ class NomciBertGatedMLP(nn.Module):
461
+ def __init__(
462
+ self,
463
+ in_features,
464
+ hidden_features=None,
465
+ out_features=None,
466
+ activation=F.sigmoid,
467
+ bias1=True,
468
+ bias2=True,
469
+ multiple_of=256,
470
+ return_residual=False,
471
+ fused_bias_fc=True,
472
+ device=None,
473
+ dtype=None,
474
+ ):
475
+ super().__init__()
476
+ out_features = out_features if out_features is not None else in_features
477
+ hidden_features = (
478
+ hidden_features if hidden_features is not None else int(8 * in_features / 3)
479
+ )
480
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
481
+ self.return_residual = return_residual
482
+
483
+ self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
484
+ self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
485
+ self.activation = activation
486
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
487
+
488
+ def forward(self, x):
489
+ y = self.fc11(x)
490
+ gate = self.fc12(x)
491
+ if self.activation == F.sigmoid: # Special case for GLU
492
+ y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
493
+ else:
494
+ y = y * self.activation(gate)
495
+ y = self.fc2(y)
496
+ return y if not self.return_residual else (y, x)
497
+
498
+
499
+ def rotate_half(x, interleaved=False):
500
+ if not interleaved:
501
+ x1, x2 = x.chunk(2, dim=-1)
502
+ return torch.cat((-x2, x1), dim=-1)
503
+ else:
504
+ x1, x2 = x[..., ::2], x[..., 1::2]
505
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
506
+
507
+
508
+ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
509
+ """
510
+ x: (batch_size, seqlen, nheads, headdim)
511
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
512
+ """
513
+ ro_dim = cos.shape[-1] * 2
514
+ assert ro_dim <= x.shape[-1]
515
+ cos, sin = (
516
+ cos[offset: offset + x.shape[1]],
517
+ sin[offset: offset + x.shape[1]],
518
+ )
519
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
520
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
521
+ return torch.cat(
522
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
523
+ dim=-1,
524
+ )
525
+
526
+
527
+ class NomicBertRotaryEmbedding(nn.Module):
528
+ def __init__(
529
+ self,
530
+ dim: int,
531
+ base=10000.0,
532
+ interleaved=False,
533
+ scale_base=None,
534
+ pos_idx_in_fp32=True,
535
+ device=None,
536
+ ):
537
+ """
538
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
539
+ of 1st half and 2nd half (GPT-NeoX style).
540
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
541
+ otherwise they might be in lower precision.
542
+ This option was added because previously (before 2023-07-02), when we construct
543
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
544
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
545
+ self.inv_freq would be bf16, and the position indices are also in bf16.
546
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
547
+ embeddings for some positions will coincide.
548
+ To maintain compatibility with models previously trained in pure bf16,
549
+ we add this option.
550
+ """
551
+ super().__init__()
552
+ self.dim = dim
553
+ self.base = float(base)
554
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
555
+ # Generate and save the inverse frequency buffer (non trainable)
556
+ inv_freq = self._compute_inv_freq(device)
557
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
558
+ self.interleaved = interleaved
559
+ self.scale_base = scale_base
560
+ scale = (
561
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
562
+ if scale_base is not None
563
+ else None
564
+ )
565
+ self.register_buffer("scale", scale, persistent=False)
566
+
567
+ self._seq_len_cached = 0
568
+ self._cos_cached = None
569
+ self._sin_cached = None
570
+ self._cos_k_cached = None
571
+ self._sin_k_cached = None
572
+
573
+ def _compute_inv_freq(self, device=None):
574
+ return 1.0 / (
575
+ self.base
576
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
577
+ )
578
+
579
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
580
+ # Reset the tables if the sequence length has changed,
581
+ # if we're on a new device (possibly due to tracing for instance),
582
+ # or if we're switching from inference mode to training
583
+ if (
584
+ seqlen > self._seq_len_cached
585
+ or self._cos_cached is None
586
+ or self._cos_cached.device != device
587
+ or self._cos_cached.dtype != dtype
588
+ or (self.training and self._cos_cached.is_inference())
589
+ ):
590
+ self._seq_len_cached = seqlen
591
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
592
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
593
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
594
+ if self.pos_idx_in_fp32:
595
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
596
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
597
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
598
+ # cos & sin output to change significantly.
599
+ # We want to recompute self.inv_freq if it was not loaded in fp32
600
+ if self.inv_freq.dtype != torch.float32:
601
+ inv_freq = self._compute_inv_freq(device=device)
602
+ else:
603
+ inv_freq = self.inv_freq
604
+ else:
605
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
606
+ inv_freq = self.inv_freq
607
+ # Don't do einsum, it converts fp32 to fp16 under AMP
608
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
609
+ freqs = torch.outer(t, inv_freq)
610
+ self._cos_cached = torch.cos(freqs).to(dtype)
611
+ self._sin_cached = torch.sin(freqs).to(dtype)
612
+
613
+ def forward(
614
+ self,
615
+ qkv: torch.Tensor,
616
+ kv: Optional[torch.Tensor] = None,
617
+ seqlen_offset: Union[int, torch.Tensor] = 0,
618
+ max_seqlen: Optional[int] = None,
619
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
620
+ """
621
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
622
+ else it's just q of shape (batch, seqlen, nheads, headdim)
623
+ kv: (batch, seqlen, 2, nheads, headdim)
624
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
625
+ Most commonly used in inference when we have KV cache.
626
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
627
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
628
+ Apply rotary embedding *inplace* to qkv and / or kv.
629
+ """
630
+ seqlen = qkv.shape[1]
631
+ if seqlen > self._seq_len_cached:
632
+ self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
633
+ elif max_seqlen is not None:
634
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
635
+ elif isinstance(seqlen_offset, int):
636
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
637
+
638
+ q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
639
+ k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
640
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
641
+
642
+
643
+ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
644
+ def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
645
+ super().__init__(**kwargs)
646
+ self.rotary_scaling_factor = rotary_scaling_factor
647
+ self.max_position_embeddings = max_position_embeddings
648
+
649
+
650
+ def _compute_inv_freq(self, base=None, device=None):
651
+ if base is None:
652
+ base = self.base
653
+ return 1.0 / (
654
+ base
655
+ ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
656
+ )
657
+
658
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
659
+ # Reset the tables if the sequence length has changed,
660
+ # if we're on a new device (possibly due to tracing for instance),
661
+ # or if we're switching from inference mode to training
662
+ if seqlen > self.max_position_embeddings:
663
+ base = self.base * (
664
+ (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
665
+ ) ** (self.dim / (self.dim - 2))
666
+ inv_freq = self._compute_inv_freq(base=base, device=device)
667
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
668
+
669
+ if (
670
+ seqlen > self._seq_len_cached
671
+ or self._cos_cached is None
672
+ or self._cos_cached.device != device
673
+ or self._cos_cached.dtype != dtype
674
+ or (self.training and self._cos_cached.is_inference())
675
+ ):
676
+ self._seq_len_cached = seqlen
677
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
678
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
679
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
680
+ if self.pos_idx_in_fp32:
681
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
682
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
683
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
684
+ # cos & sin output to change significantly.
685
+ # We want to recompute self.inv_freq if it was not loaded in fp32
686
+ if self.inv_freq.dtype != torch.float32:
687
+ if seqlen > self.max_position_embeddings:
688
+ base = self.base * (
689
+ (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
690
+ ) ** (self.dim / (self.dim - 2))
691
+ else:
692
+ base = self.base
693
+ inv_freq = self._compute_inv_freq(device=device, base=base)
694
+ else:
695
+ inv_freq = self.inv_freq
696
+ else:
697
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
698
+ inv_freq = self.inv_freq
699
+ # Don't do einsum, it converts fp32 to fp16 under AMP
700
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
701
+ freqs = torch.outer(t, inv_freq)
702
+ if self.scale is None:
703
+ self._cos_cached = torch.cos(freqs).to(dtype)
704
+ self._sin_cached = torch.sin(freqs).to(dtype)
705
+ else:
706
+ power = (
707
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
708
+ - seqlen // 2
709
+ ) / self.scale_base
710
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
711
+ # We want the multiplication by scale to happen in fp32
712
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
713
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
714
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
715
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
716
+
717
+ class NomicBertAttention(nn.Module):
718
+ """Multi-head self-attention and cross-attention"""
719
+
720
+ def __init__(
721
+ self,
722
+ config,
723
+ ) -> None:
724
+ """
725
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
726
+ return_residual: whether to return the input x along with the output. This is for
727
+ performance reason: for post-norm architecture, returning the input allows us
728
+ to fuse the backward of nn.Linear with the residual connection.
729
+ """
730
+ super().__init__()
731
+ self.embed_dim = config.n_embd
732
+ self.use_flash_attn = config.use_flash_attn
733
+ self.fused_bias_fc = config.fused_bias_fc
734
+
735
+ self.num_heads = config.n_head
736
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
737
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
738
+ self.head_dim = self.embed_dim // self.num_heads
739
+ # we don't really support mqa / gqa for now
740
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
741
+
742
+ self.register_buffer(
743
+ "norm_factor",
744
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
745
+ persistent=False,
746
+ )
747
+
748
+ self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
749
+ if self.rotary_emb_dim > 0:
750
+ if config.rotary_scaling_factor:
751
+ self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
752
+ dim=self.rotary_emb_dim,
753
+ base=config.rotary_emb_base,
754
+ scale_base=config.rotary_emb_scale_base,
755
+ interleaved=config.rotary_emb_interleaved,
756
+ rotary_scaling_factor=config.rotary_scaling_factor,
757
+ max_position_embeddings=config.n_positions,
758
+ )
759
+ else:
760
+ self.rotary_emb = NomicBertRotaryEmbedding(
761
+ dim=self.rotary_emb_dim,
762
+ base=config.rotary_emb_base,
763
+ scale_base=config.rotary_emb_scale_base,
764
+ interleaved=config.rotary_emb_interleaved,
765
+ )
766
+ # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
767
+ # uses the head dimension instead of the sequence dimension
768
+ self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
769
+
770
+ self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
771
+
772
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
773
+ self.causal = config.causal
774
+ self.drop = nn.Dropout(config.attn_pdrop)
775
+
776
+ def forward(
777
+ self,
778
+ hidden_states: torch.Tensor,
779
+ attention_mask: Optional[torch.Tensor] = None,
780
+ position_ids: Optional[torch.LongTensor] = None,
781
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
782
+ output_attentions: bool = False,
783
+ use_cache: bool = False,
784
+ is_padded_inputs: Optional[bool] = True,
785
+ cu_seqlens: Optional[torch.Tensor] = None,
786
+ max_seq_len: Optional[int] = None,
787
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
788
+
789
+ has_layer_past = past_key_value is not None
790
+
791
+ if has_layer_past:
792
+ past_key_value = past_key_value[0]
793
+ past_len = past_key_value[1]
794
+ else:
795
+ past_len = 0
796
+
797
+ qkv = self.Wqkv(hidden_states)
798
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
799
+
800
+ past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
801
+
802
+ if self.rotary_emb_dim > 0:
803
+ if self.rotary_head_dim:
804
+ qkv = rearrange(qkv, "b s three h d -> b h three s d")
805
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
806
+
807
+ if self.rotary_head_dim:
808
+ qkv = rearrange(qkv, "b h three s d -> b s three h d")
809
+
810
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
811
+
812
+ query = query.permute(0, 2, 1, 3)
813
+ key = key.permute(0, 2, 1, 3)
814
+ value = value.permute(0, 2, 1, 3)
815
+
816
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
817
+ if attention_mask is not None:
818
+ attention_scores = attention_scores + attention_mask
819
+
820
+ attentions_probs = F.softmax(attention_scores, dim=-1)
821
+ attentions_probs = self.drop(attentions_probs)
822
+
823
+ attn_output = torch.matmul(attentions_probs, value)
824
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
825
+
826
+ attn_output = self.out_proj(attn_output)
827
+
828
+ return attn_output
829
+
830
+
831
+ class NomicBertBlock(nn.Module):
832
+ def __init__(
833
+ self,
834
+ config,
835
+ ):
836
+ super().__init__()
837
+ self.prenorm = config.prenorm
838
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
839
+
840
+ self.attn = NomicBertAttention(config)
841
+ activation = (
842
+ F.sigmoid
843
+ if config.activation_function == "glu"
844
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
845
+ )
846
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
847
+ self.mlp = NomciBertGatedMLP(config.n_embd, hidden_features=config.n_inner, bias1=config.mlp_fc1_bias, bias2=config.mlp_fc2_bias, activation=activation, fused_bias_fc=config.fused_bias_fc)
848
+ else:
849
+ self.mlp = NomicBertMLP(config.n_embd, hidden_features=config.n_inner, bias1=config.mlp_fc1_bias, bias2=config.mlp_fc2_bias, activation=activation, fused_bias_fc=config.fused_bias_fc)
850
+
851
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
852
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
853
+ self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
854
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
855
+
856
+ def forward(
857
+ self,
858
+ hidden_states: torch.Tensor,
859
+ hidden_states2: torch.Tensor,
860
+ residual: Optional[torch.Tensor] = None,
861
+ attention_mask: Optional[torch.Tensor] = None,
862
+ position_ids: Optional[torch.LongTensor] = None,
863
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
864
+ is_padded_inputs: Optional[bool] = True,
865
+ output_attentions: Optional[bool] = False,
866
+ use_cache: Optional[bool] = False,
867
+ cu_seqlens: Optional[torch.Tensor] = None,
868
+ max_seq_len: Optional[int] = None,
869
+ ):
870
+ r"""Pass the input through the encoder layer.
871
+
872
+ Args:
873
+ hidden_states: the sequence to the encoder layer (required).
874
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
875
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
876
+ before applying the query projection. Useful for e.g., ViT where we only care
877
+ about the CLS token in the last layer.
878
+ """
879
+ if self.prenorm:
880
+ dropped = self.dropout1(hidden_states)
881
+ residual = (dropped + residual) if residual is not None else dropped
882
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
883
+ hidden_states = self.attn(hidden_states, attention_mask=attention_mask, is_padded_inputs=is_padded_inputs, cu_seqlens=cu_seqlens, max_seq_len=max_seq_len)
884
+
885
+ dropped = self.dropout2(hidden_states)
886
+ residual = (dropped + residual) if residual is not None else dropped
887
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
888
+ hidden_states = self.mlp(hidden_states)
889
+
890
+ return hidden_states, None, residual
891
+ else:
892
+ assert residual is None
893
+ attn_outputs = self.attn(hidden_states,
894
+ attention_mask=attention_mask,
895
+ is_padded_inputs=is_padded_inputs,
896
+ cu_seqlens=cu_seqlens,
897
+ max_seq_len=max_seq_len)
898
+ hidden_states = self.norm1(
899
+ (self.dropout1(attn_outputs) + hidden_states).to(
900
+ dtype=self.norm1.weight.dtype
901
+ )
902
+ )
903
+ mlp_out = self.mlp(hidden_states)
904
+
905
+ hidden_states = self.norm2(
906
+ (self.dropout2(mlp_out) + hidden_states).to(
907
+ dtype=self.norm2.weight.dtype
908
+ )
909
+ )
910
+ return hidden_states, None, None
911
+
912
+
913
+ class NomicBertEncoder(nn.Module):
914
+ def __init__(self, config: GPT2Config):
915
+ super().__init__()
916
+ self.layers = nn.ModuleList(
917
+ [NomicBertBlock(config) for _ in range(config.n_layer)]
918
+ )
919
+ self.gradient_checkpointing = False
920
+ self.config = config
921
+
922
+ def forward(self,
923
+ hidden_states: torch.LongTensor = None,
924
+ attention_mask: Optional[torch.Tensor] = None,
925
+ position_ids: Optional[torch.LongTensor] = None,
926
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
927
+ inputs_embeds: Optional[torch.FloatTensor] = None,
928
+ use_cache: Optional[bool] = None,
929
+ output_attentions: Optional[bool] = None,
930
+ output_hidden_states: Optional[bool] = None,
931
+ return_dict: Optional[bool] = None,
932
+ is_padded_inputs: Optional[bool] = True,):
933
+
934
+ """If subset_mask is not None, we only want output for the subset of the sequence.
935
+ This means that we only compute the last layer output for these tokens.
936
+ subset_mask: (batch, seqlen), dtype=torch.bool
937
+ """
938
+ hidden_states2 = None
939
+ residual = None
940
+
941
+
942
+ for _, layer in enumerate(self.layers):
943
+ if self.gradient_checkpointing and self.training:
944
+
945
+ def create_custom_forward(module):
946
+ def custom_forward(*inputs):
947
+ # None for past_key_value
948
+ return module(*inputs)
949
+
950
+ return custom_forward
951
+
952
+ hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
953
+ create_custom_forward(layer),
954
+ hidden_states,
955
+ hidden_states2,
956
+ residual,
957
+ attention_mask,
958
+ None,
959
+ None,
960
+ is_padded_inputs,
961
+ # if you freeze ANY layers, you need `use_reentrant=False`
962
+ # https://github.com/huggingface/transformers/issues/21381
963
+ # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
964
+ use_reentrant=False,
965
+ )
966
+
967
+ else:
968
+ hidden_states, hidden_states2, residual = layer(
969
+ hidden_states,
970
+ hidden_states2,
971
+ residual,
972
+ attention_mask,
973
+ position_ids,
974
+ None,
975
+ is_padded_inputs,
976
+ output_attentions,
977
+ use_cache,
978
+ )
979
+ return hidden_states
980
+
981
+
982
+ class NomicBertPooler(nn.Module):
983
+ def __init__(self, config):
984
+ super().__init__()
985
+ self.dense = nn.Linear(config.n_embd, config.n_embd)
986
+ self.activation = nn.Tanh()
987
+
988
+ def forward(self, hidden_states, pool=True):
989
+ # We "pool" the model by simply taking the hidden state corresponding
990
+ # to the first token.
991
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
992
+ pooled_output = self.dense(first_token_tensor)
993
+ pooled_output = self.activation(pooled_output)
994
+ return pooled_output
995
+
996
+
997
+ class NomicBertPredictionHeadTransform(nn.Module):
998
+ def __init__(self, config):
999
+ super().__init__()
1000
+ self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
1001
+ approximate = (
1002
+ "tanh"
1003
+ if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
1004
+ else "none"
1005
+ )
1006
+ if config.activation_function == "swiglu":
1007
+ self.transform_act_fn = F.silu
1008
+ else:
1009
+ self.transform_act_fn = nn.GELU(approximate=approximate)
1010
+
1011
+ self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1012
+
1013
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1014
+ hidden_states = self.dense(hidden_states)
1015
+ hidden_states = self.transform_act_fn(hidden_states)
1016
+ hidden_states = self.layer_norm(hidden_states)
1017
+
1018
+ return hidden_states
1019
+
1020
+
1021
+ class NomicBertLMPredictionHead(nn.Module):
1022
+ def __init__(self, config):
1023
+ super().__init__()
1024
+
1025
+ self.transform = NomicBertPredictionHeadTransform(config)
1026
+
1027
+ self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
1028
+
1029
+ def forward(self, hidden_states):
1030
+ hidden_states = self.transform(hidden_states)
1031
+ hidden_states = self.decoder(hidden_states)
1032
+ return hidden_states
1033
+
1034
+
1035
+ class NomicBertPreTrainingHeads(nn.Module):
1036
+ def __init__(self, config):
1037
+ super().__init__()
1038
+ self.predictions = NomicBertLMPredictionHead(config)
1039
+
1040
+ def forward(self, sequence_output):
1041
+ prediction_scores = self.predictions(sequence_output)
1042
+ return prediction_scores
1043
+
1044
+
1045
+ class NomicBertModel(NomicBertPreTrainedModel):
1046
+ def __init__(self, config: GPT2Config, add_pooling_layer=True):
1047
+ super().__init__(config)
1048
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1049
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
1050
+ config.vocab_size += self.pad_vocab_size_multiple - (
1051
+ config.vocab_size % self.pad_vocab_size_multiple
1052
+ )
1053
+
1054
+ assert config.activation_function in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu", "geglu", "glu"]
1055
+
1056
+ self.embeddings = NomicBertEmbeddings(
1057
+ config
1058
+ )
1059
+ self.emb_drop = nn.Dropout(config.resid_pdrop)
1060
+ self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1061
+ self.encoder = NomicBertEncoder(config)
1062
+ self.pooler = NomicBertPooler(config) if add_pooling_layer else None
1063
+
1064
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1065
+
1066
+ def forward(
1067
+ self,
1068
+ input_ids,
1069
+ position_ids=None,
1070
+ token_type_ids=None,
1071
+ attention_mask=None,
1072
+ ):
1073
+ if token_type_ids is None:
1074
+ token_type_ids = torch.zeros_like(input_ids)
1075
+ hidden_states = self.embeddings(
1076
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids
1077
+ )
1078
+ hidden_states = self.emb_ln(hidden_states)
1079
+ hidden_states = self.emb_drop(hidden_states)
1080
+
1081
+ attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1082
+ sequence_output = self.encoder(
1083
+ hidden_states, attention_mask=attention_mask
1084
+ )
1085
+
1086
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1087
+
1088
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1089
+ last_hidden_state=sequence_output,
1090
+ pooler_output=pooled_output,
1091
+ )
1092
+
1093
+
1094
+ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1095
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1096
+
1097
+ def __init__(self, config: GPT2Config):
1098
+ super().__init__(config)
1099
+
1100
+ self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
1101
+ self.cls = NomicBertPreTrainingHeads(config)
1102
+ self.mlm_loss = nn.CrossEntropyLoss()
1103
+
1104
+ # Initialize weights and apply final processing
1105
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1106
+ self.tie_weights()
1107
+
1108
+ def tie_weights(self):
1109
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
1110
+
1111
+ def forward(
1112
+ self,
1113
+ input_ids,
1114
+ position_ids=None,
1115
+ token_type_ids=None,
1116
+ attention_mask=None,
1117
+ labels=None,
1118
+ ):
1119
+ """
1120
+ If labels are provided, they must be -100 for masked out tokens (as specified in the attention
1121
+ mask).
1122
+ Outputs:
1123
+ if `labels` and `next_sentence_label` are not `None`:
1124
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
1125
+ sentence classification loss.
1126
+ if `labels` or `next_sentence_label` is `None`:
1127
+ Outputs a tuple comprising
1128
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1129
+ - the next sentence classification logits of shape [batch_size, 2].
1130
+
1131
+ """
1132
+ outputs = self.bert(
1133
+ input_ids,
1134
+ position_ids=position_ids,
1135
+ token_type_ids=token_type_ids,
1136
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1137
+ )
1138
+ sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
1139
+
1140
+ prediction_scores = self.cls(sequence_output)
1141
+
1142
+ total_loss = None
1143
+ if labels is not None:
1144
+ masked_lm_loss = self.mlm_loss(
1145
+ rearrange(prediction_scores, "... v -> (...) v"),
1146
+ rearrange(labels, "... -> (...)"),
1147
+ )
1148
+ total_loss = masked_lm_loss.float()
1149
+
1150
+ return MaskedLMOutput(
1151
+ loss=total_loss,
1152
+ logits=prediction_scores,
1153
+ hidden_states=outputs.hidden_states,
1154
+ attentions=None,
1155
+ )
1156
+
1157
+
1158
+ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1159
+ def __init__(self, config):
1160
+ super().__init__(config)
1161
+ self.num_labels = config.num_labels
1162
+ self.config = config
1163
+
1164
+ self.bert = NomicBertModel(config)
1165
+ classifier_dropout = (
1166
+ getattr(config, "classifier_dropout", config.embd_pdrop)
1167
+ )
1168
+ self.dropout = nn.Dropout(classifier_dropout)
1169
+ self.classifier = nn.Linear(config.n_embd, config.num_labels)
1170
+
1171
+ # Initialize weights and apply final processing
1172
+ self.post_init()
1173
+
1174
+ def forward(
1175
+ self,
1176
+ input_ids: Optional[torch.Tensor] = None,
1177
+ attention_mask: Optional[torch.Tensor] = None,
1178
+ token_type_ids: Optional[torch.Tensor] = None,
1179
+ position_ids: Optional[torch.Tensor] = None,
1180
+ head_mask: Optional[torch.Tensor] = None,
1181
+ inputs_embeds: Optional[torch.Tensor] = None,
1182
+ labels: Optional[torch.Tensor] = None,
1183
+ output_attentions: Optional[bool] = None,
1184
+ output_hidden_states: Optional[bool] = None,
1185
+ return_dict: Optional[bool] = None,
1186
+ ):
1187
+ r"""
1188
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1189
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1190
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1191
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1192
+ """
1193
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1194
+ outputs = self.bert(
1195
+ input_ids,
1196
+ position_ids=position_ids,
1197
+ token_type_ids=token_type_ids,
1198
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1199
+ )
1200
+
1201
+ pooled_output = outputs[1]
1202
+
1203
+ pooled_output = self.dropout(pooled_output)
1204
+ logits = self.classifier(pooled_output)
1205
+
1206
+ loss = None
1207
+ if labels is not None:
1208
+ if self.config.problem_type is None:
1209
+ if self.num_labels == 1:
1210
+ self.config.problem_type = "regression"
1211
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1212
+ self.config.problem_type = "single_label_classification"
1213
+ else:
1214
+ self.config.problem_type = "multi_label_classification"
1215
+
1216
+ if self.config.problem_type == "regression":
1217
+ loss_fct = nn.MSELoss()
1218
+ if self.num_labels == 1:
1219
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1220
+ else:
1221
+ loss = loss_fct(logits, labels)
1222
+ elif self.config.problem_type == "single_label_classification":
1223
+ loss_fct = nn.CrossEntropyLoss()
1224
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1225
+ elif self.config.problem_type == "multi_label_classification":
1226
+ loss_fct = nn.BCEWithLogitsLoss()
1227
+ loss = loss_fct(logits, labels)
1228
+ if not return_dict:
1229
+ output = (logits,) + outputs[2:]
1230
+ return ((loss,) + output) if loss is not None else output
1231
+
1232
+ return SequenceClassifierOutput(
1233
+ loss=loss,
1234
+ logits=logits,
1235
+ hidden_states=outputs.hidden_states,
1236
+ attentions=outputs.attentions,
1237
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fc78c00133aac4e12f358cfe9546e893cb82bb9bb7956506fbbcaa1700ce17c
3
+ size 546961866