Markus28 commited on
Commit
4164fd6
1 Parent(s): 0f43653

feat: added from_config, also pass additional kwargs from config to model

Browse files
Files changed (1) hide show
  1. modeling_bert.py +9 -1
modeling_bert.py CHANGED
@@ -328,13 +328,17 @@ class BertPreTrainedModel(nn.Module):
328
  (ex: num_labels for BertForSequenceClassification)
329
  """
330
  # Instantiate model.
331
- model = cls(config) #cls(config, *inputs, **kwargs)
332
  load_return = model.load_state_dict(
333
  remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
334
  )
335
  logger.info(load_return)
336
  return model
337
 
 
 
 
 
338
 
339
  class BertModel(BertPreTrainedModel):
340
  def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
@@ -523,6 +527,10 @@ class BertForPreTraining(BertPreTrainedModel):
523
  seq_relationship_logits=seq_relationship_score,
524
  )
525
 
 
 
 
 
526
 
527
  def remap_state_dict(state_dict, config: PretrainedConfig):
528
  """
 
328
  (ex: num_labels for BertForSequenceClassification)
329
  """
330
  # Instantiate model.
331
+ model = cls(config, *inputs, **kwargs)
332
  load_return = model.load_state_dict(
333
  remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
334
  )
335
  logger.info(load_return)
336
  return model
337
 
338
+ @classmethod
339
+ def from_config(cls, config, *inputs, **kwargs):
340
+ model = cls(config, *inputs, **kwargs)
341
+ return model
342
 
343
  class BertModel(BertPreTrainedModel):
344
  def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
 
527
  seq_relationship_logits=seq_relationship_score,
528
  )
529
 
530
+ @classmethod
531
+ def _from_config(cls, config, **kwargs):
532
+ pass
533
+
534
 
535
  def remap_state_dict(state_dict, config: PretrainedConfig):
536
  """