Raghavan commited on
Commit
72769c4
1 Parent(s): d33947c

Upload 7 files

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) AI4Bharat.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - as
4
+ - bn
5
+ - brx
6
+ - doi
7
+ - en
8
+ - gom
9
+ - gu
10
+ - hi
11
+ - kn
12
+ - ks
13
+ - kas
14
+ - mai
15
+ - ml
16
+ - mr
17
+ - mni
18
+ - mnb
19
+ - ne
20
+ - or
21
+ - pa
22
+ - sa
23
+ - sat
24
+ - sd
25
+ - snd
26
+ - ta
27
+ - te
28
+ - ur
29
+ language_details: >-
30
+ asm_Beng, ben_Beng, brx_Deva, doi_Deva, eng_Latn, gom_Deva, guj_Gujr,
31
+ hin_Deva, kan_Knda, kas_Arab, kas_Deva, mai_Deva, mal_Mlym, mar_Deva,
32
+ mni_Beng, mni_Mtei, npi_Deva, ory_Orya, pan_Guru, san_Deva, sat_Olck,
33
+ snd_Arab, snd_Deva, tam_Taml, tel_Telu, urd_Arab
34
+ tags:
35
+ - indictrans2
36
+ - translation
37
+ - ai4bharat
38
+ - multilingual
39
+ license: mit
40
+ datasets:
41
+ - flores-200
42
+ - IN22-Gen
43
+ - IN22-Conv
44
+ metrics:
45
+ - bleu
46
+ - chrf
47
+ - chrf++
48
+ - comet
49
+ inference: false
50
+ ---
51
+
52
+ # IndicTrans2
53
+
54
+ This is the model card of IndicTrans2 En-Indic Distilled 200M variant.
55
+
56
+ Please refer to [section 7.6: Distilled Models](https://openreview.net/forum?id=vfT4YuzAYA) in the TMLR submission for further details on model training, data and metrics.
57
+
58
+ ### Usage Instructions
59
+
60
+ Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_inference) for a detail description on how to use HF compatible IndicTrans2 models for inference.
61
+
62
+
63
+ ### Citation
64
+
65
+ If you consider using our work then please cite using:
66
+
67
+ ```
68
+ @article{ai4bharat2023indictrans2,
69
+ title = {IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
70
+ author = {AI4Bharat and Jay Gala and Pranjal A. Chitale and Raghavan AK and Sumanth Doddapaneni and Varun Gumma and Aswanth Kumar and Janki Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M. Khapra and Raj Dabre and Anoop Kunchukuttan},
71
+ year = {2023},
72
+ journal = {arXiv preprint arXiv: 2305.16307}
73
+ }
74
+ ```
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ai4bharat/indictrans2-en-indic-dist-200M",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "architectures": [
6
+ "IndicTransForConditionalGeneration"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_indictrans.IndicTransConfig",
10
+ "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
11
+ },
12
+ "attention_dropout": 0.0,
13
+ "bos_token_id": 0,
14
+ "decoder_attention_heads": 8,
15
+ "decoder_embed_dim": 512,
16
+ "decoder_ffn_dim": 2048,
17
+ "decoder_layerdrop": 0,
18
+ "decoder_layers": 18,
19
+ "decoder_normalize_before": true,
20
+ "decoder_start_token_id": 2,
21
+ "decoder_vocab_size": 122672,
22
+ "dropout": 0.2,
23
+ "encoder_attention_heads": 8,
24
+ "encoder_embed_dim": 512,
25
+ "encoder_ffn_dim": 2048,
26
+ "encoder_layerdrop": 0,
27
+ "encoder_layers": 18,
28
+ "encoder_normalize_before": true,
29
+ "encoder_vocab_size": 32322,
30
+ "eos_token_id": 2,
31
+ "init_std": 0.02,
32
+ "is_encoder_decoder": true,
33
+ "layernorm_embedding": true,
34
+ "max_source_positions": 256,
35
+ "max_target_positions": 256,
36
+ "model_type": "IndicTrans",
37
+ "num_hidden_layers": 18,
38
+ "pad_token_id": 1,
39
+ "scale_embedding": true,
40
+ "share_decoder_input_output_embed": true,
41
+ "torch_dtype": "float32",
42
+ "transformers_version": "4.32.1",
43
+ "use_cache": true
44
+ }
configuration_indictrans.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans config."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, Mapping, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
24
+ from transformers.onnx.utils import compute_effective_axis_dimension
25
+ from transformers.utils import TensorType, is_torch_available
26
+
27
+
28
+ # Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
29
+ class IndicTransConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
32
+ IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the IT2
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50265):
41
+ Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`IT2Model`] or
43
+ d_model (`int`, *optional*, defaults to 1024):
44
+ Dimensionality of the layers and the pooler layer.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the model should return the last key/values attentions (not used by all models).
81
+ ```"""
82
+ model_type = "IndicTrans"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "num_attention_heads": "encoder_attention_heads",
86
+ "hidden_size": "d_model",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ encoder_vocab_size=None,
92
+ decoder_vocab_size=None,
93
+ encoder_embed_dim=512,
94
+ decoder_embed_dim=512,
95
+ max_source_positions=210,
96
+ max_target_positions=210,
97
+ encoder_layers=6,
98
+ encoder_ffn_dim=2048,
99
+ encoder_attention_heads=8,
100
+ decoder_layers=6,
101
+ decoder_ffn_dim=2048,
102
+ decoder_attention_heads=8,
103
+ encoder_layerdrop=0.00,
104
+ decoder_layerdrop=0.00,
105
+ use_cache=True,
106
+ is_encoder_decoder=True,
107
+ activation_function="relu",
108
+ encoder_normalize_before=False,
109
+ decoder_normalize_before=False,
110
+ layernorm_embedding=False,
111
+ share_decoder_input_output_embed=False,
112
+ dropout=0.1,
113
+ attention_dropout=0.0,
114
+ activation_dropout=0.0,
115
+ init_std=0.02,
116
+ scale_embedding=True,
117
+ decoder_start_token_id=2,
118
+ pad_token_id=1,
119
+ bos_token_id=0,
120
+ eos_token_id=2,
121
+ **kwargs,
122
+ ):
123
+ self.encoder_vocab_size = encoder_vocab_size
124
+ self.decoder_vocab_size = decoder_vocab_size
125
+ self.encoder_normalize_before = encoder_normalize_before
126
+ self.decoder_normalize_before = decoder_normalize_before
127
+ self.layernorm_embedding = layernorm_embedding
128
+ self.max_source_positions = max_source_positions
129
+ self.max_target_positions = max_target_positions
130
+ self.encoder_embed_dim = encoder_embed_dim
131
+ self.decoder_embed_dim = decoder_embed_dim
132
+ self.encoder_ffn_dim = encoder_ffn_dim
133
+ self.encoder_layers = encoder_layers
134
+ self.encoder_attention_heads = encoder_attention_heads
135
+ self.decoder_ffn_dim = decoder_ffn_dim
136
+ self.decoder_layers = decoder_layers
137
+ self.decoder_attention_heads = decoder_attention_heads
138
+ self.dropout = dropout
139
+ self.attention_dropout = attention_dropout
140
+ self.activation_dropout = activation_dropout
141
+ self.activation_function = activation_function
142
+ self.init_std = init_std
143
+ self.encoder_layerdrop = encoder_layerdrop
144
+ self.decoder_layerdrop = decoder_layerdrop
145
+ self.use_cache = use_cache
146
+ self.num_hidden_layers = encoder_layers
147
+ self.scale_embedding = scale_embedding
148
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
149
+
150
+ super().__init__(
151
+ pad_token_id=pad_token_id,
152
+ bos_token_id=bos_token_id,
153
+ eos_token_id=eos_token_id,
154
+ is_encoder_decoder=is_encoder_decoder,
155
+ decoder_start_token_id=decoder_start_token_id,
156
+ **kwargs,
157
+ )
158
+
159
+
160
+ class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
161
+ @property
162
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
163
+ common_inputs = OrderedDict(
164
+ [
165
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
166
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
167
+ ]
168
+ )
169
+
170
+ if self.use_past:
171
+ common_inputs["decoder_input_ids"] = {0: "batch"}
172
+ common_inputs["decoder_attention_mask"] = {
173
+ 0: "batch",
174
+ 1: "past_decoder_sequence + sequence",
175
+ }
176
+ else:
177
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
178
+ common_inputs["decoder_attention_mask"] = {
179
+ 0: "batch",
180
+ 1: "decoder_sequence",
181
+ }
182
+
183
+ if self.use_past:
184
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
185
+ return common_inputs
186
+
187
+ # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
188
+ # A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
189
+ # answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
190
+ # was done for BART so that it can be updated if need be.
191
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
192
+ self,
193
+ tokenizer: PreTrainedTokenizer,
194
+ batch_size: int = -1,
195
+ seq_length: int = -1,
196
+ is_pair: bool = False,
197
+ framework: Optional[TensorType] = None,
198
+ ) -> Mapping[str, Any]:
199
+ # Copied from OnnxConfig.generate_dummy_inputs
200
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
201
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
202
+ batch_size = compute_effective_axis_dimension(
203
+ batch_size,
204
+ fixed_dimension=OnnxConfig.default_fixed_batch,
205
+ num_token_to_add=0,
206
+ )
207
+
208
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
209
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
210
+ seq_length = compute_effective_axis_dimension(
211
+ seq_length,
212
+ fixed_dimension=OnnxConfig.default_fixed_sequence,
213
+ num_token_to_add=token_to_add,
214
+ )
215
+
216
+ # Generate dummy inputs according to compute batch and sequence
217
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
218
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
219
+ return common_inputs
220
+
221
+ # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
222
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
223
+ self,
224
+ tokenizer: PreTrainedTokenizer,
225
+ batch_size: int = -1,
226
+ seq_length: int = -1,
227
+ is_pair: bool = False,
228
+ framework: Optional[TensorType] = None,
229
+ ) -> Mapping[str, Any]:
230
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
231
+ tokenizer, batch_size, seq_length, is_pair, framework
232
+ )
233
+
234
+ # Generate decoder inputs
235
+ decoder_seq_length = seq_length if not self.use_past else 1
236
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
237
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
238
+ )
239
+ decoder_inputs = {
240
+ f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
241
+ }
242
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
243
+
244
+ if self.use_past:
245
+ if not is_torch_available():
246
+ raise ValueError(
247
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
248
+ )
249
+ else:
250
+ import torch
251
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
252
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
253
+ (
254
+ num_encoder_attention_heads,
255
+ num_decoder_attention_heads,
256
+ ) = self.num_attention_heads
257
+ encoder_shape = (
258
+ batch,
259
+ num_encoder_attention_heads,
260
+ encoder_seq_length,
261
+ self._config.hidden_size // num_encoder_attention_heads,
262
+ )
263
+ decoder_past_length = decoder_seq_length + 3
264
+ decoder_shape = (
265
+ batch,
266
+ num_decoder_attention_heads,
267
+ decoder_past_length,
268
+ self._config.hidden_size // num_decoder_attention_heads,
269
+ )
270
+
271
+ common_inputs["decoder_attention_mask"] = torch.cat(
272
+ [
273
+ common_inputs["decoder_attention_mask"],
274
+ torch.ones(batch, decoder_past_length),
275
+ ],
276
+ dim=1,
277
+ )
278
+
279
+ common_inputs["past_key_values"] = []
280
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
281
+ num_encoder_layers, num_decoder_layers = self.num_layers
282
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
283
+ max_num_layers = (
284
+ max(num_encoder_layers, num_decoder_layers) - min_num_layers
285
+ )
286
+ remaining_side_name = (
287
+ "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
288
+ )
289
+
290
+ for _ in range(min_num_layers):
291
+ common_inputs["past_key_values"].append(
292
+ (
293
+ torch.zeros(decoder_shape),
294
+ torch.zeros(decoder_shape),
295
+ torch.zeros(encoder_shape),
296
+ torch.zeros(encoder_shape),
297
+ )
298
+ )
299
+ # TODO: test this.
300
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
301
+ for _ in range(min_num_layers, max_num_layers):
302
+ common_inputs["past_key_values"].append(
303
+ (torch.zeros(shape), torch.zeros(shape))
304
+ )
305
+ return common_inputs
306
+
307
+ generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "decoder_start_token_id": 2,
5
+ "eos_token_id": 2,
6
+ "pad_token_id": 1,
7
+ "transformers_version": "4.32.1"
8
+ }
modeling_indictrans.py ADDED
@@ -0,0 +1,1267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans model."""
16
+
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ Seq2SeqLMOutput,
31
+ Seq2SeqModelOutput,
32
+ )
33
+
34
+ from transformers.utils import logging
35
+ from transformers.modeling_utils import PreTrainedModel
36
+
37
+ from .configuration_indictrans import IndicTransConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = "IndicTransConfig"
43
+
44
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
45
+
46
+
47
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
48
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
49
+ """
50
+ Shift input ids one token to the right.
51
+ """
52
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
53
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
54
+ shifted_input_ids[:, 0] = decoder_start_token_id
55
+
56
+ if pad_token_id is None:
57
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
58
+ # replace possible -100 values in labels by `pad_token_id`
59
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
60
+
61
+ return shifted_input_ids
62
+
63
+
64
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
65
+ def _make_causal_mask(
66
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
67
+ ):
68
+ """
69
+ Make causal mask used for bi-directional self-attention.
70
+ """
71
+ bsz, tgt_len = input_ids_shape
72
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
73
+ mask_cond = torch.arange(mask.size(-1), device=device)
74
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
75
+ mask = mask.to(dtype)
76
+
77
+ if past_key_values_length > 0:
78
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
79
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
80
+
81
+
82
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
83
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
84
+ """
85
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
86
+ """
87
+ bsz, src_len = mask.size()
88
+ tgt_len = tgt_len if tgt_len is not None else src_len
89
+
90
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
91
+
92
+ inverted_mask = 1.0 - expanded_mask
93
+
94
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
95
+
96
+
97
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
98
+ """
99
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
100
+ are ignored. This is modified from fairseq's `utils.make_positions`.
101
+ """
102
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
103
+ mask = input_ids.ne(padding_idx).int()
104
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
105
+ return incremental_indices.long() + padding_idx
106
+
107
+
108
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
109
+ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
110
+ """This module produces sinusoidal positional embeddings of any length."""
111
+
112
+ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
113
+ super().__init__()
114
+ self.offset = 2
115
+ self.embedding_dim = embedding_dim
116
+ self.padding_idx = padding_idx
117
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
118
+
119
+ def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
120
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
121
+ if hasattr(self, "weights"):
122
+ # in forward put the weights on the correct dtype and device of the param
123
+ emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
124
+
125
+ self.register_buffer("weights", emb_weights, persistent=False)
126
+
127
+ @staticmethod
128
+ def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
129
+ """
130
+ Build sinusoidal embeddings.
131
+
132
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
133
+ "Attention Is All You Need".
134
+ """
135
+ half_dim = embedding_dim // 2
136
+ emb = math.log(10000) / (half_dim - 1)
137
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
138
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
139
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
140
+ if embedding_dim % 2 == 1:
141
+ # zero pad
142
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
143
+ if padding_idx is not None:
144
+ emb[padding_idx, :] = 0
145
+
146
+ return emb.to(torch.get_default_dtype())
147
+
148
+ @torch.no_grad()
149
+ def forward(
150
+ self, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, past_key_values_length: int = 0
151
+ ):
152
+ if input_ids is not None:
153
+ bsz, seq_len = input_ids.size()
154
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
155
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
156
+ input_ids.device
157
+ )
158
+ else:
159
+ bsz, seq_len = inputs_embeds.size()[:-1]
160
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
161
+
162
+ # expand embeddings if needed
163
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
164
+ if max_pos > self.weights.size(0):
165
+ self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
166
+
167
+ return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
168
+
169
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
170
+ """
171
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
172
+
173
+ Args:
174
+ inputs_embeds: torch.Tensor
175
+
176
+ Returns: torch.Tensor
177
+ """
178
+ input_shape = inputs_embeds.size()[:-1]
179
+ sequence_length = input_shape[1]
180
+
181
+ position_ids = torch.arange(
182
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
183
+ )
184
+ return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
185
+
186
+
187
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
188
+ class IndicTransAttention(nn.Module):
189
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
190
+
191
+ def __init__(
192
+ self,
193
+ embed_dim: int,
194
+ num_heads: int,
195
+ dropout: float = 0.0,
196
+ is_decoder: bool = False,
197
+ bias: bool = True,
198
+ ):
199
+ super().__init__()
200
+ self.embed_dim = embed_dim
201
+ self.num_heads = num_heads
202
+ self.dropout = dropout
203
+ self.head_dim = embed_dim // num_heads
204
+
205
+ if (self.head_dim * num_heads) != self.embed_dim:
206
+ raise ValueError(
207
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
208
+ f" and `num_heads`: {num_heads})."
209
+ )
210
+ self.scaling = self.head_dim**-0.5
211
+ self.is_decoder = is_decoder
212
+
213
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
214
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
215
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
216
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
217
+
218
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
219
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
220
+
221
+ def forward(
222
+ self,
223
+ hidden_states: torch.Tensor,
224
+ key_value_states: Optional[torch.Tensor] = None,
225
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ layer_head_mask: Optional[torch.Tensor] = None,
228
+ output_attentions: bool = False,
229
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
230
+ """Input shape: Batch x Time x Channel"""
231
+
232
+ # if key_value_states are provided this layer is used as a cross-attention layer
233
+ # for the decoder
234
+ is_cross_attention = key_value_states is not None
235
+
236
+ bsz, tgt_len, _ = hidden_states.size()
237
+
238
+ # get query proj
239
+ query_states = self.q_proj(hidden_states) * self.scaling
240
+ # get key, value proj
241
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
242
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
243
+ # the provided `key_value_states` to support prefix tuning
244
+ if (
245
+ is_cross_attention
246
+ and past_key_value is not None
247
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
248
+ ):
249
+ # reuse k,v, cross_attentions
250
+ key_states = past_key_value[0]
251
+ value_states = past_key_value[1]
252
+ elif is_cross_attention:
253
+ # cross_attentions
254
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
255
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
256
+ elif past_key_value is not None:
257
+ # reuse k, v, self_attention
258
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
259
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
260
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
261
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
262
+ else:
263
+ # self_attention
264
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
265
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
266
+
267
+ if self.is_decoder:
268
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
269
+ # Further calls to cross_attention layer can then reuse all cross-attention
270
+ # key/value_states (first "if" case)
271
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
272
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
273
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
274
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
275
+ past_key_value = (key_states, value_states)
276
+
277
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
278
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
279
+ key_states = key_states.reshape(*proj_shape)
280
+ value_states = value_states.reshape(*proj_shape)
281
+
282
+ src_len = key_states.size(1)
283
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
284
+
285
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
286
+ raise ValueError(
287
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
288
+ f" {attn_weights.size()}"
289
+ )
290
+
291
+ if attention_mask is not None:
292
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
293
+ raise ValueError(
294
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
295
+ )
296
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
297
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
298
+
299
+ attn_weights = F.softmax(attn_weights, dim=-1)
300
+
301
+ if layer_head_mask is not None:
302
+ if layer_head_mask.size() != (self.num_heads,):
303
+ raise ValueError(
304
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
305
+ f" {layer_head_mask.size()}"
306
+ )
307
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
308
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
309
+
310
+ if output_attentions:
311
+ # this operation is a bit awkward, but it's required to
312
+ # make sure that attn_weights keeps its gradient.
313
+ # In order to do so, attn_weights have to be reshaped
314
+ # twice and have to be reused in the following
315
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
316
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
317
+ else:
318
+ attn_weights_reshaped = None
319
+
320
+ attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
321
+
322
+ attn_output = torch.bmm(attn_probs, value_states)
323
+
324
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
325
+ raise ValueError(
326
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
327
+ f" {attn_output.size()}"
328
+ )
329
+
330
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
331
+ attn_output = attn_output.transpose(1, 2)
332
+
333
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
334
+ # partitioned across GPUs when using tensor-parallelism.
335
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
336
+
337
+ attn_output = self.out_proj(attn_output)
338
+
339
+ return attn_output, attn_weights_reshaped, past_key_value
340
+
341
+
342
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
343
+ class IndicTransEncoderLayer(nn.Module):
344
+ def __init__(self, config: IndicTransConfig):
345
+ super().__init__()
346
+ self.embed_dim = config.encoder_embed_dim
347
+ self.self_attn = IndicTransAttention(
348
+ embed_dim=self.embed_dim,
349
+ num_heads=config.encoder_attention_heads,
350
+ dropout=config.attention_dropout,
351
+ )
352
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
353
+ self.dropout = config.dropout
354
+ self.activation_fn = ACT2FN[config.activation_function]
355
+ self.activation_dropout = config.activation_dropout
356
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
357
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
358
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
359
+ self.normalize_before = config.encoder_normalize_before
360
+
361
+ def forward(
362
+ self,
363
+ hidden_states: torch.Tensor,
364
+ attention_mask: torch.Tensor,
365
+ layer_head_mask: torch.Tensor,
366
+ output_attentions: bool = False,
367
+ ) -> torch.Tensor:
368
+ """
369
+ Args:
370
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
371
+ attention_mask (`torch.FloatTensor`): attention mask of size
372
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
373
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
374
+ `(encoder_attention_heads,)`.
375
+ output_attentions (`bool`, *optional*):
376
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
377
+ returned tensors for more detail.
378
+ """
379
+ residual = hidden_states
380
+ if self.normalize_before:
381
+ hidden_states = self.self_attn_layer_norm(hidden_states)
382
+ hidden_states, attn_weights, _ = self.self_attn(
383
+ hidden_states=hidden_states,
384
+ attention_mask=attention_mask,
385
+ layer_head_mask=layer_head_mask,
386
+ output_attentions=output_attentions,
387
+ )
388
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
389
+ hidden_states = residual + hidden_states
390
+ if not self.normalize_before:
391
+ hidden_states = self.self_attn_layer_norm(hidden_states)
392
+
393
+ residual = hidden_states
394
+ if self.normalize_before:
395
+ hidden_states = self.final_layer_norm(hidden_states)
396
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
397
+ hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
398
+ hidden_states = self.fc2(hidden_states)
399
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
400
+ hidden_states = residual + hidden_states
401
+ if not self.normalize_before:
402
+ hidden_states = self.final_layer_norm(hidden_states)
403
+
404
+ if hidden_states.dtype == torch.float16 and (
405
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
406
+ ):
407
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
408
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
409
+
410
+ outputs = (hidden_states,)
411
+
412
+ if output_attentions:
413
+ outputs += (attn_weights,)
414
+
415
+ return outputs
416
+
417
+
418
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
419
+ class IndicTransDecoderLayer(nn.Module):
420
+ def __init__(self, config: IndicTransConfig):
421
+ super().__init__()
422
+ self.embed_dim = config.decoder_embed_dim
423
+
424
+ self.self_attn = IndicTransAttention(
425
+ embed_dim=self.embed_dim,
426
+ num_heads=config.decoder_attention_heads,
427
+ dropout=config.attention_dropout,
428
+ is_decoder=True,
429
+ )
430
+ self.dropout = config.dropout
431
+ self.activation_fn = ACT2FN[config.activation_function]
432
+ self.activation_dropout = config.activation_dropout
433
+
434
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
435
+ self.encoder_attn = IndicTransAttention(
436
+ self.embed_dim,
437
+ config.decoder_attention_heads,
438
+ dropout=config.attention_dropout,
439
+ is_decoder=True,
440
+ )
441
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
442
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
443
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
444
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
445
+ self.normalize_before = config.decoder_normalize_before
446
+
447
+ def forward(
448
+ self,
449
+ hidden_states: torch.Tensor,
450
+ attention_mask: Optional[torch.Tensor] = None,
451
+ encoder_hidden_states: Optional[torch.Tensor] = None,
452
+ encoder_attention_mask: Optional[torch.Tensor] = None,
453
+ layer_head_mask: Optional[torch.Tensor] = None,
454
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
455
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
456
+ output_attentions: Optional[bool] = False,
457
+ use_cache: Optional[bool] = True,
458
+ ) -> torch.Tensor:
459
+ """
460
+ Args:
461
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
462
+ attention_mask (`torch.FloatTensor`): attention mask of size
463
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
464
+ encoder_hidden_states (`torch.FloatTensor`):
465
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
466
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
467
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
468
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
469
+ `(encoder_attention_heads,)`.
470
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
471
+ size `(decoder_attention_heads,)`.
472
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
473
+ output_attentions (`bool`, *optional*):
474
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
475
+ returned tensors for more detail.
476
+ """
477
+ residual = hidden_states
478
+ if self.normalize_before:
479
+ hidden_states = self.self_attn_layer_norm(hidden_states)
480
+
481
+ # Self Attention
482
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
483
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
484
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
485
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
486
+ hidden_states=hidden_states,
487
+ past_key_value=self_attn_past_key_value,
488
+ attention_mask=attention_mask,
489
+ layer_head_mask=layer_head_mask,
490
+ output_attentions=output_attentions,
491
+ )
492
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
493
+ hidden_states = residual + hidden_states
494
+ if not self.normalize_before:
495
+ hidden_states = self.self_attn_layer_norm(hidden_states)
496
+
497
+ # Cross-Attention Block
498
+ cross_attn_present_key_value = None
499
+ cross_attn_weights = None
500
+ if encoder_hidden_states is not None:
501
+ residual = hidden_states
502
+ if self.normalize_before:
503
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
504
+
505
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
506
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
507
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
508
+ hidden_states=hidden_states,
509
+ key_value_states=encoder_hidden_states,
510
+ attention_mask=encoder_attention_mask,
511
+ layer_head_mask=cross_attn_layer_head_mask,
512
+ past_key_value=cross_attn_past_key_value,
513
+ output_attentions=output_attentions,
514
+ )
515
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
516
+ hidden_states = residual + hidden_states
517
+ if not self.normalize_before:
518
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
519
+
520
+ # add cross-attn to positions 3,4 of present_key_value tuple
521
+ present_key_value = present_key_value + cross_attn_present_key_value
522
+
523
+ # Fully Connected
524
+ residual = hidden_states
525
+ if self.normalize_before:
526
+ hidden_states = self.final_layer_norm(hidden_states)
527
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
528
+ hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training)
529
+ hidden_states = self.fc2(hidden_states)
530
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
531
+ hidden_states = residual + hidden_states
532
+ if not self.normalize_before:
533
+ hidden_states = self.final_layer_norm(hidden_states)
534
+
535
+ outputs = (hidden_states,)
536
+
537
+ if output_attentions:
538
+ outputs += (self_attn_weights, cross_attn_weights)
539
+
540
+ if use_cache:
541
+ outputs += (present_key_value,)
542
+
543
+ return outputs
544
+
545
+
546
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
547
+ class IndicTransPreTrainedModel(PreTrainedModel):
548
+ config_class = IndicTransConfig
549
+ base_model_prefix = "model"
550
+ supports_gradient_checkpointing = True
551
+ _no_split_modules = ["IndicTransAttention"]
552
+
553
+ def _init_weights(self, module):
554
+ std = self.config.init_std
555
+ if isinstance(module, nn.Linear):
556
+ module.weight.data.normal_(mean=0.0, std=std)
557
+ if module.bias is not None:
558
+ module.bias.data.zero_()
559
+ elif isinstance(module, nn.Embedding):
560
+ module.weight.data.normal_(mean=0.0, std=std)
561
+ if module.padding_idx is not None:
562
+ module.weight.data[module.padding_idx].zero_()
563
+
564
+ def _set_gradient_checkpointing(self, module, value=False):
565
+ if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
566
+ module.gradient_checkpointing = value
567
+
568
+
569
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
570
+ class IndicTransEncoder(IndicTransPreTrainedModel):
571
+ """
572
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
573
+ [`IndicTransEncoderLayer`].
574
+
575
+ Args:
576
+ config: IndicTransConfig
577
+ embed_tokens (nn.Embedding): output embedding
578
+ """
579
+
580
+ def __init__(self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None):
581
+ super().__init__(config)
582
+
583
+ self.dropout = config.dropout
584
+ self.layerdrop = config.encoder_layerdrop
585
+
586
+ embed_dim = config.encoder_embed_dim
587
+ self.padding_idx = config.pad_token_id
588
+ self.max_source_positions = config.max_source_positions
589
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
590
+
591
+ self.embed_tokens = nn.Embedding(config.encoder_vocab_size, embed_dim, self.padding_idx)
592
+
593
+ if embed_tokens is not None:
594
+ self.embed_tokens.weight = embed_tokens.weight
595
+
596
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
597
+ config.max_source_positions,
598
+ embed_dim,
599
+ self.padding_idx,
600
+ )
601
+ self.layers = nn.ModuleList([IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)])
602
+ self.layer_norm = nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
603
+ self.layernorm_embedding = nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
604
+
605
+ self.gradient_checkpointing = False
606
+ # Initialize weights and apply final processing
607
+ self.post_init()
608
+
609
+ def forward(
610
+ self,
611
+ input_ids: Optional[torch.Tensor] = None,
612
+ attention_mask: Optional[torch.Tensor] = None,
613
+ head_mask: Optional[torch.Tensor] = None,
614
+ inputs_embeds: Optional[torch.Tensor] = None,
615
+ output_attentions: Optional[bool] = None,
616
+ output_hidden_states: Optional[bool] = None,
617
+ return_dict: Optional[bool] = None,
618
+ ):
619
+ r"""
620
+ Args:
621
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
622
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
623
+ provide it.
624
+
625
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
626
+ [`PreTrainedTokenizer.__call__`] for details.
627
+
628
+ [What are input IDs?](../glossary#input-ids)
629
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
630
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
631
+
632
+ - 1 for tokens that are **not masked**,
633
+ - 0 for tokens that are **masked**.
634
+
635
+ [What are attention masks?](../glossary#attention-mask)
636
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
637
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
638
+
639
+ - 1 indicates the head is **not masked**,
640
+ - 0 indicates the head is **masked**.
641
+
642
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
643
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
644
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
645
+ than the model's internal embedding lookup matrix.
646
+ output_attentions (`bool`, *optional*):
647
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
648
+ returned tensors for more detail.
649
+ output_hidden_states (`bool`, *optional*):
650
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
651
+ for more detail.
652
+ return_dict (`bool`, *optional*):
653
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
654
+ """
655
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
656
+ output_hidden_states = (
657
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
658
+ )
659
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
+
661
+ # retrieve input_ids and inputs_embeds
662
+ if input_ids is not None and inputs_embeds is not None:
663
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
664
+ elif input_ids is not None:
665
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
666
+ input_shape = input_ids.size()
667
+ input_ids = input_ids.view(-1, input_shape[-1])
668
+ elif inputs_embeds is not None:
669
+ input_shape = inputs_embeds.size()[:-1]
670
+ else:
671
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
672
+
673
+ if inputs_embeds is None:
674
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
675
+
676
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
677
+ embed_pos = embed_pos.to(inputs_embeds.device)
678
+
679
+ hidden_states = inputs_embeds + embed_pos
680
+ if self.layernorm_embedding is not None:
681
+ x = self.layernorm_embedding(hidden_states)
682
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
683
+
684
+ # expand attention_mask
685
+ if attention_mask is not None:
686
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
687
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
688
+
689
+ encoder_states = () if output_hidden_states else None
690
+ all_attentions = () if output_attentions else None
691
+
692
+ # check if head_mask has a correct number of layers specified if desired
693
+ if head_mask is not None:
694
+ if head_mask.size()[0] != len(self.layers):
695
+ raise ValueError(
696
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
697
+ f" {head_mask.size()[0]}."
698
+ )
699
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
700
+
701
+ for idx, encoder_layer in enumerate(self.layers):
702
+ if output_hidden_states:
703
+ encoder_states = encoder_states + (hidden_states,)
704
+
705
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
706
+ dropout_probability = torch.rand([])
707
+
708
+ skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
709
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
710
+ # under deepspeed zero3 all gpus must run in sync
711
+
712
+ if self.gradient_checkpointing and self.training:
713
+ # create gradient checkpointing function
714
+ def create_custom_forward(module):
715
+ def custom_forward(*inputs):
716
+ return module(*inputs, output_attentions)
717
+
718
+ return custom_forward
719
+
720
+ layer_outputs = torch.utils.checkpoint.checkpoint(
721
+ create_custom_forward(encoder_layer),
722
+ hidden_states,
723
+ attention_mask,
724
+ (head_mask[idx] if head_mask is not None else None),
725
+ )
726
+ else:
727
+ layer_outputs = encoder_layer(
728
+ hidden_states,
729
+ attention_mask,
730
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
731
+ output_attentions=output_attentions,
732
+ )
733
+
734
+ hidden_states = layer_outputs[0]
735
+
736
+ if skip_the_layer:
737
+ layer_outputs = (None, None)
738
+
739
+ if output_attentions:
740
+ all_attentions = all_attentions + (layer_outputs[1],)
741
+
742
+ if self.layer_norm is not None:
743
+ hidden_states = self.layer_norm(hidden_states)
744
+
745
+ if output_hidden_states:
746
+ encoder_states = encoder_states + (hidden_states,)
747
+
748
+ if not return_dict:
749
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
750
+ return BaseModelOutput(
751
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
752
+ )
753
+
754
+
755
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
756
+ class IndicTransDecoder(IndicTransPreTrainedModel):
757
+ """
758
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
759
+
760
+ Args:
761
+ config: IndicTransConfig
762
+ embed_tokens (nn.Embedding): output embedding
763
+ """
764
+
765
+ def __init__(self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None):
766
+ super().__init__(config)
767
+ self.dropout = config.dropout
768
+ self.layerdrop = config.decoder_layerdrop
769
+
770
+ embed_dim = config.encoder_embed_dim
771
+ self.padding_idx = config.pad_token_id
772
+ self.max_target_positions = config.max_target_positions
773
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
774
+
775
+ self.embed_tokens = nn.Embedding(config.decoder_vocab_size, embed_dim, self.padding_idx)
776
+
777
+ if embed_tokens is not None:
778
+ self.embed_tokens.weight = embed_tokens.weight
779
+
780
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
781
+ config.max_target_positions,
782
+ embed_dim,
783
+ self.padding_idx,
784
+ )
785
+ self.layers = nn.ModuleList([IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)])
786
+ self.layer_norm = nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
787
+ self.layernorm_embedding = nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
788
+
789
+ self.gradient_checkpointing = False
790
+ # Initialize weights and apply final processing
791
+ self.post_init()
792
+
793
+ def forward(
794
+ self,
795
+ input_ids: Optional[torch.Tensor] = None,
796
+ attention_mask: Optional[torch.Tensor] = None,
797
+ encoder_hidden_states: Optional[torch.Tensor] = None,
798
+ encoder_attention_mask: Optional[torch.Tensor] = None,
799
+ head_mask: Optional[torch.Tensor] = None,
800
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
801
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
802
+ inputs_embeds: Optional[torch.Tensor] = None,
803
+ use_cache: Optional[bool] = None,
804
+ output_attentions: Optional[bool] = None,
805
+ output_hidden_states: Optional[bool] = None,
806
+ return_dict: Optional[bool] = None,
807
+ ):
808
+ r"""
809
+ Args:
810
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
811
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
812
+ provide it.
813
+
814
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
815
+ [`PreTrainedTokenizer.__call__`] for details.
816
+
817
+ [What are input IDs?](../glossary#input-ids)
818
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
819
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
820
+
821
+ - 1 for tokens that are **not masked**,
822
+ - 0 for tokens that are **masked**.
823
+
824
+ [What are attention masks?](../glossary#attention-mask)
825
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
826
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
827
+ of the decoder.
828
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
829
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
830
+ selected in `[0, 1]`:
831
+
832
+ - 1 for tokens that are **not masked**,
833
+ - 0 for tokens that are **masked**.
834
+
835
+ [What are attention masks?](../glossary#attention-mask)
836
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
837
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
838
+
839
+ - 1 indicates the head is **not masked**,
840
+ - 0 indicates the head is **masked**.
841
+
842
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
843
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
844
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
845
+
846
+ - 1 indicates the head is **not masked**,
847
+ - 0 indicates the head is **masked**.
848
+
849
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
850
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
851
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
852
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
853
+
854
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
855
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
856
+
857
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
858
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
859
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
860
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
861
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
862
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
863
+ embedding lookup matrix.
864
+ output_attentions (`bool`, *optional*):
865
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
866
+ returned tensors for more detail.
867
+ output_hidden_states (`bool`, *optional*):
868
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
869
+ for more detail.
870
+ return_dict (`bool`, *optional*):
871
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
872
+ """
873
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
874
+ output_hidden_states = (
875
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
876
+ )
877
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
878
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
879
+
880
+ # retrieve input_ids and inputs_embeds
881
+ if input_ids is not None and inputs_embeds is not None:
882
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
883
+ elif input_ids is not None:
884
+ input_shape = input_ids.size()
885
+ input_ids = input_ids.view(-1, input_shape[-1])
886
+ elif inputs_embeds is not None:
887
+ input_shape = inputs_embeds.size()[:-1]
888
+ else:
889
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
890
+
891
+ # past_key_values_length
892
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
893
+
894
+ if inputs_embeds is None:
895
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
896
+
897
+ # create causal mask
898
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
899
+ combined_attention_mask = None
900
+ if input_shape[-1] > 1:
901
+ combined_attention_mask = _make_causal_mask(
902
+ input_shape,
903
+ inputs_embeds.dtype,
904
+ device=inputs_embeds.device,
905
+ past_key_values_length=past_key_values_length,
906
+ )
907
+
908
+ if attention_mask is not None and combined_attention_mask is not None:
909
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
910
+ combined_attention_mask = combined_attention_mask + _expand_mask(
911
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
912
+ )
913
+
914
+ # expand encoder attention mask
915
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
916
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
917
+ encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
918
+
919
+ # embed positions
920
+ positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
921
+ positions = positions.to(inputs_embeds.device)
922
+
923
+ hidden_states = inputs_embeds + positions
924
+ if self.layernorm_embedding is not None:
925
+ hidden_states = self.layernorm_embedding(hidden_states)
926
+
927
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
928
+
929
+ if self.gradient_checkpointing and self.training:
930
+ if use_cache:
931
+ logger.warning_once(
932
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting" " `use_cache=False`..."
933
+ )
934
+ use_cache = False
935
+
936
+ # decoder layers
937
+ all_hidden_states = () if output_hidden_states else None
938
+ all_self_attns = () if output_attentions else None
939
+ all_cross_attentions = () if output_attentions else None
940
+ next_decoder_cache = () if use_cache else None
941
+
942
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
943
+ for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
944
+ if attn_mask is not None:
945
+ if attn_mask.size()[0] != len(self.layers):
946
+ raise ValueError(
947
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
948
+ f" {head_mask.size()[0]}."
949
+ )
950
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
951
+
952
+ for idx, decoder_layer in enumerate(self.layers):
953
+ if output_hidden_states:
954
+ all_hidden_states += (hidden_states,)
955
+
956
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
957
+ dropout_probability = torch.rand([])
958
+
959
+ skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False
960
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
961
+ # under deepspeed zero3 all gpus must run in sync
962
+
963
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
964
+
965
+ if self.gradient_checkpointing and self.training:
966
+
967
+ def create_custom_forward(module):
968
+ def custom_forward(*inputs):
969
+ # None for past_key_value
970
+ return module(*inputs, output_attentions, use_cache)
971
+
972
+ return custom_forward
973
+
974
+ layer_outputs = torch.utils.checkpoint.checkpoint(
975
+ create_custom_forward(decoder_layer),
976
+ hidden_states,
977
+ combined_attention_mask,
978
+ encoder_hidden_states,
979
+ encoder_attention_mask,
980
+ head_mask[idx] if head_mask is not None else None,
981
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
982
+ None,
983
+ )
984
+ else:
985
+ layer_outputs = decoder_layer(
986
+ hidden_states,
987
+ attention_mask=combined_attention_mask,
988
+ encoder_hidden_states=encoder_hidden_states,
989
+ encoder_attention_mask=encoder_attention_mask,
990
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
991
+ cross_attn_layer_head_mask=(
992
+ cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
993
+ ),
994
+ past_key_value=past_key_value,
995
+ output_attentions=output_attentions,
996
+ use_cache=use_cache,
997
+ )
998
+
999
+ hidden_states = layer_outputs[0]
1000
+
1001
+ if skip_the_layer:
1002
+ continue
1003
+
1004
+ if use_cache:
1005
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1006
+
1007
+ if output_attentions:
1008
+ all_self_attns += (layer_outputs[1],)
1009
+ all_cross_attentions += (layer_outputs[2],)
1010
+
1011
+ if self.layer_norm is not None:
1012
+ hidden_states = self.layer_norm(hidden_states)
1013
+
1014
+ # add hidden states from the last decoder layer
1015
+ if output_hidden_states:
1016
+ all_hidden_states += (hidden_states,)
1017
+
1018
+ next_cache = next_decoder_cache if use_cache else None
1019
+ if not return_dict:
1020
+ return tuple(
1021
+ v
1022
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1023
+ if v is not None
1024
+ )
1025
+ return BaseModelOutputWithPastAndCrossAttentions(
1026
+ last_hidden_state=hidden_states,
1027
+ past_key_values=next_cache,
1028
+ hidden_states=all_hidden_states,
1029
+ attentions=all_self_attns,
1030
+ cross_attentions=all_cross_attentions,
1031
+ )
1032
+
1033
+
1034
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
1035
+ class IndicTransModel(IndicTransPreTrainedModel):
1036
+ _tied_weights_keys = None
1037
+
1038
+ def __init__(self, config: IndicTransConfig):
1039
+ super().__init__(config)
1040
+
1041
+ self.encoder = IndicTransEncoder(config)
1042
+ self.decoder = IndicTransDecoder(config)
1043
+
1044
+ # Initialize weights and apply final processing
1045
+ self.post_init()
1046
+
1047
+ def get_encoder(self):
1048
+ return self.encoder
1049
+
1050
+ def get_decoder(self):
1051
+ return self.decoder
1052
+
1053
+ def forward(
1054
+ self,
1055
+ input_ids: Optional[torch.LongTensor] = None,
1056
+ attention_mask: Optional[torch.Tensor] = None,
1057
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1058
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1059
+ head_mask: Optional[torch.Tensor] = None,
1060
+ decoder_head_mask: Optional[torch.Tensor] = None,
1061
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1062
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1063
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1064
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1065
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1066
+ use_cache: Optional[bool] = None,
1067
+ output_attentions: Optional[bool] = None,
1068
+ output_hidden_states: Optional[bool] = None,
1069
+ return_dict: Optional[bool] = None,
1070
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1071
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1072
+ output_hidden_states = (
1073
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1074
+ )
1075
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1076
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1077
+
1078
+ if encoder_outputs is None:
1079
+ encoder_outputs = self.encoder(
1080
+ input_ids=input_ids,
1081
+ attention_mask=attention_mask,
1082
+ head_mask=head_mask,
1083
+ inputs_embeds=inputs_embeds,
1084
+ output_attentions=output_attentions,
1085
+ output_hidden_states=output_hidden_states,
1086
+ return_dict=return_dict,
1087
+ )
1088
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1089
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1090
+ encoder_outputs = BaseModelOutput(
1091
+ last_hidden_state=encoder_outputs[0],
1092
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1093
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1094
+ )
1095
+
1096
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1097
+ decoder_outputs = self.decoder(
1098
+ input_ids=decoder_input_ids,
1099
+ attention_mask=decoder_attention_mask,
1100
+ encoder_hidden_states=encoder_outputs[0],
1101
+ encoder_attention_mask=attention_mask,
1102
+ head_mask=decoder_head_mask,
1103
+ cross_attn_head_mask=cross_attn_head_mask,
1104
+ past_key_values=past_key_values,
1105
+ inputs_embeds=decoder_inputs_embeds,
1106
+ use_cache=use_cache,
1107
+ output_attentions=output_attentions,
1108
+ output_hidden_states=output_hidden_states,
1109
+ return_dict=return_dict,
1110
+ )
1111
+
1112
+ if not return_dict:
1113
+ return decoder_outputs + encoder_outputs
1114
+
1115
+ return Seq2SeqModelOutput(
1116
+ last_hidden_state=decoder_outputs.last_hidden_state,
1117
+ past_key_values=decoder_outputs.past_key_values,
1118
+ decoder_hidden_states=decoder_outputs.hidden_states,
1119
+ decoder_attentions=decoder_outputs.attentions,
1120
+ cross_attentions=decoder_outputs.cross_attentions,
1121
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1122
+ encoder_hidden_states=encoder_outputs.hidden_states,
1123
+ encoder_attentions=encoder_outputs.attentions,
1124
+ )
1125
+
1126
+
1127
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1128
+ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1129
+ base_model_prefix = "model"
1130
+ _tied_weights_keys = None
1131
+
1132
+ def __init__(self, config: IndicTransConfig):
1133
+ super().__init__(config)
1134
+ self.model = IndicTransModel(config)
1135
+ self.lm_head = nn.Linear(config.decoder_embed_dim, config.decoder_vocab_size, bias=False)
1136
+
1137
+ if config.share_decoder_input_output_embed:
1138
+ self.lm_head.weight = self.model.decoder.embed_tokens.weight
1139
+
1140
+ self.post_init()
1141
+
1142
+ def tie_weights(self):
1143
+ pass
1144
+
1145
+ def get_encoder(self):
1146
+ return self.model.get_encoder()
1147
+
1148
+ def get_decoder(self):
1149
+ return self.model.get_decoder()
1150
+
1151
+ def get_output_embeddings(self):
1152
+ return self.lm_head
1153
+
1154
+ def set_output_embeddings(self, new_embeddings):
1155
+ self.lm_head = new_embeddings
1156
+
1157
+ def forward(
1158
+ self,
1159
+ input_ids: Optional[torch.LongTensor] = None,
1160
+ attention_mask: Optional[torch.Tensor] = None,
1161
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1162
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1163
+ head_mask: Optional[torch.Tensor] = None,
1164
+ decoder_head_mask: Optional[torch.Tensor] = None,
1165
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1166
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1167
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1168
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1169
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1170
+ labels: Optional[torch.LongTensor] = None,
1171
+ use_cache: Optional[bool] = None,
1172
+ output_attentions: Optional[bool] = None,
1173
+ output_hidden_states: Optional[bool] = None,
1174
+ return_dict: Optional[bool] = None,
1175
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1176
+ r"""
1177
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1178
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1179
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1180
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1181
+
1182
+ Returns:
1183
+ """
1184
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1185
+
1186
+ if labels is not None:
1187
+ if decoder_input_ids is None:
1188
+ decoder_input_ids = shift_tokens_right(
1189
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1190
+ )
1191
+
1192
+ outputs = self.model(
1193
+ input_ids,
1194
+ attention_mask=attention_mask,
1195
+ decoder_input_ids=decoder_input_ids,
1196
+ encoder_outputs=encoder_outputs,
1197
+ decoder_attention_mask=decoder_attention_mask,
1198
+ head_mask=head_mask,
1199
+ decoder_head_mask=decoder_head_mask,
1200
+ cross_attn_head_mask=cross_attn_head_mask,
1201
+ past_key_values=past_key_values,
1202
+ inputs_embeds=inputs_embeds,
1203
+ decoder_inputs_embeds=decoder_inputs_embeds,
1204
+ use_cache=use_cache,
1205
+ output_attentions=output_attentions,
1206
+ output_hidden_states=output_hidden_states,
1207
+ return_dict=return_dict,
1208
+ )
1209
+ lm_logits = self.lm_head(outputs[0])
1210
+
1211
+ masked_lm_loss = None
1212
+ if labels is not None:
1213
+ # move labels to the correct device to enable PP
1214
+ labels = labels.to(lm_logits.device)
1215
+ loss_fct = nn.CrossEntropyLoss()
1216
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1217
+
1218
+ if not return_dict:
1219
+ output = (lm_logits,) + outputs[1:]
1220
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1221
+
1222
+ return Seq2SeqLMOutput(
1223
+ loss=masked_lm_loss,
1224
+ logits=lm_logits,
1225
+ past_key_values=outputs.past_key_values,
1226
+ decoder_hidden_states=outputs.decoder_hidden_states,
1227
+ decoder_attentions=outputs.decoder_attentions,
1228
+ cross_attentions=outputs.cross_attentions,
1229
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1230
+ encoder_hidden_states=outputs.encoder_hidden_states,
1231
+ encoder_attentions=outputs.encoder_attentions,
1232
+ )
1233
+
1234
+ def prepare_inputs_for_generation(
1235
+ self,
1236
+ decoder_input_ids,
1237
+ past_key_values=None,
1238
+ attention_mask=None,
1239
+ head_mask=None,
1240
+ decoder_head_mask=None,
1241
+ cross_attn_head_mask=None,
1242
+ use_cache=None,
1243
+ encoder_outputs=None,
1244
+ **kwargs,
1245
+ ):
1246
+ # cut decoder_input_ids if past is used
1247
+ if past_key_values is not None:
1248
+ decoder_input_ids = decoder_input_ids[:, -1:]
1249
+
1250
+ return {
1251
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1252
+ "encoder_outputs": encoder_outputs,
1253
+ "past_key_values": past_key_values,
1254
+ "decoder_input_ids": decoder_input_ids,
1255
+ "attention_mask": attention_mask,
1256
+ "head_mask": head_mask,
1257
+ "decoder_head_mask": decoder_head_mask,
1258
+ "cross_attn_head_mask": cross_attn_head_mask,
1259
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1260
+ }
1261
+
1262
+ @staticmethod
1263
+ def _reorder_cache(past_key_values, beam_idx):
1264
+ reordered_past = ()
1265
+ for layer_past in past_key_values:
1266
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1267
+ return reordered_past
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52671a26e4e5f06a86b2f890dea7163db28c01fc1449a3da80d7cc990c41fef0
3
+ size 1098589001