izhx commited on
Commit
d0284a3
1 Parent(s): b7ea01b

Update modeling.py

Browse files
Files changed (1) hide show
  1. modeling.py +8 -5
modeling.py CHANGED
@@ -975,8 +975,6 @@ class NewForMaskedLM(NewPreTrainedModel):
975
  self.lm_head = NewLMPredictionHead(config)
976
  self.loss_fct = nn.CrossEntropyLoss()
977
 
978
- self.pretraining = True
979
-
980
  # Initialize weights and apply final processing
981
  self.post_init()
982
 
@@ -1009,13 +1007,13 @@ class NewForMaskedLM(NewPreTrainedModel):
1009
 
1010
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1011
 
1012
- if labels is None:
1013
  length = None
1014
  subset_indices = None
1015
  else:
1016
  length = attention_mask.sum(-1).tolist()
1017
  labels = labels[attention_mask.bool()].unsqueeze(0)
1018
- subset_indices = labels > -100 if self.pretraining else None
1019
 
1020
  outputs = self.new(
1021
  input_ids,
@@ -1037,7 +1035,12 @@ class NewForMaskedLM(NewPreTrainedModel):
1037
 
1038
  masked_lm_loss = None
1039
  if labels is not None:
1040
- labels = labels[subset_indices]
 
 
 
 
 
1041
  masked_lm_loss = self.loss_fct(prediction_scores, labels)
1042
 
1043
  if not return_dict:
 
975
  self.lm_head = NewLMPredictionHead(config)
976
  self.loss_fct = nn.CrossEntropyLoss()
977
 
 
 
978
  # Initialize weights and apply final processing
979
  self.post_init()
980
 
 
1007
 
1008
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1009
 
1010
+ if labels is None or not self.new.config.unpad_inputs:
1011
  length = None
1012
  subset_indices = None
1013
  else:
1014
  length = attention_mask.sum(-1).tolist()
1015
  labels = labels[attention_mask.bool()].unsqueeze(0)
1016
+ subset_indices = labels > -100
1017
 
1018
  outputs = self.new(
1019
  input_ids,
 
1035
 
1036
  masked_lm_loss = None
1037
  if labels is not None:
1038
+ if subset_indices is None:
1039
+ mask = attention_mask.bool()
1040
+ prediction_scores = prediction_scores[mask]
1041
+ labels = labels[mask]
1042
+ else:
1043
+ labels = labels[subset_indices]
1044
  masked_lm_loss = self.loss_fct(prediction_scores, labels)
1045
 
1046
  if not return_dict: