zhihan1996 commited on
Commit
6041066
1 Parent(s): 69b2c8f

Update bert_layers.py

Browse files
Files changed (1) hide show
  1. bert_layers.py +5 -68
bert_layers.py CHANGED
@@ -698,38 +698,6 @@ class BertForMaskedLM(BertPreTrainedModel):
698
  # Initialize weights and apply final processing
699
  self.post_init()
700
 
701
- @classmethod
702
- def from_composer(cls,
703
- pretrained_checkpoint,
704
- state_dict=None,
705
- cache_dir=None,
706
- from_tf=False,
707
- config=None,
708
- *inputs,
709
- **kwargs):
710
- """Load from pre-trained."""
711
- model = cls(config, *inputs, **kwargs)
712
- if from_tf:
713
- raise ValueError(
714
- 'Mosaic BERT does not support loading TensorFlow weights.')
715
-
716
- state_dict = torch.load(pretrained_checkpoint)
717
- # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
718
- consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
719
- missing_keys, unexpected_keys = model.load_state_dict(state_dict,
720
- strict=False)
721
-
722
- if len(missing_keys) > 0:
723
- logger.warning(
724
- f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
725
- )
726
- if len(unexpected_keys) > 0:
727
- logger.warning(
728
- f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
729
- )
730
-
731
- return model
732
-
733
  def get_output_embeddings(self):
734
  return self.cls.predictions.decoder
735
 
@@ -786,7 +754,7 @@ class BertForMaskedLM(BertPreTrainedModel):
786
  return_dict=return_dict,
787
  masked_tokens_mask=masked_tokens_mask,
788
  )
789
-
790
  sequence_output = outputs[0]
791
  prediction_scores = self.cls(sequence_output)
792
 
@@ -813,8 +781,8 @@ class BertForMaskedLM(BertPreTrainedModel):
813
  return MaskedLMOutput(
814
  loss=loss,
815
  logits=prediction_scores,
816
- hidden_states=outputs.hidden_states,
817
- attentions=outputs.attention,
818
  )
819
 
820
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
@@ -868,37 +836,6 @@ class BertForSequenceClassification(BertPreTrainedModel):
868
  # Initialize weights and apply final processing
869
  self.post_init()
870
 
871
- @classmethod
872
- def from_composer(cls,
873
- pretrained_checkpoint,
874
- state_dict=None,
875
- cache_dir=None,
876
- from_tf=False,
877
- config=None,
878
- *inputs,
879
- **kwargs):
880
- """Load from pre-trained."""
881
- model = cls(config, *inputs, **kwargs)
882
- if from_tf:
883
- raise ValueError(
884
- 'Mosaic BERT does not support loading TensorFlow weights.')
885
-
886
- state_dict = torch.load(pretrained_checkpoint)
887
- # If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
888
- consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
889
- missing_keys, unexpected_keys = model.load_state_dict(state_dict,
890
- strict=False)
891
-
892
- if len(missing_keys) > 0:
893
- logger.warning(
894
- f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
895
- )
896
- if len(unexpected_keys) > 0:
897
- logger.warning(
898
- f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
899
- )
900
-
901
- return model
902
 
903
  def forward(
904
  self,
@@ -972,7 +909,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
972
  return SequenceClassifierOutput(
973
  loss=loss,
974
  logits=logits,
975
- hidden_states=outputs.hidden_states,
976
- attentions=outputs.attention,
977
  )
978
 
 
698
  # Initialize weights and apply final processing
699
  self.post_init()
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  def get_output_embeddings(self):
702
  return self.cls.predictions.decoder
703
 
 
754
  return_dict=return_dict,
755
  masked_tokens_mask=masked_tokens_mask,
756
  )
757
+
758
  sequence_output = outputs[0]
759
  prediction_scores = self.cls(sequence_output)
760
 
 
781
  return MaskedLMOutput(
782
  loss=loss,
783
  logits=prediction_scores,
784
+ hidden_states=outputs[0],
785
+ attentions=None,
786
  )
787
 
788
  def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
 
836
  # Initialize weights and apply final processing
837
  self.post_init()
838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
 
840
  def forward(
841
  self,
 
909
  return SequenceClassifierOutput(
910
  loss=loss,
911
  logits=logits,
912
+ hidden_states=outputs[0],
913
+ attentions=None,
914
  )
915