Markus28 commited on
Commit
2b23340
1 Parent(s): a0c289c

Try to subclass PretrainedModel

Browse files
Files changed (1) hide show
  1. modeling_bert.py +2 -29
modeling_bert.py CHANGED
@@ -22,7 +22,7 @@ import torch
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
  from einops import rearrange
25
- from transformers import PretrainedConfig
26
  from .configuration_bert import JinaBertConfig
27
  from transformers.models.bert.modeling_bert import (
28
  BaseModelOutputWithPoolingAndCrossAttentions,
@@ -295,7 +295,7 @@ class BertPreTrainingHeads(nn.Module):
295
  return prediction_scores, seq_relationship_score
296
 
297
 
298
- class BertPreTrainedModel(nn.Module):
299
  """An abstract class to handle weights initialization and
300
  a simple interface for dowloading and loading pretrained models.
301
  """
@@ -310,33 +310,6 @@ class BertPreTrainedModel(nn.Module):
310
  )
311
  self.config = config
312
 
313
- @classmethod
314
- def from_pretrained(cls, model_name, config, *inputs, **kwargs):
315
- """
316
- Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
317
- Download and cache the pre-trained model file if needed.
318
-
319
- Params:
320
- pretrained_model_name_or_path: either:
321
- - a path or url to a pretrained model archive containing:
322
- . `bert_config.json` a configuration file for the model
323
- . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
324
- - a path or url to a pretrained model archive containing:
325
- . `bert_config.json` a configuration file for the model
326
- . `model.chkpt` a TensorFlow checkpoint
327
- *inputs, **kwargs: additional input for the specific Bert class
328
- (ex: num_labels for BertForSequenceClassification)
329
- """
330
- # Instantiate model.
331
- model = cls(config, *inputs, **kwargs)
332
- load_return = model.load_state_dict(state_dict_from_pretrained(model_name), strict=True)
333
- logger.info(load_return)
334
- return model
335
-
336
- @classmethod
337
- def _from_config(cls, config, **kwargs):
338
- return cls(config, **kwargs)
339
-
340
 
341
  class BertModel(BertPreTrainedModel):
342
  def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
 
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
  from einops import rearrange
25
+ from transformers import PretrainedModel
26
  from .configuration_bert import JinaBertConfig
27
  from transformers.models.bert.modeling_bert import (
28
  BaseModelOutputWithPoolingAndCrossAttentions,
 
295
  return prediction_scores, seq_relationship_score
296
 
297
 
298
+ class BertPreTrainedModel(PretrainedModel):
299
  """An abstract class to handle weights initialization and
300
  a simple interface for dowloading and loading pretrained models.
301
  """
 
310
  )
311
  self.config = config
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  class BertModel(BertPreTrainedModel):
315
  def __init__(self, config: JinaBertConfig, add_pooling_layer=True):