jaygala24 commited on
Commit
258faca
1 Parent(s): e2012ed

Create configuration_indictrans.py

Browse files
Files changed (1) hide show
  1. configuration_indictrans.py +307 -0
configuration_indictrans.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans config."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, Mapping, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
24
+ from transformers.onnx.utils import compute_effective_axis_dimension
25
+ from transformers.utils import TensorType, is_torch_available
26
+
27
+
28
+ # Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
29
+ class IndicTransConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
32
+ IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the IT2
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50265):
41
+ Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`IT2Model`] or
43
+ d_model (`int`, *optional*, defaults to 1024):
44
+ Dimensionality of the layers and the pooler layer.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the model should return the last key/values attentions (not used by all models).
81
+ ```"""
82
+ model_type = "IndicTrans"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "num_attention_heads": "encoder_attention_heads",
86
+ "hidden_size": "d_model",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ encoder_vocab_size=None,
92
+ decoder_vocab_size=None,
93
+ encoder_embed_dim=512,
94
+ decoder_embed_dim=512,
95
+ max_source_positions=210,
96
+ max_target_positions=210,
97
+ encoder_layers=6,
98
+ encoder_ffn_dim=2048,
99
+ encoder_attention_heads=8,
100
+ decoder_layers=6,
101
+ decoder_ffn_dim=2048,
102
+ decoder_attention_heads=8,
103
+ encoder_layerdrop=0.00,
104
+ decoder_layerdrop=0.00,
105
+ use_cache=True,
106
+ is_encoder_decoder=True,
107
+ activation_function="relu",
108
+ encoder_normalize_before=False,
109
+ decoder_normalize_before=False,
110
+ layernorm_embedding=False,
111
+ share_decoder_input_output_embed=False,
112
+ dropout=0.1,
113
+ attention_dropout=0.0,
114
+ activation_dropout=0.0,
115
+ init_std=0.02,
116
+ scale_embedding=True,
117
+ decoder_start_token_id=2,
118
+ pad_token_id=1,
119
+ bos_token_id=0,
120
+ eos_token_id=2,
121
+ **kwargs,
122
+ ):
123
+ self.encoder_vocab_size = encoder_vocab_size
124
+ self.decoder_vocab_size = decoder_vocab_size
125
+ self.encoder_normalize_before = encoder_normalize_before
126
+ self.decoder_normalize_before = decoder_normalize_before
127
+ self.layernorm_embedding = layernorm_embedding
128
+ self.max_source_positions = max_source_positions
129
+ self.max_target_positions = max_target_positions
130
+ self.encoder_embed_dim = encoder_embed_dim
131
+ self.decoder_embed_dim = decoder_embed_dim
132
+ self.encoder_ffn_dim = encoder_ffn_dim
133
+ self.encoder_layers = encoder_layers
134
+ self.encoder_attention_heads = encoder_attention_heads
135
+ self.decoder_ffn_dim = decoder_ffn_dim
136
+ self.decoder_layers = decoder_layers
137
+ self.decoder_attention_heads = decoder_attention_heads
138
+ self.dropout = dropout
139
+ self.attention_dropout = attention_dropout
140
+ self.activation_dropout = activation_dropout
141
+ self.activation_function = activation_function
142
+ self.init_std = init_std
143
+ self.encoder_layerdrop = encoder_layerdrop
144
+ self.decoder_layerdrop = decoder_layerdrop
145
+ self.use_cache = use_cache
146
+ self.num_hidden_layers = encoder_layers
147
+ self.scale_embedding = scale_embedding
148
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
149
+
150
+ super().__init__(
151
+ pad_token_id=pad_token_id,
152
+ bos_token_id=bos_token_id,
153
+ eos_token_id=eos_token_id,
154
+ is_encoder_decoder=is_encoder_decoder,
155
+ decoder_start_token_id=decoder_start_token_id,
156
+ **kwargs,
157
+ )
158
+
159
+
160
+ class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
161
+ @property
162
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
163
+ common_inputs = OrderedDict(
164
+ [
165
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
166
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
167
+ ]
168
+ )
169
+
170
+ if self.use_past:
171
+ common_inputs["decoder_input_ids"] = {0: "batch"}
172
+ common_inputs["decoder_attention_mask"] = {
173
+ 0: "batch",
174
+ 1: "past_decoder_sequence + sequence",
175
+ }
176
+ else:
177
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
178
+ common_inputs["decoder_attention_mask"] = {
179
+ 0: "batch",
180
+ 1: "decoder_sequence",
181
+ }
182
+
183
+ if self.use_past:
184
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
185
+ return common_inputs
186
+
187
+ # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
188
+ # A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
189
+ # answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
190
+ # was done for BART so that it can be updated if need be.
191
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
192
+ self,
193
+ tokenizer: PreTrainedTokenizer,
194
+ batch_size: int = -1,
195
+ seq_length: int = -1,
196
+ is_pair: bool = False,
197
+ framework: Optional[TensorType] = None,
198
+ ) -> Mapping[str, Any]:
199
+ # Copied from OnnxConfig.generate_dummy_inputs
200
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
201
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
202
+ batch_size = compute_effective_axis_dimension(
203
+ batch_size,
204
+ fixed_dimension=OnnxConfig.default_fixed_batch,
205
+ num_token_to_add=0,
206
+ )
207
+
208
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
209
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
210
+ seq_length = compute_effective_axis_dimension(
211
+ seq_length,
212
+ fixed_dimension=OnnxConfig.default_fixed_sequence,
213
+ num_token_to_add=token_to_add,
214
+ )
215
+
216
+ # Generate dummy inputs according to compute batch and sequence
217
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
218
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
219
+ return common_inputs
220
+
221
+ # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
222
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
223
+ self,
224
+ tokenizer: PreTrainedTokenizer,
225
+ batch_size: int = -1,
226
+ seq_length: int = -1,
227
+ is_pair: bool = False,
228
+ framework: Optional[TensorType] = None,
229
+ ) -> Mapping[str, Any]:
230
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
231
+ tokenizer, batch_size, seq_length, is_pair, framework
232
+ )
233
+
234
+ # Generate decoder inputs
235
+ decoder_seq_length = seq_length if not self.use_past else 1
236
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
237
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
238
+ )
239
+ decoder_inputs = {
240
+ f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
241
+ }
242
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
243
+
244
+ if self.use_past:
245
+ if not is_torch_available():
246
+ raise ValueError(
247
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
248
+ )
249
+ else:
250
+ import torch
251
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
252
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
253
+ (
254
+ num_encoder_attention_heads,
255
+ num_decoder_attention_heads,
256
+ ) = self.num_attention_heads
257
+ encoder_shape = (
258
+ batch,
259
+ num_encoder_attention_heads,
260
+ encoder_seq_length,
261
+ self._config.hidden_size // num_encoder_attention_heads,
262
+ )
263
+ decoder_past_length = decoder_seq_length + 3
264
+ decoder_shape = (
265
+ batch,
266
+ num_decoder_attention_heads,
267
+ decoder_past_length,
268
+ self._config.hidden_size // num_decoder_attention_heads,
269
+ )
270
+
271
+ common_inputs["decoder_attention_mask"] = torch.cat(
272
+ [
273
+ common_inputs["decoder_attention_mask"],
274
+ torch.ones(batch, decoder_past_length),
275
+ ],
276
+ dim=1,
277
+ )
278
+
279
+ common_inputs["past_key_values"] = []
280
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
281
+ num_encoder_layers, num_decoder_layers = self.num_layers
282
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
283
+ max_num_layers = (
284
+ max(num_encoder_layers, num_decoder_layers) - min_num_layers
285
+ )
286
+ remaining_side_name = (
287
+ "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
288
+ )
289
+
290
+ for _ in range(min_num_layers):
291
+ common_inputs["past_key_values"].append(
292
+ (
293
+ torch.zeros(decoder_shape),
294
+ torch.zeros(decoder_shape),
295
+ torch.zeros(encoder_shape),
296
+ torch.zeros(encoder_shape),
297
+ )
298
+ )
299
+ # TODO: test this.
300
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
301
+ for _ in range(min_num_layers, max_num_layers):
302
+ common_inputs["past_key_values"].append(
303
+ (torch.zeros(shape), torch.zeros(shape))
304
+ )
305
+ return common_inputs
306
+
307
+ generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm