Raghavan commited on
Commit
929b4c7
1 Parent(s): 8657564

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +5 -2
modeling_indictrans.py CHANGED
@@ -61,7 +61,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
61
 
62
 
63
  def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
- labels = decoder_input_ids[:, 1:]
 
65
 
66
  labels_mask = labels == 1
67
  labels[labels_mask] = -100
@@ -70,6 +71,8 @@ def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
70
  decoder_input_ids[mask] = 1
71
  decoder_attention_mask[mask] = 0
72
 
 
 
73
  return decoder_input_ids, decoder_attention_mask, labels
74
 
75
 
@@ -1247,7 +1250,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1247
  # move labels to the correct device to enable PP
1248
  labels = labels.to(lm_logits.device)
1249
  loss_fct = nn.CrossEntropyLoss()
1250
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.view(-1))
1251
 
1252
  if not return_dict:
1253
  output = (lm_logits,) + outputs[1:]
 
61
 
62
 
63
  def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
+ labels = decoder_input_ids.full_size(decoder_input_ids.size(), -100)
65
+ labels[:, :-1] = decoder_input_ids[:, 1:]
66
 
67
  labels_mask = labels == 1
68
  labels[labels_mask] = -100
 
71
  decoder_input_ids[mask] = 1
72
  decoder_attention_mask[mask] = 0
73
 
74
+ labels = decoder_input_ids[:, 1:]
75
+
76
  return decoder_input_ids, decoder_attention_mask, labels
77
 
78
 
 
1250
  # move labels to the correct device to enable PP
1251
  labels = labels.to(lm_logits.device)
1252
  loss_fct = nn.CrossEntropyLoss()
1253
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.reshape(-1))
1254
 
1255
  if not return_dict:
1256
  output = (lm_logits,) + outputs[1:]