Update modeling_hf_nomic_bert.py
Browse files- modeling_hf_nomic_bert.py +5 -11
modeling_hf_nomic_bert.py
CHANGED
@@ -315,9 +315,8 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
315 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
316 |
num_labels = kwargs.pop("num_labels", None)
|
317 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
322 |
config.n_positions = 2048
|
323 |
if num_labels:
|
@@ -326,10 +325,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
326 |
if "add_pooling_layer" in kwargs:
|
327 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
328 |
else:
|
329 |
-
|
330 |
-
model = cls(config, *inputs, add_pooling_layer=False)
|
331 |
-
else:
|
332 |
-
model = cls(config, *inputs)
|
333 |
# TODO: fix this
|
334 |
# Assuming we know what we're doing when loading from disk
|
335 |
# Prob a bad assumption but i'm tired and want to train this asap
|
@@ -348,9 +344,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
348 |
load_return = model.load_state_dict(state_dict, strict=False)
|
349 |
else:
|
350 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
351 |
-
state_dict = state_dict_from_pretrained(
|
352 |
-
model_name, safe_serialization=kwargs.get("safe_serialization", False)
|
353 |
-
)
|
354 |
state_dict = remap_bert_state_dict(
|
355 |
state_dict,
|
356 |
config,
|
@@ -361,7 +355,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
361 |
if ignore_mismatched_shapes:
|
362 |
state_dict = filter_shapes(state_dict, model)
|
363 |
|
364 |
-
load_return = model.load_state_dict(state_dict, strict=
|
365 |
logger.warning(load_return)
|
366 |
return model
|
367 |
|
|
|
315 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
316 |
num_labels = kwargs.pop("num_labels", None)
|
317 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
318 |
+
strict = kwargs.pop("strict", True)
|
319 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
|
|
320 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
321 |
config.n_positions = 2048
|
322 |
if num_labels:
|
|
|
325 |
if "add_pooling_layer" in kwargs:
|
326 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
327 |
else:
|
328 |
+
model = cls(config, *inputs)
|
|
|
|
|
|
|
329 |
# TODO: fix this
|
330 |
# Assuming we know what we're doing when loading from disk
|
331 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
|
344 |
load_return = model.load_state_dict(state_dict, strict=False)
|
345 |
else:
|
346 |
# TODO: can probably check config class and see if we need to remap from a bert model
|
347 |
+
state_dict = state_dict_from_pretrained(model_name)
|
|
|
|
|
348 |
state_dict = remap_bert_state_dict(
|
349 |
state_dict,
|
350 |
config,
|
|
|
355 |
if ignore_mismatched_shapes:
|
356 |
state_dict = filter_shapes(state_dict, model)
|
357 |
|
358 |
+
load_return = model.load_state_dict(state_dict, strict=strict)
|
359 |
logger.warning(load_return)
|
360 |
return model
|
361 |
|