Markus28 commited on
Commit
c0b46cc
1 Parent(s): 3cb3930

fix BertForMaskedLM

Browse files
Files changed (1) hide show
  1. modeling_bert.py +8 -8
modeling_bert.py CHANGED
@@ -752,18 +752,18 @@ class BertForMaskedLM(BertPreTrainedModel):
752
  prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
753
 
754
  if (
755
- self.dense_seq_output and labels is not None
756
  ): # prediction_scores are already flattened
757
  masked_lm_loss = self.mlm_loss(
758
  prediction_scores, labels.flatten()[masked_token_idx]
759
  ).float()
760
-
761
- assert labels is not None
762
-
763
- masked_lm_loss = self.mlm_loss(
764
- rearrange(prediction_scores, "... v -> (...) v"),
765
- rearrange(labels, "... -> (...)"),
766
- ).float()
767
 
768
  return BertForPreTrainingOutput(
769
  loss=masked_lm_loss,
 
752
  prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
753
 
754
  if (
755
+ self.dense_seq_output and labels is not None
756
  ): # prediction_scores are already flattened
757
  masked_lm_loss = self.mlm_loss(
758
  prediction_scores, labels.flatten()[masked_token_idx]
759
  ).float()
760
+ elif labels is not None:
761
+ masked_lm_loss = self.mlm_loss(
762
+ rearrange(prediction_scores, "... v -> (...) v"),
763
+ rearrange(labels, "... -> (...)"),
764
+ ).float()
765
+ else:
766
+ raise ValueError('MLM labels must not be None')
767
 
768
  return BertForPreTrainingOutput(
769
  loss=masked_lm_loss,