Upload 7 files
Browse files- modeling_indictrans.py +5 -2
modeling_indictrans.py
CHANGED
@@ -61,7 +61,8 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
|
61 |
|
62 |
|
63 |
def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
64 |
-
labels = decoder_input_ids
|
|
|
65 |
|
66 |
labels_mask = labels == 1
|
67 |
labels[labels_mask] = -100
|
@@ -70,6 +71,8 @@ def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
|
70 |
decoder_input_ids[mask] = 1
|
71 |
decoder_attention_mask[mask] = 0
|
72 |
|
|
|
|
|
73 |
return decoder_input_ids, decoder_attention_mask, labels
|
74 |
|
75 |
|
@@ -1247,7 +1250,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
|
1247 |
# move labels to the correct device to enable PP
|
1248 |
labels = labels.to(lm_logits.device)
|
1249 |
loss_fct = nn.CrossEntropyLoss()
|
1250 |
-
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.
|
1251 |
|
1252 |
if not return_dict:
|
1253 |
output = (lm_logits,) + outputs[1:]
|
|
|
61 |
|
62 |
|
63 |
def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
64 |
+
labels = decoder_input_ids.full_size(decoder_input_ids.size(), -100)
|
65 |
+
labels[:, :-1] = decoder_input_ids[:, 1:]
|
66 |
|
67 |
labels_mask = labels == 1
|
68 |
labels[labels_mask] = -100
|
|
|
71 |
decoder_input_ids[mask] = 1
|
72 |
decoder_attention_mask[mask] = 0
|
73 |
|
74 |
+
labels = decoder_input_ids[:, 1:]
|
75 |
+
|
76 |
return decoder_input_ids, decoder_attention_mask, labels
|
77 |
|
78 |
|
|
|
1250 |
# move labels to the correct device to enable PP
|
1251 |
labels = labels.to(lm_logits.device)
|
1252 |
loss_fct = nn.CrossEntropyLoss()
|
1253 |
+
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.decoder_vocab_size), labels.reshape(-1))
|
1254 |
|
1255 |
if not return_dict:
|
1256 |
output = (lm_logits,) + outputs[1:]
|