feat: added from_config, also pass additional kwargs from config to model
Browse files- modeling_bert.py +9 -1
modeling_bert.py
CHANGED
@@ -328,13 +328,17 @@ class BertPreTrainedModel(nn.Module):
|
|
328 |
(ex: num_labels for BertForSequenceClassification)
|
329 |
"""
|
330 |
# Instantiate model.
|
331 |
-
model = cls(config
|
332 |
load_return = model.load_state_dict(
|
333 |
remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
|
334 |
)
|
335 |
logger.info(load_return)
|
336 |
return model
|
337 |
|
|
|
|
|
|
|
|
|
338 |
|
339 |
class BertModel(BertPreTrainedModel):
|
340 |
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|
@@ -523,6 +527,10 @@ class BertForPreTraining(BertPreTrainedModel):
|
|
523 |
seq_relationship_logits=seq_relationship_score,
|
524 |
)
|
525 |
|
|
|
|
|
|
|
|
|
526 |
|
527 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
528 |
"""
|
|
|
328 |
(ex: num_labels for BertForSequenceClassification)
|
329 |
"""
|
330 |
# Instantiate model.
|
331 |
+
model = cls(config, *inputs, **kwargs)
|
332 |
load_return = model.load_state_dict(
|
333 |
remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
|
334 |
)
|
335 |
logger.info(load_return)
|
336 |
return model
|
337 |
|
338 |
+
@classmethod
|
339 |
+
def from_config(cls, config, *inputs, **kwargs):
|
340 |
+
model = cls(config, *inputs, **kwargs)
|
341 |
+
return model
|
342 |
|
343 |
class BertModel(BertPreTrainedModel):
|
344 |
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|
|
|
527 |
seq_relationship_logits=seq_relationship_score,
|
528 |
)
|
529 |
|
530 |
+
@classmethod
|
531 |
+
def _from_config(cls, config, **kwargs):
|
532 |
+
pass
|
533 |
+
|
534 |
|
535 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
536 |
"""
|