Raghavan commited on
Commit
8b9ed51
·
verified ·
1 Parent(s): 182ddfb

Upload 7 files

Browse files
Files changed (1) hide show
  1. modeling_indictrans.py +14 -0
modeling_indictrans.py CHANGED
@@ -40,6 +40,7 @@ logger = logging.get_logger(__name__)
40
  _CONFIG_FOR_DOC = "IndicTransConfig"
41
 
42
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
 
43
 
44
 
45
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
@@ -59,6 +60,16 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
59
  return shifted_input_ids
60
 
61
 
 
 
 
 
 
 
 
 
 
 
62
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
63
  def _make_causal_mask(
64
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@@ -1206,6 +1217,9 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1206
  # labels, self.config.pad_token_id, self.config.decoder_start_token_id
1207
  # )
1208
 
 
 
 
1209
  outputs = self.model(
1210
  input_ids,
1211
  attention_mask=attention_mask,
 
40
  _CONFIG_FOR_DOC = "IndicTransConfig"
41
 
42
  INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
43
+ eos_token_id = 2
44
 
45
 
46
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
 
60
  return shifted_input_ids
61
 
62
 
63
+ def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
64
+ mask = (decoder_input_ids == eos_token_id)
65
+ decoder_input_ids[mask] = 1
66
+ decoder_attention_mask[mask] = 0
67
+
68
+ labels = decoder_input_ids[:, 1:]
69
+
70
+ return decoder_input_ids, decoder_attention_mask, labels
71
+
72
+
73
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
74
  def _make_causal_mask(
75
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
1217
  # labels, self.config.pad_token_id, self.config.decoder_start_token_id
1218
  # )
1219
 
1220
+ decoder_input_ids, decoder_attention_mask, labels = prepare_decoder_input_ids_label(decoder_input_ids,
1221
+ decoder_attention_mask)
1222
+
1223
  outputs = self.model(
1224
  input_ids,
1225
  attention_mask=attention_mask,