Markus28 commited on
Commit
0f43653
·
1 Parent(s): 3160695

feat: updated modeling_bert.py to allow MLM-only training

Browse files
Files changed (1) hide show
  1. modeling_bert.py +19 -15
modeling_bert.py CHANGED
@@ -494,24 +494,28 @@ class BertForPreTraining(BertPreTrainedModel):
494
  )
495
  prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
496
 
497
- total_loss = None
498
- if labels is not None and next_sentence_label is not None:
499
- if (
500
- self.dense_seq_output and labels is not None
501
- ): # prediction_scores are already flattened
502
- masked_lm_loss = self.mlm_loss(
503
- prediction_scores, labels.flatten()[masked_token_idx]
504
- )
505
- else:
506
- masked_lm_loss = self.mlm_loss(
507
- rearrange(prediction_scores, "... v -> (...) v"),
508
- rearrange(labels, "... -> (...)"),
509
- )
 
510
  next_sentence_loss = self.nsp_loss(
511
  rearrange(seq_relationship_score, "... t -> (...) t"),
512
  rearrange(next_sentence_label, "... -> (...)"),
513
- )
514
- total_loss = masked_lm_loss.float() + next_sentence_loss.float()
 
 
 
515
 
516
  return BertForPreTrainingOutput(
517
  loss=total_loss,
 
494
  )
495
  prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
496
 
497
+ if (
498
+ self.dense_seq_output and labels is not None
499
+ ): # prediction_scores are already flattened
500
+ masked_lm_loss = self.mlm_loss(
501
+ prediction_scores, labels.flatten()[masked_token_idx]
502
+ ).float()
503
+ elif labels is not None:
504
+ masked_lm_loss = self.mlm_loss(
505
+ rearrange(prediction_scores, "... v -> (...) v"),
506
+ rearrange(labels, "... -> (...)"),
507
+ ).float()
508
+ else:
509
+ masked_lm_loss = 0
510
+ if next_sentence_label is not None:
511
  next_sentence_loss = self.nsp_loss(
512
  rearrange(seq_relationship_score, "... t -> (...) t"),
513
  rearrange(next_sentence_label, "... -> (...)"),
514
+ ).float()
515
+ else:
516
+ next_sentence_loss = 0
517
+
518
+ total_loss = masked_lm_loss + next_sentence_loss
519
 
520
  return BertForPreTrainingOutput(
521
  loss=total_loss,