davda54 commited on
Commit
d06b660
1 Parent(s): 796b28d

Update modeling_norbert.py

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +13 -13
modeling_norbert.py CHANGED
@@ -277,12 +277,12 @@ class NorbertPreTrainedModel(PreTrainedModel):
277
 
278
 
279
  class NorbertModel(NorbertPreTrainedModel):
280
- def __init__(self, config, add_mlm_layer=False):
281
- super().__init__(config)
282
  self.config = config
283
 
284
  self.embedding = Embedding(config)
285
- self.transformer = Encoder(config, activation_checkpointing=False)
286
  self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
287
 
288
  def get_input_embeddings(self):
@@ -352,8 +352,8 @@ class NorbertModel(NorbertPreTrainedModel):
352
  class NorbertForMaskedLM(NorbertModel):
353
  _keys_to_ignore_on_load_unexpected = ["head"]
354
 
355
- def __init__(self, config):
356
- super().__init__(config, add_mlm_layer=True)
357
 
358
  def get_output_embeddings(self):
359
  return self.classifier.nonlinearity[-1].weight
@@ -432,8 +432,8 @@ class NorbertForSequenceClassification(NorbertModel):
432
  _keys_to_ignore_on_load_unexpected = ["classifier"]
433
  _keys_to_ignore_on_load_missing = ["head"]
434
 
435
- def __init__(self, config):
436
- super().__init__(config, add_mlm_layer=False)
437
 
438
  self.num_labels = config.num_labels
439
  self.head = Classifier(config, self.num_labels)
@@ -498,8 +498,8 @@ class NorbertForTokenClassification(NorbertModel):
498
  _keys_to_ignore_on_load_unexpected = ["classifier"]
499
  _keys_to_ignore_on_load_missing = ["head"]
500
 
501
- def __init__(self, config):
502
- super().__init__(config, add_mlm_layer=False)
503
 
504
  self.num_labels = config.num_labels
505
  self.head = Classifier(config, self.num_labels)
@@ -546,8 +546,8 @@ class NorbertForQuestionAnswering(NorbertModel):
546
  _keys_to_ignore_on_load_unexpected = ["classifier"]
547
  _keys_to_ignore_on_load_missing = ["head"]
548
 
549
- def __init__(self, config):
550
- super().__init__(config, add_mlm_layer=False)
551
 
552
  self.num_labels = config.num_labels
553
  self.head = Classifier(config, self.num_labels)
@@ -614,8 +614,8 @@ class NorbertForMultipleChoice(NorbertModel):
614
  _keys_to_ignore_on_load_unexpected = ["classifier"]
615
  _keys_to_ignore_on_load_missing = ["head"]
616
 
617
- def __init__(self, config):
618
- super().__init__(config, add_mlm_layer=False)
619
 
620
  self.num_labels = getattr(config, "num_labels", 2)
621
  self.head = Classifier(config, self.num_labels)
 
277
 
278
 
279
  class NorbertModel(NorbertPreTrainedModel):
280
+ def __init__(self, config, add_mlm_layer=False, gradient_checkpointing=False, **kwargs):
281
+ super().__init__(config, **kwargs)
282
  self.config = config
283
 
284
  self.embedding = Embedding(config)
285
+ self.transformer = Encoder(config, activation_checkpointing=gradient_checkpointing)
286
  self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
287
 
288
  def get_input_embeddings(self):
 
352
  class NorbertForMaskedLM(NorbertModel):
353
  _keys_to_ignore_on_load_unexpected = ["head"]
354
 
355
+ def __init__(self, config, **kwargs):
356
+ super().__init__(config, add_mlm_layer=True, **kwargs)
357
 
358
  def get_output_embeddings(self):
359
  return self.classifier.nonlinearity[-1].weight
 
432
  _keys_to_ignore_on_load_unexpected = ["classifier"]
433
  _keys_to_ignore_on_load_missing = ["head"]
434
 
435
+ def __init__(self, config, **kwargs):
436
+ super().__init__(config, add_mlm_layer=False, **kwargs)
437
 
438
  self.num_labels = config.num_labels
439
  self.head = Classifier(config, self.num_labels)
 
498
  _keys_to_ignore_on_load_unexpected = ["classifier"]
499
  _keys_to_ignore_on_load_missing = ["head"]
500
 
501
+ def __init__(self, config, **kwargs):
502
+ super().__init__(config, add_mlm_layer=False, **kwargs)
503
 
504
  self.num_labels = config.num_labels
505
  self.head = Classifier(config, self.num_labels)
 
546
  _keys_to_ignore_on_load_unexpected = ["classifier"]
547
  _keys_to_ignore_on_load_missing = ["head"]
548
 
549
+ def __init__(self, config, **kwargs):
550
+ super().__init__(config, add_mlm_layer=False, **kwargs)
551
 
552
  self.num_labels = config.num_labels
553
  self.head = Classifier(config, self.num_labels)
 
614
  _keys_to_ignore_on_load_unexpected = ["classifier"]
615
  _keys_to_ignore_on_load_missing = ["head"]
616
 
617
+ def __init__(self, config, **kwargs):
618
+ super().__init__(config, add_mlm_layer=False, **kwargs)
619
 
620
  self.num_labels = getattr(config, "num_labels", 2)
621
  self.head = Classifier(config, self.num_labels)