Update modeling_hf_nomic_bert.py
Browse files
modeling_hf_nomic_bert.py
CHANGED
@@ -316,7 +316,9 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
316 |
num_labels = kwargs.pop("num_labels", None)
|
317 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
318 |
strict = kwargs.pop("strict", True)
|
319 |
-
|
|
|
|
|
320 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
321 |
config.n_positions = 2048
|
322 |
if num_labels:
|
@@ -325,7 +327,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
325 |
if "add_pooling_layer" in kwargs:
|
326 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
327 |
else:
|
328 |
-
|
|
|
|
|
|
|
329 |
# TODO: fix this
|
330 |
# Assuming we know what we're doing when loading from disk
|
331 |
# Prob a bad assumption but i'm tired and want to train this asap
|
@@ -344,7 +349,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
344 |
load_return = model.load_state_dict(state_dict, strict=False)
|
345 |
else:
|
346 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
347 |
-
state_dict = state_dict_from_pretrained(model_name)
|
348 |
state_dict = remap_bert_state_dict(
|
349 |
state_dict,
|
350 |
config,
|
@@ -1057,9 +1062,9 @@ class NomicBertModel(NomicBertPreTrainedModel):
|
|
1057 |
def forward(
|
1058 |
self,
|
1059 |
input_ids,
|
|
|
1060 |
position_ids=None,
|
1061 |
token_type_ids=None,
|
1062 |
-
attention_mask=None,
|
1063 |
return_dict=None,
|
1064 |
matryoshka_dim=None,
|
1065 |
):
|
|
|
316 |
num_labels = kwargs.pop("num_labels", None)
|
317 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
318 |
strict = kwargs.pop("strict", True)
|
319 |
+
if rotary_scaling_factor:
|
320 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
321 |
+
|
322 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
323 |
config.n_positions = 2048
|
324 |
if num_labels:
|
|
|
327 |
if "add_pooling_layer" in kwargs:
|
328 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
329 |
else:
|
330 |
+
if cls == NomicBertModel:
|
331 |
+
model = cls(config, *inputs, add_pooling_layer=False)
|
332 |
+
else:
|
333 |
+
model = cls(config, *inputs)
|
334 |
# TODO: fix this
|
335 |
# Assuming we know what we're doing when loading from disk
|
336 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
|
349 |
load_return = model.load_state_dict(state_dict, strict=False)
|
350 |
else:
|
351 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
352 |
+
state_dict = state_dict_from_pretrained(model_name, safe_serialization=kwargs.get("safe_serialization", False))
|
353 |
state_dict = remap_bert_state_dict(
|
354 |
state_dict,
|
355 |
config,
|
|
|
1062 |
def forward(
|
1063 |
self,
|
1064 |
input_ids,
|
1065 |
+
attention_mask=None,
|
1066 |
position_ids=None,
|
1067 |
token_type_ids=None,
|
|
|
1068 |
return_dict=None,
|
1069 |
matryoshka_dim=None,
|
1070 |
):
|