zpn commited on
Commit
404f7c2
1 Parent(s): c111a6a

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +129 -147
modeling_hf_nomic_bert.py CHANGED
@@ -3,39 +3,34 @@
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
  # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
 
 
6
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
7
  import os
8
- import logging
 
9
  from functools import partial
10
- from typing import Optional, List, Tuple, Union
11
 
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
  from einops import rearrange, repeat
 
16
  from transformers import GPT2Config, PreTrainedModel
17
  from transformers.models.bert.modeling_bert import (
18
  BaseModelOutputWithPoolingAndCrossAttentions,
19
  MaskedLMOutput,
20
- SequenceClassifierOutput
21
- )
22
-
23
- import re
24
- from collections import OrderedDict
25
- from safetensors.torch import load_file as safe_load_file
26
- from transformers.utils import (
27
- SAFE_WEIGHTS_INDEX_NAME,
28
- SAFE_WEIGHTS_NAME,
29
- WEIGHTS_INDEX_NAME,
30
- WEIGHTS_NAME,
31
  )
 
32
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
33
 
34
-
35
  from .configuration_hf_nomic_bert import NomicBertConfig
36
 
37
  logger = logging.getLogger(__name__)
38
 
 
39
  # adapted from flash attention, added safe serialization option for hf models
40
  def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
41
  # If not fp32, then we don't want to load directly to the GPU
@@ -50,18 +45,12 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
50
  safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
51
 
52
  if os.path.isfile(weights_path):
53
- resolved_archive_file = cached_file(
54
- model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
55
- )
56
  elif os.path.isfile(weights_index_path):
57
- resolved_archive_file = cached_file(
58
- model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
59
- )
60
  is_sharded = True
61
  elif os.path.isfile(safe_weights_path):
62
- resolved_archive_file = cached_file(
63
- model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
64
- )
65
  load_safe = True
66
  elif os.path.isfile(safe_weights_index_path):
67
  resolved_archive_file = cached_file(
@@ -74,8 +63,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
74
  resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
75
  if resolved_archive_file is None:
76
  weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
77
- resolved_archive_file = cached_file(model_name, weight_index,
78
- _raise_exceptions_for_missing_entries=False)
79
  if resolved_archive_file is not None:
80
  is_sharded = True
81
 
@@ -92,9 +80,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
92
  if is_sharded:
93
  # resolved_archive_file becomes a list of files that point to the different
94
  # checkpoint shards in this case.
95
- resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
96
- model_name, resolved_archive_file
97
- )
98
  state_dict = {}
99
  for sharded_file in resolved_archive_file:
100
  state_dict.update(loader(sharded_file))
@@ -106,7 +92,7 @@ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None
106
  state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
107
  return state_dict
108
 
109
-
110
  def filter_shapes(state_dict, model):
111
  """
112
  Filters the state dict to match the current model shape.
@@ -118,11 +104,12 @@ def filter_shapes(state_dict, model):
118
  filtered_state_dict[key] = value
119
  return filtered_state_dict
120
 
121
-
122
  def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
123
  """
124
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
125
  """
 
126
  def add_bert_prefix(key):
127
  # prepend bert. to the key
128
  if key.startswith("bert.") or key.startswith("cls."):
@@ -130,7 +117,7 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
130
  return f"bert.{key}"
131
 
132
  state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
133
-
134
  # LayerNorm
135
  def key_mapping_ln_gamma_beta(key):
136
  key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
@@ -195,9 +182,7 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
195
  bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
196
  bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
197
  if not (last_layer_subset and d == config.num_hidden_layers - 1):
198
- state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat(
199
- [Wq, Wk, Wv], dim=0
200
- )
201
  state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
202
  else:
203
  state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
@@ -217,7 +202,6 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
217
  def key_mapping_decoder_bias(key):
218
  return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
219
 
220
-
221
  # remove nsp weights, we don't use
222
  state_dict.pop("cls.seq_relationship.weight", None)
223
  state_dict.pop("cls.seq_relationship.bias", None)
@@ -226,12 +210,14 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
226
  state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
227
 
228
  if remove_cls_weights:
229
- cls_weights = ["cls.predictions.decoder.bias",
230
- "cls.predictions.transform.dense.weight",
231
- "cls.predictions.transform.dense.bias",
232
- "cls.predictions.transform.layer_norm.weight",
233
- "cls.predictions.transform.layer_norm.bias",
234
- "cls.predictions.decoder.weight"]
 
 
235
  for weight in cls_weights:
236
  state_dict.pop(weight, None)
237
 
@@ -257,20 +243,21 @@ def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weig
257
  )
258
 
259
  if add_pooling_layer is False:
260
- pooler_weights = ["bert.pooler.dense.weight",
261
- "bert.pooler.dense.bias",
262
- ]
 
263
  for key in pooler_weights:
264
  state_dict.pop(key, None)
265
 
266
  if remove_bert:
 
267
  def remove_bert_prefix(key):
268
  key = re.sub(r"^bert.", "", key)
269
  return key
270
 
271
  state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
272
 
273
-
274
  return state_dict
275
 
276
 
@@ -278,6 +265,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
278
  """An abstract class to handle weights initialization and
279
  a simple interface for dowloading and loading pretrained models.
280
  """
 
281
  config_class = NomicBertConfig
282
  base_model_prefix = "model"
283
  supports_gradient_checkpointing = True
@@ -323,8 +311,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
323
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
324
  if rotary_scaling_factor:
325
  config.rotary_scaling_factor = rotary_scaling_factor
326
- else:
327
- config.rotary_scaling_factor = None
328
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
329
  config.n_positions = 2048
330
  if num_labels:
@@ -341,26 +328,32 @@ class NomicBertPreTrainedModel(PreTrainedModel):
341
  # Assuming we know what we're doing when loading from disk
342
  # Prob a bad assumption but i'm tired and want to train this asap
343
  if os.path.exists(model_name):
344
- state_dict = torch.load(f"{model_name}/pytorch_model.bin")
 
 
 
 
 
 
 
 
345
  if ignore_mismatched_shapes:
346
  state_dict = filter_shapes(state_dict, model)
347
  load_return = model.load_state_dict(state_dict, strict=False)
348
  else:
349
  # TODO: can probably check config class and see if we need to remap from a bert model
350
- state_dict = state_dict_from_pretrained(model_name)
351
- state_dict = remap_bert_state_dict(state_dict,
352
- config,
353
- remove_bert=remove_bert_prefix,
354
- remove_cls_weights=remove_cls,
355
- add_pooling_layer=getattr(config, "add_pooling_layer", False)
356
- )
 
357
  if ignore_mismatched_shapes:
358
  state_dict = filter_shapes(state_dict, model)
359
 
360
- load_return = model.load_state_dict(
361
- state_dict,
362
- strict=True
363
- )
364
  logger.warning(load_return)
365
  return model
366
 
@@ -380,25 +373,21 @@ def _init_weights(module, initializer_range=0.02):
380
  if module.padding_idx is not None:
381
  nn.init.zeros_(module.weight[module.padding_idx])
382
 
383
-
384
  class NomicBertEmbeddings(nn.Module):
385
- def __init__(
386
- self,
387
- config
388
- ):
389
  """
390
  If max_position_embeddings <= 0, there's no position embeddings
391
  If type_vocab_size <= 0, there's no token type embeddings
392
  """
393
  super().__init__()
394
- self.word_embeddings = nn.Embedding(
395
- config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
396
- )
397
  self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
398
  self.type_vocab_size = config.type_vocab_size
399
  if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
400
  self.position_embeddings = nn.Embedding(
401
- config.max_position_embeddings, config.hidden_size,
 
402
  )
403
  if self.type_vocab_size > 0:
404
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
@@ -425,6 +414,7 @@ class NomicBertEmbeddings(nn.Module):
425
  embeddings = embeddings + position_embeddings
426
  return embeddings
427
 
 
428
  class NomicBertMLP(nn.Module):
429
  def __init__(
430
  self,
@@ -442,11 +432,7 @@ class NomicBertMLP(nn.Module):
442
  hidden_features = hidden_features if hidden_features is not None else in_features * 4
443
  self.return_residual = return_residual
444
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
445
- approximate = (
446
- "tanh"
447
- if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
448
- else "none"
449
- )
450
  self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
451
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
452
 
@@ -456,7 +442,7 @@ class NomicBertMLP(nn.Module):
456
  y = self.fc2(y)
457
  return y if not self.return_residual else (y, x)
458
 
459
-
460
  class NomciBertGatedMLP(nn.Module):
461
  def __init__(
462
  self,
@@ -474,9 +460,7 @@ class NomciBertGatedMLP(nn.Module):
474
  ):
475
  super().__init__()
476
  out_features = out_features if out_features is not None else in_features
477
- hidden_features = (
478
- hidden_features if hidden_features is not None else int(8 * in_features / 3)
479
- )
480
  hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
481
  self.return_residual = return_residual
482
 
@@ -513,8 +497,8 @@ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
513
  ro_dim = cos.shape[-1] * 2
514
  assert ro_dim <= x.shape[-1]
515
  cos, sin = (
516
- cos[offset: offset + x.shape[1]],
517
- sin[offset: offset + x.shape[1]],
518
  )
519
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
520
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
@@ -571,10 +555,7 @@ class NomicBertRotaryEmbedding(nn.Module):
571
  self._sin_k_cached = None
572
 
573
  def _compute_inv_freq(self, device=None):
574
- return 1.0 / (
575
- self.base
576
- ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
577
- )
578
 
579
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
580
  # Reset the tables if the sequence length has changed,
@@ -646,14 +627,10 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
646
  self.rotary_scaling_factor = rotary_scaling_factor
647
  self.max_position_embeddings = max_position_embeddings
648
 
649
-
650
  def _compute_inv_freq(self, base=None, device=None):
651
  if base is None:
652
  base = self.base
653
- return 1.0 / (
654
- base
655
- ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
656
- )
657
 
658
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
659
  # Reset the tables if the sequence length has changed,
@@ -704,8 +681,7 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
704
  self._sin_cached = torch.sin(freqs).to(dtype)
705
  else:
706
  power = (
707
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
708
- - seqlen // 2
709
  ) / self.scale_base
710
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
711
  # We want the multiplication by scale to happen in fp32
@@ -714,6 +690,7 @@ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
714
  self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
715
  self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
716
 
 
717
  class NomicBertAttention(nn.Module):
718
  """Multi-head self-attention and cross-attention"""
719
 
@@ -755,7 +732,7 @@ class NomicBertAttention(nn.Module):
755
  interleaved=config.rotary_emb_interleaved,
756
  rotary_scaling_factor=config.rotary_scaling_factor,
757
  max_position_embeddings=config.n_positions,
758
- )
759
  else:
760
  self.rotary_emb = NomicBertRotaryEmbedding(
761
  dim=self.rotary_emb_dim,
@@ -826,7 +803,7 @@ class NomicBertAttention(nn.Module):
826
  attn_output = self.out_proj(attn_output)
827
 
828
  return attn_output
829
-
830
 
831
  class NomicBertBlock(nn.Module):
832
  def __init__(
@@ -836,17 +813,31 @@ class NomicBertBlock(nn.Module):
836
  super().__init__()
837
  self.prenorm = config.prenorm
838
  self.fused_dropout_add_ln = config.fused_dropout_add_ln
839
-
840
- self.attn = NomicBertAttention(config)
841
  activation = (
842
- F.sigmoid
843
- if config.activation_function == "glu"
844
- else (F.silu if config.activation_function == "swiglu" else F.gelu)
845
  )
846
  if config.activation_function in ["glu", "swiglu", "geglu"]:
847
- self.mlp = NomciBertGatedMLP(config.n_embd, hidden_features=config.n_inner, bias1=config.mlp_fc1_bias, bias2=config.mlp_fc2_bias, activation=activation, fused_bias_fc=config.fused_bias_fc)
 
 
 
 
 
 
 
848
  else:
849
- self.mlp = NomicBertMLP(config.n_embd, hidden_features=config.n_inner, bias1=config.mlp_fc1_bias, bias2=config.mlp_fc2_bias, activation=activation, fused_bias_fc=config.fused_bias_fc)
 
 
 
 
 
 
 
850
 
851
  self.dropout1 = nn.Dropout(config.resid_pdrop)
852
  self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
@@ -880,7 +871,13 @@ class NomicBertBlock(nn.Module):
880
  dropped = self.dropout1(hidden_states)
881
  residual = (dropped + residual) if residual is not None else dropped
882
  hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
883
- hidden_states = self.attn(hidden_states, attention_mask=attention_mask, is_padded_inputs=is_padded_inputs, cu_seqlens=cu_seqlens, max_seq_len=max_seq_len)
 
 
 
 
 
 
884
 
885
  dropped = self.dropout2(hidden_states)
886
  residual = (dropped + residual) if residual is not None else dropped
@@ -890,36 +887,29 @@ class NomicBertBlock(nn.Module):
890
  return hidden_states, None, residual
891
  else:
892
  assert residual is None
893
- attn_outputs = self.attn(hidden_states,
894
- attention_mask=attention_mask,
895
- is_padded_inputs=is_padded_inputs,
896
- cu_seqlens=cu_seqlens,
897
- max_seq_len=max_seq_len)
898
- hidden_states = self.norm1(
899
- (self.dropout1(attn_outputs) + hidden_states).to(
900
- dtype=self.norm1.weight.dtype
901
- )
902
  )
 
903
  mlp_out = self.mlp(hidden_states)
904
 
905
- hidden_states = self.norm2(
906
- (self.dropout2(mlp_out) + hidden_states).to(
907
- dtype=self.norm2.weight.dtype
908
- )
909
- )
910
  return hidden_states, None, None
911
 
912
 
913
  class NomicBertEncoder(nn.Module):
914
  def __init__(self, config: GPT2Config):
915
  super().__init__()
916
- self.layers = nn.ModuleList(
917
- [NomicBertBlock(config) for _ in range(config.n_layer)]
918
- )
919
  self.gradient_checkpointing = False
920
  self.config = config
921
 
922
- def forward(self,
 
923
  hidden_states: torch.LongTensor = None,
924
  attention_mask: Optional[torch.Tensor] = None,
925
  position_ids: Optional[torch.LongTensor] = None,
@@ -929,8 +919,8 @@ class NomicBertEncoder(nn.Module):
929
  output_attentions: Optional[bool] = None,
930
  output_hidden_states: Optional[bool] = None,
931
  return_dict: Optional[bool] = None,
932
- is_padded_inputs: Optional[bool] = True,):
933
-
934
  """If subset_mask is not None, we only want output for the subset of the sequence.
935
  This means that we only compute the last layer output for these tokens.
936
  subset_mask: (batch, seqlen), dtype=torch.bool
@@ -938,7 +928,6 @@ class NomicBertEncoder(nn.Module):
938
  hidden_states2 = None
939
  residual = None
940
 
941
-
942
  for _, layer in enumerate(self.layers):
943
  if self.gradient_checkpointing and self.training:
944
 
@@ -998,11 +987,7 @@ class NomicBertPredictionHeadTransform(nn.Module):
998
  def __init__(self, config):
999
  super().__init__()
1000
  self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
1001
- approximate = (
1002
- "tanh"
1003
- if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
1004
- else "none"
1005
- )
1006
  if config.activation_function == "swiglu":
1007
  self.transform_act_fn = F.silu
1008
  else:
@@ -1047,15 +1032,19 @@ class NomicBertModel(NomicBertPreTrainedModel):
1047
  super().__init__(config)
1048
  self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1049
  if config.vocab_size % self.pad_vocab_size_multiple != 0:
1050
- config.vocab_size += self.pad_vocab_size_multiple - (
1051
- config.vocab_size % self.pad_vocab_size_multiple
1052
- )
1053
-
1054
- assert config.activation_function in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh", "swiglu", "geglu", "glu"]
1055
-
1056
- self.embeddings = NomicBertEmbeddings(
1057
- config
1058
- )
 
 
 
 
1059
  self.emb_drop = nn.Dropout(config.resid_pdrop)
1060
  self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1061
  self.encoder = NomicBertEncoder(config)
@@ -1069,20 +1058,15 @@ class NomicBertModel(NomicBertPreTrainedModel):
1069
  position_ids=None,
1070
  token_type_ids=None,
1071
  attention_mask=None,
1072
- return_dict=None,
1073
  ):
1074
  if token_type_ids is None:
1075
  token_type_ids = torch.zeros_like(input_ids)
1076
- hidden_states = self.embeddings(
1077
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids
1078
- )
1079
  hidden_states = self.emb_ln(hidden_states)
1080
  hidden_states = self.emb_drop(hidden_states)
1081
 
1082
  attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1083
- sequence_output = self.encoder(
1084
- hidden_states, attention_mask=attention_mask, return_dict=return_dict,
1085
- )
1086
 
1087
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1088
 
@@ -1152,10 +1136,10 @@ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1152
  loss=total_loss,
1153
  logits=prediction_scores,
1154
  hidden_states=outputs.hidden_states,
1155
- attentions=None,
1156
  )
1157
 
1158
-
1159
  class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1160
  def __init__(self, config):
1161
  super().__init__(config)
@@ -1163,9 +1147,7 @@ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1163
  self.config = config
1164
 
1165
  self.bert = NomicBertModel(config)
1166
- classifier_dropout = (
1167
- getattr(config, "classifier_dropout", config.embd_pdrop)
1168
- )
1169
  self.dropout = nn.Dropout(classifier_dropout)
1170
  self.classifier = nn.Linear(config.n_embd, config.num_labels)
1171
 
 
3
  # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
  # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
 
6
+ import logging
7
+
8
  # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
  import os
10
+ import re
11
+ from collections import OrderedDict
12
  from functools import partial
13
+ from typing import List, Optional, Tuple, Union
14
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
  from einops import rearrange, repeat
19
+ from safetensors.torch import load_file as safe_load_file
20
  from transformers import GPT2Config, PreTrainedModel
21
  from transformers.models.bert.modeling_bert import (
22
  BaseModelOutputWithPoolingAndCrossAttentions,
23
  MaskedLMOutput,
24
+ SequenceClassifierOutput,
 
 
 
 
 
 
 
 
 
 
25
  )
26
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
27
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
28
 
 
29
  from .configuration_hf_nomic_bert import NomicBertConfig
30
 
31
  logger = logging.getLogger(__name__)
32
 
33
+
34
  # adapted from flash attention, added safe serialization option for hf models
35
  def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
36
  # If not fp32, then we don't want to load directly to the GPU
 
45
  safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
46
 
47
  if os.path.isfile(weights_path):
48
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
 
 
49
  elif os.path.isfile(weights_index_path):
50
+ resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
 
 
51
  is_sharded = True
52
  elif os.path.isfile(safe_weights_path):
53
+ resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
 
 
54
  load_safe = True
55
  elif os.path.isfile(safe_weights_index_path):
56
  resolved_archive_file = cached_file(
 
63
  resolved_archive_file = cached_file(model_name, weight_name, _raise_exceptions_for_missing_entries=False)
64
  if resolved_archive_file is None:
65
  weight_index = WEIGHTS_INDEX_NAME if not safe_serialization else SAFE_WEIGHTS_INDEX_NAME
66
+ resolved_archive_file = cached_file(model_name, weight_index, _raise_exceptions_for_missing_entries=False)
 
67
  if resolved_archive_file is not None:
68
  is_sharded = True
69
 
 
80
  if is_sharded:
81
  # resolved_archive_file becomes a list of files that point to the different
82
  # checkpoint shards in this case.
83
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
 
 
84
  state_dict = {}
85
  for sharded_file in resolved_archive_file:
86
  state_dict.update(loader(sharded_file))
 
92
  state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
93
  return state_dict
94
 
95
+
96
  def filter_shapes(state_dict, model):
97
  """
98
  Filters the state dict to match the current model shape.
 
104
  filtered_state_dict[key] = value
105
  return filtered_state_dict
106
 
107
+
108
  def remap_bert_state_dict(state_dict, config, remove_bert=False, remove_cls_weights=False, add_pooling_layer=False):
109
  """
110
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
111
  """
112
+
113
  def add_bert_prefix(key):
114
  # prepend bert. to the key
115
  if key.startswith("bert.") or key.startswith("cls."):
 
117
  return f"bert.{key}"
118
 
119
  state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
120
+
121
  # LayerNorm
122
  def key_mapping_ln_gamma_beta(key):
123
  key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
 
182
  bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
183
  bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
184
  if not (last_layer_subset and d == config.num_hidden_layers - 1):
185
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
 
 
186
  state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
187
  else:
188
  state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
 
202
  def key_mapping_decoder_bias(key):
203
  return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
204
 
 
205
  # remove nsp weights, we don't use
206
  state_dict.pop("cls.seq_relationship.weight", None)
207
  state_dict.pop("cls.seq_relationship.bias", None)
 
210
  state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
211
 
212
  if remove_cls_weights:
213
+ cls_weights = [
214
+ "cls.predictions.decoder.bias",
215
+ "cls.predictions.transform.dense.weight",
216
+ "cls.predictions.transform.dense.bias",
217
+ "cls.predictions.transform.layer_norm.weight",
218
+ "cls.predictions.transform.layer_norm.bias",
219
+ "cls.predictions.decoder.weight",
220
+ ]
221
  for weight in cls_weights:
222
  state_dict.pop(weight, None)
223
 
 
243
  )
244
 
245
  if add_pooling_layer is False:
246
+ pooler_weights = [
247
+ "bert.pooler.dense.weight",
248
+ "bert.pooler.dense.bias",
249
+ ]
250
  for key in pooler_weights:
251
  state_dict.pop(key, None)
252
 
253
  if remove_bert:
254
+
255
  def remove_bert_prefix(key):
256
  key = re.sub(r"^bert.", "", key)
257
  return key
258
 
259
  state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
260
 
 
261
  return state_dict
262
 
263
 
 
265
  """An abstract class to handle weights initialization and
266
  a simple interface for dowloading and loading pretrained models.
267
  """
268
+
269
  config_class = NomicBertConfig
270
  base_model_prefix = "model"
271
  supports_gradient_checkpointing = True
 
311
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
312
  if rotary_scaling_factor:
313
  config.rotary_scaling_factor = rotary_scaling_factor
314
+
 
315
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
316
  config.n_positions = 2048
317
  if num_labels:
 
328
  # Assuming we know what we're doing when loading from disk
329
  # Prob a bad assumption but i'm tired and want to train this asap
330
  if os.path.exists(model_name):
331
+ model_path = f"{model_name}/pytorch_model.bin"
332
+ if os.path.exists(model_path):
333
+ state_dict = torch.load(f"{model_name}/pytorch_model.bin")
334
+ else:
335
+ model_path = f"{model_name}/model.safetensors"
336
+ if not os.path.exists(model_path):
337
+ raise ValueError(f"Model path {model_path} not found")
338
+ state_dict = safe_load_file(model_path)
339
+
340
  if ignore_mismatched_shapes:
341
  state_dict = filter_shapes(state_dict, model)
342
  load_return = model.load_state_dict(state_dict, strict=False)
343
  else:
344
  # TODO: can probably check config class and see if we need to remap from a bert model
345
+ state_dict = state_dict_from_pretrained(model_name, safe_serialization=kwargs.get("safe_serialization", False))
346
+ state_dict = remap_bert_state_dict(
347
+ state_dict,
348
+ config,
349
+ remove_bert=remove_bert_prefix,
350
+ remove_cls_weights=remove_cls,
351
+ add_pooling_layer=getattr(config, "add_pooling_layer", False),
352
+ )
353
  if ignore_mismatched_shapes:
354
  state_dict = filter_shapes(state_dict, model)
355
 
356
+ load_return = model.load_state_dict(state_dict, strict=True)
 
 
 
357
  logger.warning(load_return)
358
  return model
359
 
 
373
  if module.padding_idx is not None:
374
  nn.init.zeros_(module.weight[module.padding_idx])
375
 
376
+
377
  class NomicBertEmbeddings(nn.Module):
378
+ def __init__(self, config):
 
 
 
379
  """
380
  If max_position_embeddings <= 0, there's no position embeddings
381
  If type_vocab_size <= 0, there's no token type embeddings
382
  """
383
  super().__init__()
384
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
 
 
385
  self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
386
  self.type_vocab_size = config.type_vocab_size
387
  if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
388
  self.position_embeddings = nn.Embedding(
389
+ config.max_position_embeddings,
390
+ config.hidden_size,
391
  )
392
  if self.type_vocab_size > 0:
393
  self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
 
414
  embeddings = embeddings + position_embeddings
415
  return embeddings
416
 
417
+
418
  class NomicBertMLP(nn.Module):
419
  def __init__(
420
  self,
 
432
  hidden_features = hidden_features if hidden_features is not None else in_features * 4
433
  self.return_residual = return_residual
434
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
435
+ approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
 
 
 
 
436
  self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
437
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
438
 
 
442
  y = self.fc2(y)
443
  return y if not self.return_residual else (y, x)
444
 
445
+
446
  class NomciBertGatedMLP(nn.Module):
447
  def __init__(
448
  self,
 
460
  ):
461
  super().__init__()
462
  out_features = out_features if out_features is not None else in_features
463
+ hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
 
 
464
  hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
465
  self.return_residual = return_residual
466
 
 
497
  ro_dim = cos.shape[-1] * 2
498
  assert ro_dim <= x.shape[-1]
499
  cos, sin = (
500
+ cos[offset : offset + x.shape[1]],
501
+ sin[offset : offset + x.shape[1]],
502
  )
503
  cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
504
  sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
 
555
  self._sin_k_cached = None
556
 
557
  def _compute_inv_freq(self, device=None):
558
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
 
 
 
559
 
560
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
561
  # Reset the tables if the sequence length has changed,
 
627
  self.rotary_scaling_factor = rotary_scaling_factor
628
  self.max_position_embeddings = max_position_embeddings
629
 
 
630
  def _compute_inv_freq(self, base=None, device=None):
631
  if base is None:
632
  base = self.base
633
+ return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
 
 
 
634
 
635
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
636
  # Reset the tables if the sequence length has changed,
 
681
  self._sin_cached = torch.sin(freqs).to(dtype)
682
  else:
683
  power = (
684
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
 
685
  ) / self.scale_base
686
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
687
  # We want the multiplication by scale to happen in fp32
 
690
  self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
691
  self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
692
 
693
+
694
  class NomicBertAttention(nn.Module):
695
  """Multi-head self-attention and cross-attention"""
696
 
 
732
  interleaved=config.rotary_emb_interleaved,
733
  rotary_scaling_factor=config.rotary_scaling_factor,
734
  max_position_embeddings=config.n_positions,
735
+ )
736
  else:
737
  self.rotary_emb = NomicBertRotaryEmbedding(
738
  dim=self.rotary_emb_dim,
 
803
  attn_output = self.out_proj(attn_output)
804
 
805
  return attn_output
806
+
807
 
808
  class NomicBertBlock(nn.Module):
809
  def __init__(
 
813
  super().__init__()
814
  self.prenorm = config.prenorm
815
  self.fused_dropout_add_ln = config.fused_dropout_add_ln
816
+
817
+ self.attn = NomicBertAttention(config)
818
  activation = (
819
+ F.sigmoid
820
+ if config.activation_function == "glu"
821
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
822
  )
823
  if config.activation_function in ["glu", "swiglu", "geglu"]:
824
+ self.mlp = NomciBertGatedMLP(
825
+ config.n_embd,
826
+ hidden_features=config.n_inner,
827
+ bias1=config.mlp_fc1_bias,
828
+ bias2=config.mlp_fc2_bias,
829
+ activation=activation,
830
+ fused_bias_fc=config.fused_bias_fc,
831
+ )
832
  else:
833
+ self.mlp = NomicBertMLP(
834
+ config.n_embd,
835
+ hidden_features=config.n_inner,
836
+ bias1=config.mlp_fc1_bias,
837
+ bias2=config.mlp_fc2_bias,
838
+ activation=activation,
839
+ fused_bias_fc=config.fused_bias_fc,
840
+ )
841
 
842
  self.dropout1 = nn.Dropout(config.resid_pdrop)
843
  self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
 
871
  dropped = self.dropout1(hidden_states)
872
  residual = (dropped + residual) if residual is not None else dropped
873
  hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
874
+ hidden_states = self.attn(
875
+ hidden_states,
876
+ attention_mask=attention_mask,
877
+ is_padded_inputs=is_padded_inputs,
878
+ cu_seqlens=cu_seqlens,
879
+ max_seq_len=max_seq_len,
880
+ )
881
 
882
  dropped = self.dropout2(hidden_states)
883
  residual = (dropped + residual) if residual is not None else dropped
 
887
  return hidden_states, None, residual
888
  else:
889
  assert residual is None
890
+ attn_outputs = self.attn(
891
+ hidden_states,
892
+ attention_mask=attention_mask,
893
+ is_padded_inputs=is_padded_inputs,
894
+ cu_seqlens=cu_seqlens,
895
+ max_seq_len=max_seq_len,
 
 
 
896
  )
897
+ hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
898
  mlp_out = self.mlp(hidden_states)
899
 
900
+ hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
 
 
 
 
901
  return hidden_states, None, None
902
 
903
 
904
  class NomicBertEncoder(nn.Module):
905
  def __init__(self, config: GPT2Config):
906
  super().__init__()
907
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
 
 
908
  self.gradient_checkpointing = False
909
  self.config = config
910
 
911
+ def forward(
912
+ self,
913
  hidden_states: torch.LongTensor = None,
914
  attention_mask: Optional[torch.Tensor] = None,
915
  position_ids: Optional[torch.LongTensor] = None,
 
919
  output_attentions: Optional[bool] = None,
920
  output_hidden_states: Optional[bool] = None,
921
  return_dict: Optional[bool] = None,
922
+ is_padded_inputs: Optional[bool] = True,
923
+ ):
924
  """If subset_mask is not None, we only want output for the subset of the sequence.
925
  This means that we only compute the last layer output for these tokens.
926
  subset_mask: (batch, seqlen), dtype=torch.bool
 
928
  hidden_states2 = None
929
  residual = None
930
 
 
931
  for _, layer in enumerate(self.layers):
932
  if self.gradient_checkpointing and self.training:
933
 
 
987
  def __init__(self, config):
988
  super().__init__()
989
  self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
990
+ approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
 
 
 
 
991
  if config.activation_function == "swiglu":
992
  self.transform_act_fn = F.silu
993
  else:
 
1032
  super().__init__(config)
1033
  self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1034
  if config.vocab_size % self.pad_vocab_size_multiple != 0:
1035
+ config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
1036
+
1037
+ assert config.activation_function in [
1038
+ "gelu",
1039
+ "gelu_new",
1040
+ "gelu_fast",
1041
+ "gelu_pytorch_tanh",
1042
+ "swiglu",
1043
+ "geglu",
1044
+ "glu",
1045
+ ]
1046
+
1047
+ self.embeddings = NomicBertEmbeddings(config)
1048
  self.emb_drop = nn.Dropout(config.resid_pdrop)
1049
  self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1050
  self.encoder = NomicBertEncoder(config)
 
1058
  position_ids=None,
1059
  token_type_ids=None,
1060
  attention_mask=None,
 
1061
  ):
1062
  if token_type_ids is None:
1063
  token_type_ids = torch.zeros_like(input_ids)
1064
+ hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
 
 
1065
  hidden_states = self.emb_ln(hidden_states)
1066
  hidden_states = self.emb_drop(hidden_states)
1067
 
1068
  attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1069
+ sequence_output = self.encoder(hidden_states, attention_mask=attention_mask)
 
 
1070
 
1071
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1072
 
 
1136
  loss=total_loss,
1137
  logits=prediction_scores,
1138
  hidden_states=outputs.hidden_states,
1139
+ attentions=None,
1140
  )
1141
 
1142
+
1143
  class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1144
  def __init__(self, config):
1145
  super().__init__(config)
 
1147
  self.config = config
1148
 
1149
  self.bert = NomicBertModel(config)
1150
+ classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
 
 
1151
  self.dropout = nn.Dropout(classifier_dropout)
1152
  self.classifier = nn.Linear(config.n_embd, config.num_labels)
1153