pavanhitloop commited on
Commit
9d452e1
1 Parent(s): 1521005

codebase added

Browse files
.gitattributes CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ IndicTransTokenizer/en-indic/model.TGT filter=lfs diff=lfs merge=lfs -text
37
+ *.TGT filter=lfs diff=lfs merge=lfs -text
38
+ *.SRC filter=lfs diff=lfs merge=lfs -text
39
+ *.SRC.json filter=lfs diff=lfs merge=lfs -text
40
+ *.TGT.json filter=lfs diff=lfs merge=lfs -text
41
+ IndicTransTokenizer/en-indic/dict.SRC.json filter=lfs diff=lfs merge=lfs -text
42
+ IndicTransTokenizer/en-indic/dict.TGT.json filter=lfs diff=lfs merge=lfs -text
43
+ IndicTransTokenizer/en-indic/model.SRC filter=lfs diff=lfs merge=lfs -text
44
+ IndicTransTokenizer/indic-en/dict.SRC.json filter=lfs diff=lfs merge=lfs -text
45
+ IndicTransTokenizer/indic-en/dict.TGT.json filter=lfs diff=lfs merge=lfs -text
46
+ IndicTransTokenizer/indic-en/model.SRC filter=lfs diff=lfs merge=lfs -text
47
+ IndicTransTokenizer/indic-en/model.TGT filter=lfs diff=lfs merge=lfs -text
IndicTransTokenizer/__init__.py ADDED
File without changes
IndicTransTokenizer/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (156 Bytes). View file
 
IndicTransTokenizer/__pycache__/tokenizer.cpython-39.pyc ADDED
Binary file (10.1 kB). View file
 
IndicTransTokenizer/__pycache__/utils.cpython-39.pyc ADDED
Binary file (14.1 kB). View file
 
IndicTransTokenizer/en-indic/dict.SRC.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99cabf338bf3db11eafae2769584b8b5d3aa579989feb7e9f72236bdf69810e1
3
+ size 645274
IndicTransTokenizer/en-indic/dict.TGT.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7817850c9e4b99c59fad57d0611c7720f1921f215e6f247cf25d52eff7f1146
3
+ size 3390440
IndicTransTokenizer/en-indic/model.SRC ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cedc5cbcc740369b76201942a0f096fec7287fee039b55bdb956f301235b914
3
+ size 759425
IndicTransTokenizer/en-indic/model.TGT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
IndicTransTokenizer/indic-en/dict.SRC.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f9cdb988b42c4e0f4fce5e44cc66975e5088a96d111a149b9ac7d55059b8ec1
3
+ size 3391183
IndicTransTokenizer/indic-en/dict.TGT.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13c3a162fe655dbe99c790a413675c5d0634cd771fadcefe8d407676a7d1a311
3
+ size 644755
IndicTransTokenizer/indic-en/model.SRC ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
IndicTransTokenizer/indic-en/model.TGT ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cedc5cbcc740369b76201942a0f096fec7287fee039b55bdb956f301235b914
3
+ size 759425
IndicTransTokenizer/tokenizer.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ from transformers import BatchEncoding
6
+ from typing import Dict, List, Tuple, Union
7
+ from sentencepiece import SentencePieceProcessor
8
+
9
+ _PATH = os.path.dirname(os.path.realpath(__file__))
10
+
11
+
12
+ class IndicTransTokenizer:
13
+ def __init__(
14
+ self,
15
+ src_vocab_fp=None,
16
+ tgt_vocab_fp=None,
17
+ src_spm_fp=None,
18
+ tgt_spm_fp=None,
19
+ unk_token="<unk>",
20
+ bos_token="<s>",
21
+ eos_token="</s>",
22
+ pad_token="<pad>",
23
+ direction="indic-en",
24
+ model_max_length=256,
25
+ ):
26
+ self.model_max_length = model_max_length
27
+
28
+ self.supported_langs = [
29
+ "asm_Beng",
30
+ "ben_Beng",
31
+ "brx_Deva",
32
+ "doi_Deva",
33
+ "eng_Latn",
34
+ "gom_Deva",
35
+ "guj_Gujr",
36
+ "hin_Deva",
37
+ "kan_Knda",
38
+ "kas_Arab",
39
+ "kas_Deva",
40
+ "mai_Deva",
41
+ "mal_Mlym",
42
+ "mar_Deva",
43
+ "mni_Beng",
44
+ "mni_Mtei",
45
+ "npi_Deva",
46
+ "ory_Orya",
47
+ "pan_Guru",
48
+ "san_Deva",
49
+ "sat_Olck",
50
+ "snd_Arab",
51
+ "snd_Deva",
52
+ "tam_Taml",
53
+ "tel_Telu",
54
+ "urd_Arab",
55
+ ]
56
+
57
+ self.src_vocab_fp = (
58
+ src_vocab_fp
59
+ if (src_vocab_fp is not None)
60
+ else os.path.join(_PATH, direction, "dict.SRC.json")
61
+ )
62
+ self.tgt_vocab_fp = (
63
+ tgt_vocab_fp
64
+ if (tgt_vocab_fp is not None)
65
+ else os.path.join(_PATH, direction, "dict.TGT.json")
66
+ )
67
+ self.src_spm_fp = (
68
+ src_spm_fp
69
+ if (src_spm_fp is not None)
70
+ else os.path.join(_PATH, direction, "model.SRC")
71
+ )
72
+ self.tgt_spm_fp = (
73
+ tgt_spm_fp
74
+ if (tgt_spm_fp is not None)
75
+ else os.path.join(_PATH, direction, "model.TGT")
76
+ )
77
+
78
+ self.unk_token = unk_token
79
+ self.pad_token = pad_token
80
+ self.eos_token = eos_token
81
+ self.bos_token = bos_token
82
+
83
+ self.encoder = self._load_json(self.src_vocab_fp)
84
+ if self.unk_token not in self.encoder:
85
+ raise KeyError("<unk> token must be in vocab")
86
+ assert self.pad_token in self.encoder
87
+ self.encoder_rev = {v: k for k, v in self.encoder.items()}
88
+
89
+ self.decoder = self._load_json(self.tgt_vocab_fp)
90
+ if self.unk_token not in self.encoder:
91
+ raise KeyError("<unk> token must be in vocab")
92
+ assert self.pad_token in self.encoder
93
+ self.decoder_rev = {v: k for k, v in self.decoder.items()}
94
+
95
+ # load SentencePiece model for pre-processing
96
+ self.src_spm = self._load_spm(self.src_spm_fp)
97
+ self.tgt_spm = self._load_spm(self.tgt_spm_fp)
98
+
99
+ def is_special_token(self, x: str):
100
+ return (x == self.pad_token) or (x == self.bos_token) or (x == self.eos_token)
101
+
102
+ def get_vocab_size(self, src: bool) -> int:
103
+ """Returns the size of the vocabulary"""
104
+ return len(self.encoder) if src else len(self.decoder)
105
+
106
+ def _load_spm(self, path: str) -> SentencePieceProcessor:
107
+ return SentencePieceProcessor(model_file=path)
108
+
109
+ def _save_json(self, data, path: str) -> None:
110
+ with open(path, "w", encoding="utf-8") as f:
111
+ json.dump(data, f, indent=2)
112
+
113
+ def _load_json(self, path: str) -> Union[Dict, List]:
114
+ with open(path, "r", encoding="utf-8") as f:
115
+ return json.load(f)
116
+
117
+ def _convert_token_to_id(self, token: str, src: bool) -> int:
118
+ """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
119
+ return (
120
+ self.encoder.get(token, self.encoder[self.unk_token])
121
+ if src
122
+ else self.decoder.get(token, self.encoder[self.unk_token])
123
+ )
124
+
125
+ def _convert_id_to_token(self, index: int, src: bool) -> str:
126
+ """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
127
+ return (
128
+ self.encoder_rev.get(index, self.unk_token)
129
+ if src
130
+ else self.decoder_rev.get(index, self.unk_token)
131
+ )
132
+
133
+ def _convert_tokens_to_string(self, tokens: List[str], src: bool) -> str:
134
+ """Uses sentencepiece model for detokenization"""
135
+ if src:
136
+ if tokens[0] in self.supported_langs and tokens[1] in self.supported_langs:
137
+ tokens = tokens[2:]
138
+ return " ".join(tokens)
139
+ else:
140
+ return " ".join(tokens)
141
+
142
+ def _remove_translation_tags(self, text: str) -> Tuple[List, str]:
143
+ """Removes the translation tags before text normalization and tokenization."""
144
+ tokens = text.split(" ")
145
+ return tokens[:2], " ".join(tokens[2:])
146
+
147
+ def _tokenize_src_line(self, line: str) -> List[str]:
148
+ """Tokenizes a source line."""
149
+ tags, text = self._remove_translation_tags(line)
150
+ tokens = self.src_spm.encode(text, out_type=str)
151
+ return tags + tokens
152
+
153
+ def _tokenize_tgt_line(self, line: str) -> List[str]:
154
+ """Tokenizes a target line."""
155
+ return self.tgt_spm.encode(line, out_type=str)
156
+
157
+ def tokenize(self, text: str, src: bool) -> List[str]:
158
+ """Tokenizes a string into tokens using the source/target vocabulary."""
159
+ return self._tokenize_src_line(text) if src else self._tokenize_tgt_line(text)
160
+
161
+ def batch_tokenize(self, batch: List[str], src: bool) -> List[List[str]]:
162
+ """Tokenizes a list of strings into tokens using the source/target vocabulary."""
163
+ return [self.tokenize(line, src) for line in batch]
164
+
165
+ def _create_attention_mask(self, ids: List[int], max_seq_len: int) -> List[int]:
166
+ """Creates a attention mask for the input sequence."""
167
+ return ([0] * (max_seq_len - len(ids))) + ([1] * (len(ids) + 1))
168
+
169
+ def _pad_batch(self, tokens: List[str], max_seq_len: int) -> List[str]:
170
+ """Pads a batch of tokens and adds BOS/EOS tokens."""
171
+ return (
172
+ ([self.pad_token] * (max_seq_len - len(tokens))) + tokens + [self.eos_token]
173
+ )
174
+
175
+ def _decode_line(self, ids: List[int], src: bool) -> List[str]:
176
+ return [self._convert_id_to_token(_id, src) for _id in ids]
177
+
178
+ def _encode_line(self, tokens: List[str], src: bool) -> List[int]:
179
+ return [self._convert_token_to_id(token, src) for token in tokens]
180
+
181
+ def _strip_special_tokens(self, tokens: List[str]) -> List[str]:
182
+ return [token for token in tokens if not self.is_special_token(token)]
183
+
184
+ def _single_input_preprocessing(
185
+ self, tokens: List[str], src: bool, max_seq_len: int
186
+ ) -> Tuple[List[int], List[int], int]:
187
+ """Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map."""
188
+ attention_mask = self._create_attention_mask(tokens, max_seq_len)
189
+ padded_tokens = self._pad_batch(tokens, max_seq_len)
190
+ input_ids = self._encode_line(padded_tokens, src)
191
+ return input_ids, attention_mask
192
+
193
+ def _single_output_postprocessing(self, ids: List[int], src: bool) -> str:
194
+ """Detokenizes a list of integer ids into a string using the source/target vocabulary."""
195
+ tokens = self._decode_line(ids, src)
196
+ tokens = self._strip_special_tokens(tokens)
197
+ return self._convert_tokens_to_string(tokens, src)
198
+
199
+ def __call__(
200
+ self,
201
+ batch: Union[list, str],
202
+ src: bool,
203
+ truncation: bool = False,
204
+ padding: str = "longest",
205
+ max_length: int = None,
206
+ return_tensors: str = "pt",
207
+ return_attention_mask: bool = True,
208
+ return_length: bool = False,
209
+ ) -> BatchEncoding:
210
+ """Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map."""
211
+ assert padding in [
212
+ "longest",
213
+ "max_length",
214
+ ], "padding should be either 'longest' or 'max_length'"
215
+
216
+ if not isinstance(batch, list):
217
+ raise TypeError(
218
+ f"batch must be a list, but current batch is of type {type(batch)}"
219
+ )
220
+
221
+ # tokenize the source sentences
222
+ batch = self.batch_tokenize(batch, src)
223
+
224
+ # truncate the sentences if needed
225
+ if truncation and max_length is not None:
226
+ batch = [ids[:max_length] for ids in batch]
227
+
228
+ lengths = [len(ids) for ids in batch]
229
+
230
+ max_seq_len = max(lengths) if padding == "longest" else max_length
231
+
232
+ input_ids, attention_mask = zip(
233
+ *[
234
+ self._single_input_preprocessing(
235
+ tokens=tokens, src=src, max_seq_len=max_seq_len
236
+ )
237
+ for tokens in batch
238
+ ]
239
+ )
240
+
241
+ _data = {"input_ids": input_ids}
242
+
243
+ if return_attention_mask:
244
+ _data["attention_mask"] = attention_mask
245
+
246
+ if return_length:
247
+ _data["lengths"] = lengths
248
+
249
+ return BatchEncoding(_data, tensor_type=return_tensors)
250
+
251
+ def batch_decode(
252
+ self, batch: Union[list, torch.Tensor], src: bool
253
+ ) -> List[List[str]]:
254
+ """Detokenizes a list of integer ids or a tensor into a list of strings using the source/target vocabulary."""
255
+
256
+ if isinstance(batch, torch.Tensor):
257
+ batch = batch.detach().cpu().tolist()
258
+
259
+ return [self._single_output_postprocessing(ids=ids, src=src) for ids in batch]
IndicTransTokenizer/utils.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+ from joblib import Parallel, delayed
3
+ from nltk.tokenize import sent_tokenize
4
+ from typing import List, Tuple, Union
5
+
6
+ from sacremoses import MosesPunctNormalizer
7
+ from indicnlp.normalize import indic_normalize
8
+ from sacremoses import MosesTokenizer, MosesDetokenizer
9
+ from indicnlp.transliterate import unicode_transliterate
10
+ from indicnlp.tokenize import indic_tokenize, indic_detokenize
11
+ from indicnlp.tokenize.sentence_tokenize import sentence_split, DELIM_PAT_NO_DANDA
12
+
13
+ en_tok = MosesTokenizer(lang="en")
14
+ en_normalizer = MosesPunctNormalizer()
15
+ en_detok = MosesDetokenizer(lang="en")
16
+ xliterator = unicode_transliterate.UnicodeIndicTransliterator()
17
+
18
+
19
+ flores_codes = {
20
+ "asm_Beng": "as",
21
+ "awa_Deva": "hi",
22
+ "ben_Beng": "bn",
23
+ "bho_Deva": "hi",
24
+ "brx_Deva": "hi",
25
+ "doi_Deva": "hi",
26
+ "eng_Latn": "en",
27
+ "gom_Deva": "kK",
28
+ "guj_Gujr": "gu",
29
+ "hin_Deva": "hi",
30
+ "hne_Deva": "hi",
31
+ "kan_Knda": "kn",
32
+ "kas_Arab": "ur",
33
+ "kas_Deva": "hi",
34
+ "kha_Latn": "en",
35
+ "lus_Latn": "en",
36
+ "mag_Deva": "hi",
37
+ "mai_Deva": "hi",
38
+ "mal_Mlym": "ml",
39
+ "mar_Deva": "mr",
40
+ "mni_Beng": "bn",
41
+ "mni_Mtei": "hi",
42
+ "npi_Deva": "ne",
43
+ "ory_Orya": "or",
44
+ "pan_Guru": "pa",
45
+ "san_Deva": "hi",
46
+ "sat_Olck": "or",
47
+ "snd_Arab": "ur",
48
+ "snd_Deva": "hi",
49
+ "tam_Taml": "ta",
50
+ "tel_Telu": "te",
51
+ "urd_Arab": "ur",
52
+ }
53
+
54
+
55
+ flores_to_iso = {
56
+ "asm_Beng": "as",
57
+ "awa_Deva": "awa",
58
+ "ben_Beng": "bn",
59
+ "bho_Deva": "bho",
60
+ "brx_Deva": "brx",
61
+ "doi_Deva": "doi",
62
+ "eng_Latn": "en",
63
+ "gom_Deva": "gom",
64
+ "guj_Gujr": "gu",
65
+ "hin_Deva": "hi",
66
+ "hne_Deva": "hne",
67
+ "kan_Knda": "kn",
68
+ "kas_Arab": "ksa",
69
+ "kas_Deva": "ksd",
70
+ "kha_Latn": "kha",
71
+ "lus_Latn": "lus",
72
+ "mag_Deva": "mag",
73
+ "mai_Deva": "mai",
74
+ "mal_Mlym": "ml",
75
+ "mar_Deva": "mr",
76
+ "mni_Beng": "mnib",
77
+ "mni_Mtei": "mnim",
78
+ "npi_Deva": "ne",
79
+ "ory_Orya": "or",
80
+ "pan_Guru": "pa",
81
+ "san_Deva": "sa",
82
+ "sat_Olck": "sat",
83
+ "snd_Arab": "sda",
84
+ "snd_Deva": "sdd",
85
+ "tam_Taml": "ta",
86
+ "tel_Telu": "te",
87
+ "urd_Arab": "ur",
88
+ }
89
+
90
+
91
+ INDIC_NUM_MAP = {
92
+ "\u09e6": "0",
93
+ "0": "0",
94
+ "\u0ae6": "0",
95
+ "\u0ce6": "0",
96
+ "\u0966": "0",
97
+ "\u0660": "0",
98
+ "\uabf0": "0",
99
+ "\u0b66": "0",
100
+ "\u0a66": "0",
101
+ "\u1c50": "0",
102
+ "\u06f0": "0",
103
+ "\u09e7": "1",
104
+ "1": "1",
105
+ "\u0ae7": "1",
106
+ "\u0967": "1",
107
+ "\u0ce7": "1",
108
+ "\u06f1": "1",
109
+ "\uabf1": "1",
110
+ "\u0b67": "1",
111
+ "\u0a67": "1",
112
+ "\u1c51": "1",
113
+ "\u0c67": "1",
114
+ "\u09e8": "2",
115
+ "2": "2",
116
+ "\u0ae8": "2",
117
+ "\u0968": "2",
118
+ "\u0ce8": "2",
119
+ "\u06f2": "2",
120
+ "\uabf2": "2",
121
+ "\u0b68": "2",
122
+ "\u0a68": "2",
123
+ "\u1c52": "2",
124
+ "\u0c68": "2",
125
+ "\u09e9": "3",
126
+ "3": "3",
127
+ "\u0ae9": "3",
128
+ "\u0969": "3",
129
+ "\u0ce9": "3",
130
+ "\u06f3": "3",
131
+ "\uabf3": "3",
132
+ "\u0b69": "3",
133
+ "\u0a69": "3",
134
+ "\u1c53": "3",
135
+ "\u0c69": "3",
136
+ "\u09ea": "4",
137
+ "4": "4",
138
+ "\u0aea": "4",
139
+ "\u096a": "4",
140
+ "\u0cea": "4",
141
+ "\u06f4": "4",
142
+ "\uabf4": "4",
143
+ "\u0b6a": "4",
144
+ "\u0a6a": "4",
145
+ "\u1c54": "4",
146
+ "\u0c6a": "4",
147
+ "\u09eb": "5",
148
+ "5": "5",
149
+ "\u0aeb": "5",
150
+ "\u096b": "5",
151
+ "\u0ceb": "5",
152
+ "\u06f5": "5",
153
+ "\uabf5": "5",
154
+ "\u0b6b": "5",
155
+ "\u0a6b": "5",
156
+ "\u1c55": "5",
157
+ "\u0c6b": "5",
158
+ "\u09ec": "6",
159
+ "6": "6",
160
+ "\u0aec": "6",
161
+ "\u096c": "6",
162
+ "\u0cec": "6",
163
+ "\u06f6": "6",
164
+ "\uabf6": "6",
165
+ "\u0b6c": "6",
166
+ "\u0a6c": "6",
167
+ "\u1c56": "6",
168
+ "\u0c6c": "6",
169
+ "\u09ed": "7",
170
+ "7": "7",
171
+ "\u0aed": "7",
172
+ "\u096d": "7",
173
+ "\u0ced": "7",
174
+ "\u06f7": "7",
175
+ "\uabf7": "7",
176
+ "\u0b6d": "7",
177
+ "\u0a6d": "7",
178
+ "\u1c57": "7",
179
+ "\u0c6d": "7",
180
+ "\u09ee": "8",
181
+ "8": "8",
182
+ "\u0aee": "8",
183
+ "\u096e": "8",
184
+ "\u0cee": "8",
185
+ "\u06f8": "8",
186
+ "\uabf8": "8",
187
+ "\u0b6e": "8",
188
+ "\u0a6e": "8",
189
+ "\u1c58": "8",
190
+ "\u0c6e": "8",
191
+ "\u09ef": "9",
192
+ "9": "9",
193
+ "\u0aef": "9",
194
+ "\u096f": "9",
195
+ "\u0cef": "9",
196
+ "\u06f9": "9",
197
+ "\uabf9": "9",
198
+ "\u0b6f": "9",
199
+ "\u0a6f": "9",
200
+ "\u1c59": "9",
201
+ "\u0c6f": "9",
202
+ }
203
+
204
+
205
+ multispace_regex = re.compile("[ ]{2,}")
206
+ end_bracket_space_punc_regex = re.compile(r"\) ([\.!:?;,])")
207
+ digit_space_percent = re.compile(r"(\d) %")
208
+ double_quot_punc = re.compile(r"\"([,\.]+)")
209
+ digit_nbsp_digit = re.compile(r"(\d) (\d)")
210
+
211
+
212
+ def punc_norm(text, lang="en"):
213
+ text = (
214
+ text.replace("\r", "")
215
+ .replace("(", " (")
216
+ .replace(")", ") ")
217
+ .replace("( ", "(")
218
+ .replace(" )", ")")
219
+ .replace(" :", ":")
220
+ .replace(" ;", ";")
221
+ .replace("`", "'")
222
+ .replace("„", '"')
223
+ .replace("“", '"')
224
+ .replace("”", '"')
225
+ .replace("–", "-")
226
+ .replace("—", " - ")
227
+ .replace("´", "'")
228
+ .replace("‘", "'")
229
+ .replace("‚", "'")
230
+ .replace("’", "'")
231
+ .replace("''", '"')
232
+ .replace("´´", '"')
233
+ .replace("…", "...")
234
+ .replace(" « ", ' "')
235
+ .replace("« ", '"')
236
+ .replace("«", '"')
237
+ .replace(" » ", '" ')
238
+ .replace(" »", '"')
239
+ .replace("»", '"')
240
+ .replace(" %", "%")
241
+ .replace("nº ", "nº ")
242
+ .replace(" :", ":")
243
+ .replace(" ºC", " ºC")
244
+ .replace(" cm", " cm")
245
+ .replace(" ?", "?")
246
+ .replace(" !", "!")
247
+ .replace(" ;", ";")
248
+ .replace(", ", ", ")
249
+ )
250
+
251
+ text = multispace_regex.sub(" ", text)
252
+ text = end_bracket_space_punc_regex.sub(r")\1", text)
253
+ text = digit_space_percent.sub(r"\1%", text)
254
+ text = double_quot_punc.sub(
255
+ r'\1"', text
256
+ ) # English "quotation," followed by comma, style
257
+ text = digit_nbsp_digit.sub(r"\1.\2", text) # What does it mean?
258
+ return text.strip(" ")
259
+
260
+
261
+ URL_PATTERN = r"\b(?<![\w/.])(?:(?:https?|ftp)://)?(?:(?:[\w-]+\.)+(?!\.))(?:[\w/\-?#&=%.]+)+(?!\.\w+)\b"
262
+ EMAIL_PATTERN = r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"
263
+ # handles dates, time, percentages, proportion, ratio, etc
264
+ NUMERAL_PATTERN = r"(~?\d+\.?\d*\s?%?\s?-?\s?~?\d+\.?\d*\s?%|~?\d+%|\d+[-\/.,:']\d+[-\/.,:'+]\d+(?:\.\d+)?|\d+[-\/.:'+]\d+(?:\.\d+)?)"
265
+ # handles upi, social media handles and hashtags
266
+ OTHER_PATTERN = r"[A-Za-z0-9]*[#|@]\w+"
267
+
268
+
269
+ def normalize_indic_numerals(line: str):
270
+ """
271
+ Normalize the numerals in Indic languages from native script to Roman script (if present).
272
+
273
+ Args:
274
+ line (str): an input string with Indic numerals to be normalized.
275
+
276
+ Returns:
277
+ str: an input string with the all Indic numerals normalized to Roman script.
278
+ """
279
+ return "".join([INDIC_NUM_MAP.get(c, c) for c in line])
280
+
281
+
282
+ def wrap_with_placeholders(text: str, patterns: list) -> Tuple[str, dict]:
283
+ """
284
+ Wraps substrings with matched patterns in the given text with placeholders and returns
285
+ the modified text along with a mapping of the placeholders to their original value.
286
+
287
+ Args:
288
+ text (str): an input string which needs to be wrapped with the placeholders.
289
+ pattern (list): list of patterns to search for in the input string.
290
+
291
+ Returns:
292
+ Tuple[str, dict]: a tuple containing the modified text and a dictionary mapping
293
+ placeholders to their original values.
294
+ """
295
+ serial_no = 1
296
+
297
+ placeholder_entity_map = dict()
298
+
299
+ for pattern in patterns:
300
+ matches = set(re.findall(pattern, text))
301
+
302
+ # wrap common match with placeholder tags
303
+ for match in matches:
304
+ if pattern == URL_PATTERN:
305
+ # Avoids false positive URL matches for names with initials.
306
+ temp = match.replace(".", "")
307
+ if len(temp) < 4:
308
+ continue
309
+ if pattern == NUMERAL_PATTERN:
310
+ # Short numeral patterns do not need placeholder based handling.
311
+ temp = match.replace(" ", "").replace(".", "").replace(":", "")
312
+ if len(temp) < 4:
313
+ continue
314
+
315
+ # Set of Translations of "ID" in all the suppported languages have been collated.
316
+ # This has been added to deal with edge cases where placeholders might get translated.
317
+ indic_failure_cases = [
318
+ "آی ڈی ",
319
+ "ꯑꯥꯏꯗꯤ",
320
+ "आईडी",
321
+ "आई . डी . ",
322
+ "आई . डी .",
323
+ "आई. डी. ",
324
+ "आई. डी.",
325
+ "ऐटि",
326
+ "آئی ڈی ",
327
+ "ᱟᱭᱰᱤ ᱾",
328
+ "आयडी",
329
+ "ऐडि",
330
+ "आइडि",
331
+ ]
332
+ placeholder = "<ID{}>".format(serial_no)
333
+ alternate_placeholder = "< ID{} >".format(serial_no)
334
+ placeholder_entity_map[placeholder] = match
335
+ placeholder_entity_map[alternate_placeholder] = match
336
+ placeholder = "<ID{}]".format(serial_no)
337
+ alternate_placeholder = "< ID{} ]".format(serial_no)
338
+ placeholder_entity_map[placeholder] = match
339
+ placeholder_entity_map[alternate_placeholder] = match
340
+
341
+ for i in indic_failure_cases:
342
+ placeholder_temp = "<{}{}>".format(i, serial_no)
343
+ placeholder_entity_map[placeholder_temp] = match
344
+ placeholder_temp = "< {}{} >".format(i, serial_no)
345
+ placeholder_entity_map[placeholder_temp] = match
346
+ placeholder_temp = "< {} {} >".format(i, serial_no)
347
+ placeholder_entity_map[placeholder_temp] = match
348
+ placeholder_temp = "<{} {}]".format(i, serial_no)
349
+ placeholder_entity_map[placeholder_temp] = match
350
+ placeholder_temp = "< {} {} ]".format(i, serial_no)
351
+ placeholder_entity_map[placeholder_temp] = match
352
+ placeholder_temp = "[{} {}]".format(i, serial_no)
353
+ placeholder_entity_map[placeholder_temp] = match
354
+ placeholder_temp = "[ {} {} ]".format(i, serial_no)
355
+ placeholder_entity_map[placeholder_temp] = match
356
+
357
+ text = text.replace(match, placeholder)
358
+ serial_no += 1
359
+
360
+ text = re.sub("\s+", " ", text)
361
+
362
+ # Regex has failure cases in trailing "/" in URLs, so this is a workaround.
363
+ text = text.replace(">/", ">")
364
+ text = text.replace("]/", "]")
365
+
366
+ return text, placeholder_entity_map
367
+
368
+
369
+ def normalize(
370
+ text: str,
371
+ patterns: list = [EMAIL_PATTERN, URL_PATTERN, NUMERAL_PATTERN, OTHER_PATTERN],
372
+ ) -> Tuple[str, dict]:
373
+ """
374
+ Normalizes and wraps the spans of input string with placeholder tags. It first normalizes
375
+ the Indic numerals in the input string to Roman script. Later, it uses the input string with normalized
376
+ Indic numerals to wrap the spans of text matching the pattern with placeholder tags.
377
+
378
+ Args:
379
+ text (str): input string.
380
+ pattern (list): list of patterns to search for in the input string.
381
+
382
+ Returns:
383
+ Tuple[str, dict]: a tuple containing the modified text and a dictionary mapping
384
+ placeholders to their original values.
385
+ """
386
+ text = normalize_indic_numerals(text.strip("\n"))
387
+ text, placeholder_entity_map = wrap_with_placeholders(text, patterns)
388
+ return text, placeholder_entity_map
389
+
390
+
391
+ def split_sentences(paragraph: str, lang: str) -> List[str]:
392
+ """
393
+ Splits the input text paragraph into sentences. It uses `moses` for English and
394
+ `indic-nlp` for Indic languages.
395
+
396
+ Args:
397
+ paragraph (str): input text paragraph.
398
+ lang (str): flores language code.
399
+
400
+ Returns:
401
+ List[str] -> list of sentences.
402
+ """
403
+ # fails to handle sentence splitting in case of
404
+ # with MosesSentenceSplitter(lang) as splitter:
405
+ # return splitter([paragraph])
406
+ return (
407
+ sent_tokenize(paragraph)
408
+ if lang == "eng_Latn"
409
+ else sentence_split(
410
+ paragraph, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA
411
+ )
412
+ )
413
+
414
+
415
+ def apply_lang_tags(sents: List[str], src_lang: str, tgt_lang: str) -> List[str]:
416
+ """
417
+ Add special tokens indicating source and target language to the start of the each input sentence.
418
+ Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
419
+
420
+ Args:
421
+ sent (str): input sentence to be translated.
422
+ src_lang (str): flores lang code of the input sentence.
423
+ tgt_lang (str): flores lang code in which the input sentence will be translated.
424
+
425
+ Returns:
426
+ List[str]: list of input sentences with the special tokens added to the start.
427
+ """
428
+ return Parallel(n_jobs=-1)(
429
+ delayed(lambda x: f"{src_lang} {tgt_lang} {x.strip()}")(sent) for sent in sents
430
+ )
431
+
432
+
433
+ def preprocess_sent(
434
+ sent: str,
435
+ normalizer: Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory],
436
+ lang: str,
437
+ ) -> str:
438
+ """
439
+ Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it.
440
+
441
+ Args:
442
+ sent (str): input text sentence to preprocess.
443
+ normalizer (Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory]): an object that performs normalization on the text.
444
+ lang (str): flores language code of the input text sentence.
445
+
446
+ Returns:
447
+ Tuple[str, dict]: a tuple of preprocessed input text sentence and also a corresponding dictionary
448
+ mapping placeholders to their original values.
449
+ """
450
+ iso_lang = flores_codes[lang]
451
+ sent = punc_norm(sent, iso_lang)
452
+ sent, placeholder_entity_map = normalize(sent)
453
+
454
+ transliterate = True
455
+ if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]:
456
+ transliterate = False
457
+
458
+ if iso_lang == "en":
459
+ processed_sent = " ".join(
460
+ en_tok.tokenize(en_normalizer.normalize(sent.strip()), escape=False)
461
+ )
462
+ elif transliterate:
463
+ # transliterates from the any specific language to devanagari
464
+ # which is why we specify lang2_code as "hi".
465
+ processed_sent = xliterator.transliterate(
466
+ " ".join(
467
+ indic_tokenize.trivial_tokenize(
468
+ normalizer.normalize(sent.strip()), iso_lang
469
+ )
470
+ ),
471
+ iso_lang,
472
+ "hi",
473
+ ).replace(" ् ", "्")
474
+ else:
475
+ # we only need to transliterate for joint training
476
+ processed_sent = " ".join(
477
+ indic_tokenize.trivial_tokenize(
478
+ normalizer.normalize(sent.strip()), iso_lang
479
+ )
480
+ )
481
+
482
+ return processed_sent, placeholder_entity_map
483
+
484
+
485
+ def preprocess(sents: List[str], lang: str):
486
+ """
487
+ Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it.
488
+
489
+ Args:
490
+ batch (List[str]): input list of sentences to preprocess.
491
+ lang (str): flores language code of the input text sentences.
492
+
493
+ Returns:
494
+ Tuple[List[str], List[dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
495
+ mapping placeholders to their original values.
496
+ """
497
+
498
+ normalizer = (
499
+ indic_normalize.IndicNormalizerFactory().get_normalizer(flores_codes[lang])
500
+ if lang != "eng_Latn"
501
+ else None
502
+ )
503
+
504
+ processed_sents, placeholder_entity_map_sents = zip(
505
+ *[preprocess_sent(sent, normalizer, lang) for sent in sents]
506
+ )
507
+
508
+ return processed_sents, placeholder_entity_map_sents
509
+
510
+
511
+ def preprocess_batch(batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
512
+ """
513
+ Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the
514
+ normalized text sequences using sentence piece tokenizer and also adds language tags.
515
+
516
+ Args:
517
+ batch (List[str]): input list of sentences to preprocess.
518
+ src_lang (str): flores language code of the input text sentences.
519
+ tgt_lang (str): flores language code of the output text sentences.
520
+
521
+ Returns:
522
+ Tuple[List[str], List[dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
523
+ mapping placeholders to their original values.
524
+ """
525
+ preprocessed_sents, placeholder_entity_map_sents = preprocess(batch, lang=src_lang)
526
+ tagged_sents = apply_lang_tags(preprocessed_sents, src_lang, tgt_lang)
527
+ return tagged_sents, placeholder_entity_map_sents
528
+
529
+
530
+ def postprocess_batch(
531
+ sents: List[str],
532
+ placeholder_entity_map: List[dict],
533
+ lang: str,
534
+ common_lang: str = "hin_Deva",
535
+ ) -> List[str]:
536
+ """
537
+ Postprocesses a batch of input sentences after the translation generations.
538
+
539
+ Args:
540
+ sents (List[str]): batch of translated sentences to postprocess.
541
+ placeholder_entity_map (List[dict]): dictionary mapping placeholders to the original entity values.
542
+ lang (str): flores language code of the input sentences.
543
+ common_lang (str, optional): flores language code of the transliterated language (defaults: hin_Deva).
544
+
545
+ Returns:
546
+ List[str]: postprocessed batch of input sentences.
547
+ """
548
+
549
+ lang_code, script_code = lang.split("_")
550
+
551
+ for i in range(len(sents)):
552
+ sents[i] = sents[i].replace(" ", "").replace("▁", " ").strip()
553
+
554
+ # Fixes for Perso-Arabic scripts
555
+ # TODO: Move these normalizations inside indic-nlp-library
556
+ if script_code in {"Arab", "Aran"}:
557
+ # UrduHack adds space before punctuations. Since the model was trained without fixing this issue, let's fix it now
558
+ sents[i] = sents[i].replace(" ؟", "؟").replace(" ۔", "۔").replace(" ،", "،")
559
+ # Kashmiri bugfix for palatalization: https://github.com/AI4Bharat/IndicTrans2/issues/11
560
+ sents[i] = sents[i].replace("ٮ۪", "ؠ")
561
+
562
+ # Oriya bug: indic-nlp-library produces ଯ଼ instead of ୟ when converting from Devanagari to Odia
563
+ # TODO: Find out what's the issue with unicode transliterator for Oriya and fix it
564
+ if lang_code == "or":
565
+ sents[i] = sents[i].replace("ଯ଼", "ୟ")
566
+
567
+ assert len(sents) == len(placeholder_entity_map)
568
+
569
+ # Replace the placeholders entity
570
+ for i in range(0, len(sents)):
571
+ for key in placeholder_entity_map[i].keys():
572
+ sents[i] = sents[i].replace(key, placeholder_entity_map[i][key])
573
+
574
+ # Detokenize and transliterate to native scripts if applicable
575
+
576
+ if lang == "eng_Latn":
577
+ postprocessed_sents = [en_detok.detokenize(sent.split(" ")) for sent in sents]
578
+ else:
579
+ postprocessed_sents = [
580
+ indic_detokenize.trivial_detokenize(
581
+ xliterator.transliterate(
582
+ s, flores_codes[common_lang], flores_codes[lang]
583
+ ),
584
+ flores_codes[lang],
585
+ )
586
+ for s in sents
587
+ ]
588
+
589
+ assert len(postprocessed_sents) == len(placeholder_entity_map)
590
+
591
+ return postprocessed_sents
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IndicTrans2 HF Compatible Models
2
+
3
+ In this section, we provide details on how to use our [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2) models which were originally trained with the [fairseq](https://github.com/facebookresearch/fairseq) to [HuggingFace transformers](https://huggingface.co/docs/transformers/index) for inference purpose. Our scripts for HuggingFace compatible models are adapted from [M2M100 repository](https://github.com/huggingface/transformers/tree/main/src/transformers/models/m2m_100).
4
+
5
+
6
+ ### Setup
7
+
8
+ To get started, follow these steps to set up the environment:
9
+
10
+ ```
11
+ # Clone the github repository and navigate to the project directory.
12
+ git clone https://github.com/AI4Bharat/IndicTrans2
13
+ cd IndicTrans2
14
+
15
+ # Install all the dependencies and requirements associated with the project for running HF compatible models.
16
+ source install.sh
17
+ ```
18
+
19
+ > Note: The `install.sh` script in this directory is specifically for running HF compatible models for inference.
20
+
21
+
22
+ ### Converting
23
+
24
+ In order to convert the fairseq checkpoint to a PyTorch checkpoint that is compatible with HuggingFace Transformers, use the following command:
25
+
26
+ ```bash
27
+ python3 convert_indictrans_checkpoint_to_pytorch.py --fairseq_path <fairseq_checkpoint_best.pt> --pytorch_dump_folder_path <hf_output_dir>
28
+ ```
29
+ - `<fairseq_checkpoint_best.pt>`: path to the fairseq `checkpoint_best.pt` that needs to be converted to HF compatible models
30
+ - `<hf_output_dir>`: path to the output directory where the HF compatible models will be saved
31
+
32
+
33
+ ### Models
34
+
35
+ | Model | 🤗 HuggingFace Checkpoints |
36
+ |----------|-----------------------------------|
37
+ | Preprint En-Indic | [ai4bharat/indictrans2-en-indic-1B](https://huggingface.co/ai4bharat/indictrans2-en-indic-1B) |
38
+ | Preprint Indic-En | [ai4bharat/indictrans2-indic-en-1B](https://huggingface.co/ai4bharat/indictrans2-indic-en-1B) |
39
+
40
+
41
+ ### Inference
42
+
43
+ With the conversion complete, you can now perform inference using the HuggingFace Transformers.
44
+
45
+ You can start with the provided `example.py` script and customize it for your specific translation use case:
46
+
47
+ ```bash
48
+ python3 example.py
49
+ ```
50
+
51
+ Feel free to modify the `example.py` script to suit your translation needs.
52
+
53
+ ### Citation
54
+
55
+ ```
56
+ @article{ai4bharat2023indictrans2,
57
+ title = {IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
58
+ author = {AI4Bharat and Jay Gala and Pranjal A. Chitale and Raghavan AK and Sumanth Doddapaneni and Varun Gumma and Aswanth Kumar and Janki Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M. Khapra and Raj Dabre and Anoop Kunchukuttan},
59
+ year = {2023},
60
+ journal = {arXiv preprint arXiv: 2305.16307}
61
+ }
62
+ ```
configuration_indictrans.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans config."""
16
+
17
+
18
+ from collections import OrderedDict
19
+ from typing import Any, Mapping, Optional
20
+
21
+ from transformers import PreTrainedTokenizer
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
24
+ from transformers.onnx.utils import compute_effective_axis_dimension
25
+ from transformers.utils import TensorType, is_torch_available
26
+
27
+
28
+ # Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
29
+ class IndicTransConfig(PretrainedConfig):
30
+ r"""
31
+ This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
32
+ IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
+ with the defaults will yield a similar configuration to that of the IT2
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 50265):
41
+ Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`IT2Model`] or
43
+ d_model (`int`, *optional*, defaults to 1024):
44
+ Dimensionality of the layers and the pooler layer.
45
+ encoder_layers (`int`, *optional*, defaults to 12):
46
+ Number of encoder layers.
47
+ decoder_layers (`int`, *optional*, defaults to 12):
48
+ Number of decoder layers.
49
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
52
+ Number of attention heads for each attention layer in the Transformer decoder.
53
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
54
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
55
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
56
+ Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
57
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
58
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
59
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
60
+ dropout (`float`, *optional*, defaults to 0.1):
61
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
62
+ attention_dropout (`float`, *optional*, defaults to 0.0):
63
+ The dropout ratio for the attention probabilities.
64
+ activation_dropout (`float`, *optional*, defaults to 0.0):
65
+ The dropout ratio for activations inside the fully connected layer.
66
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
67
+ The dropout ratio for classifier.
68
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
69
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
70
+ just in case (e.g., 512 or 1024 or 2048).
71
+ init_std (`float`, *optional*, defaults to 0.02):
72
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
73
+ encoder_layerdrop (`float`, *optional*, defaults to 0.0):
74
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
75
+ for more details.
76
+ decoder_layerdrop (`float`, *optional*, defaults to 0.0):
77
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
78
+ for more details.
79
+ use_cache (`bool`, *optional*, defaults to `True`):
80
+ Whether or not the model should return the last key/values attentions (not used by all models).
81
+ ```"""
82
+ model_type = "IndicTrans"
83
+ keys_to_ignore_at_inference = ["past_key_values"]
84
+ attribute_map = {
85
+ "num_attention_heads": "encoder_attention_heads",
86
+ "hidden_size": "d_model",
87
+ }
88
+
89
+ def __init__(
90
+ self,
91
+ encoder_vocab_size=None,
92
+ decoder_vocab_size=None,
93
+ encoder_embed_dim=512,
94
+ decoder_embed_dim=512,
95
+ max_source_positions=210,
96
+ max_target_positions=210,
97
+ encoder_layers=6,
98
+ encoder_ffn_dim=2048,
99
+ encoder_attention_heads=8,
100
+ decoder_layers=6,
101
+ decoder_ffn_dim=2048,
102
+ decoder_attention_heads=8,
103
+ encoder_layerdrop=0.00,
104
+ decoder_layerdrop=0.00,
105
+ use_cache=True,
106
+ is_encoder_decoder=True,
107
+ activation_function="relu",
108
+ encoder_normalize_before=False,
109
+ decoder_normalize_before=False,
110
+ layernorm_embedding=False,
111
+ share_decoder_input_output_embed=False,
112
+ dropout=0.1,
113
+ attention_dropout=0.0,
114
+ activation_dropout=0.0,
115
+ init_std=0.02,
116
+ scale_embedding=True,
117
+ decoder_start_token_id=2,
118
+ pad_token_id=1,
119
+ bos_token_id=0,
120
+ eos_token_id=2,
121
+ **kwargs,
122
+ ):
123
+ self.encoder_vocab_size = encoder_vocab_size
124
+ self.decoder_vocab_size = decoder_vocab_size
125
+ self.encoder_normalize_before = encoder_normalize_before
126
+ self.decoder_normalize_before = decoder_normalize_before
127
+ self.layernorm_embedding = layernorm_embedding
128
+ self.max_source_positions = max_source_positions
129
+ self.max_target_positions = max_target_positions
130
+ self.encoder_embed_dim = encoder_embed_dim
131
+ self.decoder_embed_dim = decoder_embed_dim
132
+ self.encoder_ffn_dim = encoder_ffn_dim
133
+ self.encoder_layers = encoder_layers
134
+ self.encoder_attention_heads = encoder_attention_heads
135
+ self.decoder_ffn_dim = decoder_ffn_dim
136
+ self.decoder_layers = decoder_layers
137
+ self.decoder_attention_heads = decoder_attention_heads
138
+ self.dropout = dropout
139
+ self.attention_dropout = attention_dropout
140
+ self.activation_dropout = activation_dropout
141
+ self.activation_function = activation_function
142
+ self.init_std = init_std
143
+ self.encoder_layerdrop = encoder_layerdrop
144
+ self.decoder_layerdrop = decoder_layerdrop
145
+ self.use_cache = use_cache
146
+ self.num_hidden_layers = encoder_layers
147
+ self.scale_embedding = scale_embedding
148
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
149
+
150
+ super().__init__(
151
+ pad_token_id=pad_token_id,
152
+ bos_token_id=bos_token_id,
153
+ eos_token_id=eos_token_id,
154
+ is_encoder_decoder=is_encoder_decoder,
155
+ decoder_start_token_id=decoder_start_token_id,
156
+ **kwargs,
157
+ )
158
+
159
+
160
+ class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
161
+ @property
162
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
163
+ common_inputs = OrderedDict(
164
+ [
165
+ ("input_ids", {0: "batch", 1: "encoder_sequence"}),
166
+ ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
167
+ ]
168
+ )
169
+
170
+ if self.use_past:
171
+ common_inputs["decoder_input_ids"] = {0: "batch"}
172
+ common_inputs["decoder_attention_mask"] = {
173
+ 0: "batch",
174
+ 1: "past_decoder_sequence + sequence",
175
+ }
176
+ else:
177
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
178
+ common_inputs["decoder_attention_mask"] = {
179
+ 0: "batch",
180
+ 1: "decoder_sequence",
181
+ }
182
+
183
+ if self.use_past:
184
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
185
+ return common_inputs
186
+
187
+ # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
188
+ # A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
189
+ # answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
190
+ # was done for BART so that it can be updated if need be.
191
+ def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
192
+ self,
193
+ tokenizer: PreTrainedTokenizer,
194
+ batch_size: int = -1,
195
+ seq_length: int = -1,
196
+ is_pair: bool = False,
197
+ framework: Optional[TensorType] = None,
198
+ ) -> Mapping[str, Any]:
199
+ # Copied from OnnxConfig.generate_dummy_inputs
200
+ # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
201
+ # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
202
+ batch_size = compute_effective_axis_dimension(
203
+ batch_size,
204
+ fixed_dimension=OnnxConfig.default_fixed_batch,
205
+ num_token_to_add=0,
206
+ )
207
+
208
+ # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
209
+ token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
210
+ seq_length = compute_effective_axis_dimension(
211
+ seq_length,
212
+ fixed_dimension=OnnxConfig.default_fixed_sequence,
213
+ num_token_to_add=token_to_add,
214
+ )
215
+
216
+ # Generate dummy inputs according to compute batch and sequence
217
+ dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
218
+ common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
219
+ return common_inputs
220
+
221
+ # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
222
+ def _generate_dummy_inputs_for_default_and_seq2seq_lm(
223
+ self,
224
+ tokenizer: PreTrainedTokenizer,
225
+ batch_size: int = -1,
226
+ seq_length: int = -1,
227
+ is_pair: bool = False,
228
+ framework: Optional[TensorType] = None,
229
+ ) -> Mapping[str, Any]:
230
+ encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
231
+ tokenizer, batch_size, seq_length, is_pair, framework
232
+ )
233
+
234
+ # Generate decoder inputs
235
+ decoder_seq_length = seq_length if not self.use_past else 1
236
+ decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
237
+ tokenizer, batch_size, decoder_seq_length, is_pair, framework
238
+ )
239
+ decoder_inputs = {
240
+ f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
241
+ }
242
+ common_inputs = dict(**encoder_inputs, **decoder_inputs)
243
+
244
+ if self.use_past:
245
+ if not is_torch_available():
246
+ raise ValueError(
247
+ "Cannot generate dummy past_keys inputs without PyTorch installed."
248
+ )
249
+ else:
250
+ import torch
251
+ batch, encoder_seq_length = common_inputs["input_ids"].shape
252
+ decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
253
+ (
254
+ num_encoder_attention_heads,
255
+ num_decoder_attention_heads,
256
+ ) = self.num_attention_heads
257
+ encoder_shape = (
258
+ batch,
259
+ num_encoder_attention_heads,
260
+ encoder_seq_length,
261
+ self._config.hidden_size // num_encoder_attention_heads,
262
+ )
263
+ decoder_past_length = decoder_seq_length + 3
264
+ decoder_shape = (
265
+ batch,
266
+ num_decoder_attention_heads,
267
+ decoder_past_length,
268
+ self._config.hidden_size // num_decoder_attention_heads,
269
+ )
270
+
271
+ common_inputs["decoder_attention_mask"] = torch.cat(
272
+ [
273
+ common_inputs["decoder_attention_mask"],
274
+ torch.ones(batch, decoder_past_length),
275
+ ],
276
+ dim=1,
277
+ )
278
+
279
+ common_inputs["past_key_values"] = []
280
+ # If the number of encoder and decoder layers are present in the model configuration, both are considered
281
+ num_encoder_layers, num_decoder_layers = self.num_layers
282
+ min_num_layers = min(num_encoder_layers, num_decoder_layers)
283
+ max_num_layers = (
284
+ max(num_encoder_layers, num_decoder_layers) - min_num_layers
285
+ )
286
+ remaining_side_name = (
287
+ "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
288
+ )
289
+
290
+ for _ in range(min_num_layers):
291
+ common_inputs["past_key_values"].append(
292
+ (
293
+ torch.zeros(decoder_shape),
294
+ torch.zeros(decoder_shape),
295
+ torch.zeros(encoder_shape),
296
+ torch.zeros(encoder_shape),
297
+ )
298
+ )
299
+ # TODO: test this.
300
+ shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
301
+ for _ in range(min_num_layers, max_num_layers):
302
+ common_inputs["past_key_values"].append(
303
+ (torch.zeros(shape), torch.zeros(shape))
304
+ )
305
+ return common_inputs
306
+
307
+ generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
convert_indictrans_checkpoint_to_pytorch.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import argparse
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from configuration_indictrans import IndicTransConfig
21
+ from modeling_indictrans import IndicTransForConditionalGeneration
22
+
23
+
24
+ def remove_ignore_keys_(state_dict):
25
+ ignore_keys = [
26
+ "encoder.version",
27
+ "decoder.version",
28
+ "model.encoder.version",
29
+ "model.decoder.version",
30
+ "_float_tensor",
31
+ "encoder.embed_positions._float_tensor",
32
+ "decoder.embed_positions._float_tensor",
33
+ ]
34
+ for k in ignore_keys:
35
+ state_dict.pop(k, None)
36
+
37
+
38
+ def make_linear_from_emb(emb):
39
+ vocab_size, emb_size = emb.shape
40
+ lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
41
+ lin_layer.weight.data = emb.data
42
+ return lin_layer
43
+
44
+
45
+ def convert_fairseq_IT2_checkpoint_from_disk(checkpoint_path):
46
+ model = torch.load(checkpoint_path, map_location="cpu")
47
+ args = model["args"] or model["cfg"]["model"]
48
+ state_dict = model["model"]
49
+ remove_ignore_keys_(state_dict)
50
+ encoder_vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
51
+ decoder_vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0]
52
+
53
+ config = IndicTransConfig(
54
+ encoder_vocab_size=encoder_vocab_size,
55
+ decoder_vocab_size=decoder_vocab_size,
56
+ max_source_positions=args.max_source_positions,
57
+ max_target_positions=args.max_target_positions,
58
+ encoder_layers=args.encoder_layers,
59
+ decoder_layers=args.decoder_layers,
60
+ layernorm_embedding=args.layernorm_embedding,
61
+ encoder_normalize_before=args.encoder_normalize_before,
62
+ decoder_normalize_before=args.decoder_normalize_before,
63
+ encoder_attention_heads=args.encoder_attention_heads,
64
+ decoder_attention_heads=args.decoder_attention_heads,
65
+ encoder_ffn_dim=args.encoder_ffn_embed_dim,
66
+ decoder_ffn_dim=args.decoder_ffn_embed_dim,
67
+ encoder_embed_dim=args.encoder_embed_dim,
68
+ decoder_embed_dim=args.decoder_embed_dim,
69
+ encoder_layerdrop=args.encoder_layerdrop,
70
+ decoder_layerdrop=args.decoder_layerdrop,
71
+ dropout=args.dropout,
72
+ attention_dropout=args.attention_dropout,
73
+ activation_dropout=args.activation_dropout,
74
+ activation_function=args.activation_fn,
75
+ share_decoder_input_output_embed=args.share_decoder_input_output_embed,
76
+ scale_embedding=not args.no_scale_embedding,
77
+ )
78
+
79
+ model = IndicTransForConditionalGeneration(config)
80
+ model.model.load_state_dict(state_dict, strict=False)
81
+ if not args.share_decoder_input_output_embed:
82
+ model.lm_head = make_linear_from_emb(
83
+ state_dict["decoder.output_projection.weight"]
84
+ )
85
+ print(model)
86
+ return model
87
+
88
+
89
+ if __name__ == "__main__":
90
+ parser = argparse.ArgumentParser()
91
+ # Required parameters
92
+ parser.add_argument(
93
+ "--fairseq_path",
94
+ default="indic-en/model/checkpoint_best.pt",
95
+ type=str,
96
+ help="path to a model.pt on local filesystem.",
97
+ )
98
+ parser.add_argument(
99
+ "--pytorch_dump_folder_path",
100
+ default="indic-en/hf_model",
101
+ type=str,
102
+ help="Path to the output PyTorch model.",
103
+ )
104
+
105
+ args = parser.parse_args()
106
+ model = convert_fairseq_IT2_checkpoint_from_disk(args.fairseq_path)
107
+ model.save_pretrained(args.pytorch_dump_folder_path)
example.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
4
+ from IndicTransTokenizer.utils import preprocess_batch, postprocess_batch
5
+ from IndicTransTokenizer.tokenizer import IndicTransTokenizer
6
+
7
+ en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"
8
+
9
+ BATCH_SIZE = 16
10
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+ if len(sys.argv)>1:
13
+ quantization = sys.argv[1]
14
+ else:
15
+ quantization = ""
16
+
17
+
18
+ def initialize_model_and_tokenizer(ckpt_dir, direction, quantization):
19
+ if quantization == "4-bit":
20
+ qconfig = BitsAndBytesConfig(
21
+ load_in_4bit=True,
22
+ bnb_4bit_use_double_quant=True,
23
+ bnb_4bit_compute_dtype=torch.bfloat16,
24
+ )
25
+ elif quantization == "8-bit":
26
+ qconfig = BitsAndBytesConfig(
27
+ load_in_8bit=True,
28
+ bnb_8bit_use_double_quant=True,
29
+ bnb_8bit_compute_dtype=torch.bfloat16,
30
+ )
31
+ else:
32
+ qconfig = None
33
+
34
+ tokenizer = IndicTransTokenizer(direction=direction)
35
+ model = AutoModelForSeq2SeqLM.from_pretrained(
36
+ ckpt_dir,
37
+ trust_remote_code=True,
38
+ low_cpu_mem_usage=True,
39
+ quantization_config=qconfig
40
+ )
41
+
42
+ if qconfig==None:
43
+ model = model.to(DEVICE)
44
+ model.half()
45
+
46
+ model.eval()
47
+
48
+ return tokenizer, model
49
+
50
+
51
+ def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer):
52
+ translations = []
53
+ for i in range(0, len(input_sentences), BATCH_SIZE):
54
+ batch = input_sentences[i : i + BATCH_SIZE]
55
+
56
+ # Preprocess the batch and extract entity mappings
57
+ batch, entity_map = preprocess_batch(
58
+ batch, src_lang=src_lang, tgt_lang=tgt_lang
59
+ )
60
+
61
+ # Tokenize the batch and generate input encodings
62
+ inputs = tokenizer(
63
+ batch,
64
+ src=True,
65
+ truncation=True,
66
+ padding="longest",
67
+ return_tensors="pt",
68
+ return_attention_mask=True,
69
+ ).to(DEVICE)
70
+
71
+ # Generate translations using the model
72
+ with torch.no_grad():
73
+ generated_tokens = model.generate(
74
+ **inputs,
75
+ use_cache=True,
76
+ min_length=0,
77
+ max_length=256,
78
+ num_beams=5,
79
+ num_return_sequences=1,
80
+ )
81
+
82
+ # Decode the generated tokens into text
83
+ generated_tokens = tokenizer.batch_decode(
84
+ generated_tokens.detach().cpu().tolist(), src=False
85
+ )
86
+
87
+ # Postprocess the translations, including entity replacement
88
+ translations += postprocess_batch(
89
+ generated_tokens, lang=tgt_lang, placeholder_entity_map=entity_map
90
+ )
91
+
92
+ del inputs
93
+ torch.cuda.empty_cache()
94
+
95
+ return translations
96
+
97
+
98
+ en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(
99
+ en_indic_ckpt_dir, "en-indic", quantization
100
+ )
101
+
102
+ # ---------------------------------------------------------------------------
103
+ # English to Hindi
104
+ # ---------------------------------------------------------------------------
105
+ en_sents = [
106
+ "When I was young, I used to go to the park every day.",
107
+ "He has many old books, which he inherited from his ancestors.",
108
+ "I can't figure out how to solve my problem.",
109
+ "She is very hardworking and intelligent, which is why she got all the good marks.",
110
+ "We watched a new movie last week, which was very inspiring.",
111
+ "If you had met me at that time, we would have gone out to eat.",
112
+ "She went to the market with her sister to buy a new sari.",
113
+ "Raj told me that he is going to his grandmother's house next month.",
114
+ "All the kids were having fun at the party and were eating lots of sweets.",
115
+ "My friend has invited me to his birthday party, and I will give him a gift.",
116
+ ]
117
+ src_lang, tgt_lang = "eng_Latn", "hin_Deva"
118
+ hi_translations = batch_translate(
119
+ en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer
120
+ )
121
+
122
+ print(f"\n{src_lang} - {tgt_lang}")
123
+ for input_sentence, translation in zip(en_sents, hi_translations):
124
+ print(f"{src_lang}: {input_sentence}")
125
+ print(f"{tgt_lang}: {translation}")
handler.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import sys, os, re
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
7
+ from IndicTransTokenizer.utils import preprocess_batch, postprocess_batch
8
+ from IndicTransTokenizer.tokenizer import IndicTransTokenizer
9
+
10
+
11
+ class EndpointHandler():
12
+ def __init__(self, direction = "en-indic", quantization = ""):
13
+ self.model_name = "ai4bharat/indictrans2-en-indic-1B"
14
+
15
+ self.utterance_pattern = re.compile(r"^\d+$")
16
+ self.timestamp_pattern = re.compile(r"(\d+:\d+:\d+,\d+)\s*-->\s*(\d+:\d+:\d+,\d+)")
17
+
18
+ self.BATCH_SIZE = 16
19
+ self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ self.model = None
22
+ self.tokenizer = None
23
+
24
+ if quantization == "4-bit":
25
+ qconfig = BitsAndBytesConfig(
26
+ load_in_4bit=True,
27
+ bnb_4bit_use_double_quant=True,
28
+ bnb_4bit_compute_dtype=torch.bfloat16,
29
+ )
30
+ elif quantization == "8-bit":
31
+ qconfig = BitsAndBytesConfig(
32
+ load_in_8bit=True,
33
+ bnb_8bit_use_double_quant=True,
34
+ bnb_8bit_compute_dtype=torch.bfloat16,
35
+ )
36
+ else:
37
+ qconfig = None
38
+
39
+ self.tokenizer = IndicTransTokenizer(direction=direction)
40
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
41
+ self.model_name,
42
+ trust_remote_code=True,
43
+ low_cpu_mem_usage=True,
44
+ quantization_config=qconfig
45
+ )
46
+
47
+ if qconfig==None:
48
+ self.model = self.model.to(self.DEVICE)
49
+ self.model.half()
50
+
51
+ self.model.eval()
52
+
53
+
54
+ def batch_translate(self, input_sentences, src_lang, tgt_lang):
55
+ translations = []
56
+ for i in range(0, len(input_sentences), self.BATCH_SIZE):
57
+ batch = input_sentences[i : i + self.BATCH_SIZE]
58
+
59
+ # Preprocess the batch and extract entity mappings
60
+ batch, entity_map = preprocess_batch(
61
+ batch, src_lang=src_lang, tgt_lang=tgt_lang
62
+ )
63
+
64
+ # Tokenize the batch and generate input encodings
65
+ inputs = self.tokenizer(
66
+ batch,
67
+ src=True,
68
+ truncation=True,
69
+ padding="longest",
70
+ return_tensors="pt",
71
+ return_attention_mask=True,
72
+ ).to(self.DEVICE)
73
+
74
+ # Generate translations using the model
75
+ with torch.no_grad():
76
+ generated_tokens = self.model.generate(
77
+ **inputs,
78
+ use_cache=True,
79
+ min_length=0,
80
+ max_length=256,
81
+ num_beams=5,
82
+ num_return_sequences=1,
83
+ )
84
+
85
+ # Decode the generated tokens into text
86
+ generated_tokens = self.tokenizer.batch_decode(
87
+ generated_tokens.detach().cpu().tolist(), src=False
88
+ )
89
+
90
+ # Postprocess the translations, including entity replacement
91
+ translations += postprocess_batch(
92
+ generated_tokens, lang=tgt_lang, placeholder_entity_map=entity_map
93
+ )
94
+
95
+ del inputs
96
+ if torch.cuda.is_available():
97
+ torch.cuda.empty_cache()
98
+
99
+ return translations
100
+
101
+
102
+ def read_srt(self, srt_path):
103
+ data = []
104
+ with open(srt_path, 'r', encoding='utf-8') as fp:
105
+ utterance_ind = ""
106
+ start_end = ""
107
+ text = ""
108
+ for ind, line in enumerate(fp.readlines()):
109
+ line = line.strip()
110
+ if re.search(self.utterance_pattern, line) is not None:
111
+ utterance_ind = line
112
+ elif re.search(self.timestamp_pattern, line) is not None:
113
+ start_end = line
114
+ else:
115
+ text += line
116
+
117
+ if utterance_ind!='' and start_end!='' and text!='':
118
+ data.append({'utterance_ind': utterance_ind, 'start_end': start_end, 'text': text})
119
+ utterance_ind = ''
120
+ start_end = ''
121
+ text = ''
122
+
123
+ return data
124
+
125
+ def test(self, inputs) -> List[Dict[str, Any]]:
126
+ """
127
+ data args:
128
+ inputs (:obj: (transcript_path : 'str', src_lang : 'str', tgt_lang : 'str')
129
+ kwargs
130
+ Return:
131
+ A :obj:`list` | `dict`: will be serialized and returned
132
+ """
133
+
134
+ src_lang = inputs["src_lang"]
135
+ tgt_lang = inputs["tgt_lang"]
136
+ transcript_path = inputs["transcript_path"]
137
+
138
+ output_translations = []
139
+ if self.model is not None:
140
+ transcriptions = self.read_srt(transcript_path)
141
+ trans_sents = [entry['text'] for entry in transcriptions]
142
+ indic_translations = self.batch_translate(trans_sents, src_lang, tgt_lang)
143
+
144
+ for i in tqdm(range(len(transcriptions))):
145
+ entry = transcriptions[i]
146
+ entry['text'] = indic_translations[i]
147
+ output_translations.append(entry)
148
+
149
+ return output_translations
150
+ else:
151
+ return []
152
+
153
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
154
+ """
155
+ data args:
156
+ inputs (:obj: (transcript_path : 'str', src_lang : 'str', tgt_lang : 'str')
157
+ kwargs
158
+ Return:
159
+ A :obj:`list` | `dict`: will be serialized and returned
160
+ """
161
+
162
+ inputs = data.pop("inputs",data)
163
+
164
+ src_lang = inputs["src_lang"]
165
+ tgt_lang = inputs["tgt_lang"]
166
+ transcript_path = inputs["transcript_path"]
167
+
168
+ output_translations = []
169
+ if self.model is not None:
170
+ transcriptions = self.read_srt(transcript_path)
171
+ indic_translations = self.batch_translate(transcriptions, src_lang, tgt_lang)
172
+
173
+ for i in tqdm(range(len(transcriptions))):
174
+ entry = transcriptions[i]
175
+ entry['text'] = indic_translations[i]
176
+ output_translations.append(entry)
177
+
178
+ return output_translations
179
+ else:
180
+ return []
181
+
182
+
183
+ if __name__ == "__main__":
184
+ endpoint = EndpointHandler(quantization = "8-bit")
185
+ inputs = {}
186
+ inputs['src_lang'] = 'eng_Latn'
187
+ inputs['tgt_lang'] = 'tel_Telu'
188
+ inputs['transcript_path'] = './sample.srt'
189
+
190
+ outputs = endpoint.test(inputs)
191
+
192
+ print("Outputs: ")
193
+ for entry in outputs:
194
+ print(entry)
install.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #/bin/bash
2
+
3
+ root_dir=$(pwd)
4
+ echo "Setting up the environment in the $root_dir"
5
+
6
+ # --------------------------------------------------------------
7
+ # create and activate the virtual environment
8
+ # --------------------------------------------------------------
9
+ echo "Creating a virtual environment with python3"
10
+ conda create -n itv2_hf python=3.9 -y
11
+ conda activate itv2_hf
12
+
13
+ echo "Installing all the dependencies"
14
+ conda install pip
15
+ python3 -m pip install --upgrade pip
16
+
17
+
18
+ # --------------------------------------------------------------
19
+ # PyTorch Installation
20
+ # --------------------------------------------------------------
21
+ python3 -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
22
+
23
+
24
+ # --------------------------------------------------------------
25
+ # Install IndicNLP library and necessary resources
26
+ # --------------------------------------------------------------
27
+ git clone https://github.com/anoopkunchukuttan/indic_nlp_resources.git
28
+ export INDIC_RESOURCES_PATH=$root_dir/indic_nlp_resources
29
+
30
+ # we use version 0.92 which is the latest in the github repo
31
+ git clone https://github.com/anoopkunchukuttan/indic_nlp_library.git
32
+ cd indic_nlp_library
33
+ python3 -m pip install ./
34
+ cd $root_dir
35
+
36
+
37
+ # --------------------------------------------------------------
38
+ # Install additional utility packages
39
+ # --------------------------------------------------------------
40
+ python3 -m pip install sacremoses pandas regex mock transformers==4.33.2 urduhack[tf] mosestokenizer
41
+ python3 -c "import urduhack; urduhack.download()"
42
+ python3 -m pip install bitsandbytes scipy accelerate datasets
43
+
44
+
45
+ # --------------------------------------------------------------
46
+ # Sentencepiece for tokenization
47
+ # --------------------------------------------------------------
48
+ # build the cpp binaries from the source repo in order to use the command line utility
49
+ # source repo: https://github.com/google/sentencepiece
50
+ python3 -m pip install sentencepiece
51
+
52
+ echo "Setup completed!"
modeling_indictrans.py ADDED
@@ -0,0 +1,1449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch IndicTrans model."""
16
+
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from torch.nn import functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutput,
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ Seq2SeqLMOutput,
31
+ Seq2SeqModelOutput,
32
+ )
33
+
34
+ from transformers.utils import logging
35
+ from transformers.modeling_utils import PreTrainedModel
36
+
37
+ from configuration_indictrans import IndicTransConfig
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = "IndicTransConfig"
43
+
44
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
45
+
46
+
47
+ # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
48
+ def shift_tokens_right(
49
+ input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
50
+ ):
51
+ """
52
+ Shift input ids one token to the right.
53
+ """
54
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
55
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
56
+ shifted_input_ids[:, 0] = decoder_start_token_id
57
+
58
+ if pad_token_id is None:
59
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
60
+ # replace possible -100 values in labels by `pad_token_id`
61
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
62
+
63
+ return shifted_input_ids
64
+
65
+
66
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
67
+ def _make_causal_mask(
68
+ input_ids_shape: torch.Size,
69
+ dtype: torch.dtype,
70
+ device: torch.device,
71
+ past_key_values_length: int = 0,
72
+ ):
73
+ """
74
+ Make causal mask used for bi-directional self-attention.
75
+ """
76
+ bsz, tgt_len = input_ids_shape
77
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
78
+ mask_cond = torch.arange(mask.size(-1), device=device)
79
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
80
+ mask = mask.to(dtype)
81
+
82
+ if past_key_values_length > 0:
83
+ mask = torch.cat(
84
+ [
85
+ torch.zeros(
86
+ tgt_len, past_key_values_length, dtype=dtype, device=device
87
+ ),
88
+ mask,
89
+ ],
90
+ dim=-1,
91
+ )
92
+ return mask[None, None, :, :].expand(
93
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
94
+ )
95
+
96
+
97
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
98
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
99
+ """
100
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
101
+ """
102
+ bsz, src_len = mask.size()
103
+ tgt_len = tgt_len if tgt_len is not None else src_len
104
+
105
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
106
+
107
+ inverted_mask = 1.0 - expanded_mask
108
+
109
+ return inverted_mask.masked_fill(
110
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
111
+ )
112
+
113
+
114
+ def create_position_ids_from_input_ids(
115
+ input_ids, padding_idx, past_key_values_length=0
116
+ ):
117
+ """
118
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
119
+ are ignored. This is modified from fairseq's `utils.make_positions`.
120
+ """
121
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
122
+ mask = input_ids.ne(padding_idx).int()
123
+ incremental_indices = (
124
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
125
+ ) * mask
126
+ return incremental_indices.long() + padding_idx
127
+
128
+
129
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
130
+ class IndicTransSinusoidalPositionalEmbedding(nn.Module):
131
+ """This module produces sinusoidal positional embeddings of any length."""
132
+
133
+ def __init__(
134
+ self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
135
+ ):
136
+ super().__init__()
137
+ self.offset = 2
138
+ self.embedding_dim = embedding_dim
139
+ self.padding_idx = padding_idx
140
+ self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
141
+
142
+ def make_weights(
143
+ self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
144
+ ):
145
+ emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
146
+ if hasattr(self, "weights"):
147
+ # in forward put the weights on the correct dtype and device of the param
148
+ emb_weights = emb_weights.to(
149
+ dtype=self.weights.dtype, device=self.weights.device
150
+ )
151
+
152
+ self.register_buffer("weights", emb_weights, persistent=False)
153
+
154
+ @staticmethod
155
+ def get_embedding(
156
+ num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
157
+ ):
158
+ """
159
+ Build sinusoidal embeddings.
160
+
161
+ This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
162
+ "Attention Is All You Need".
163
+ """
164
+ half_dim = embedding_dim // 2
165
+ emb = math.log(10000) / (half_dim - 1)
166
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
167
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
168
+ 1
169
+ ) * emb.unsqueeze(0)
170
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
171
+ num_embeddings, -1
172
+ )
173
+ if embedding_dim % 2 == 1:
174
+ # zero pad
175
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
176
+ if padding_idx is not None:
177
+ emb[padding_idx, :] = 0
178
+
179
+ return emb.to(torch.get_default_dtype())
180
+
181
+ @torch.no_grad()
182
+ def forward(
183
+ self,
184
+ input_ids: torch.Tensor = None,
185
+ inputs_embeds: torch.Tensor = None,
186
+ past_key_values_length: int = 0,
187
+ ):
188
+ if input_ids is not None:
189
+ bsz, seq_len = input_ids.size()
190
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
191
+ position_ids = create_position_ids_from_input_ids(
192
+ input_ids, self.padding_idx, past_key_values_length
193
+ ).to(input_ids.device)
194
+ else:
195
+ bsz, seq_len = inputs_embeds.size()[:-1]
196
+ position_ids = self.create_position_ids_from_inputs_embeds(
197
+ inputs_embeds, past_key_values_length
198
+ )
199
+
200
+ # expand embeddings if needed
201
+ max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
202
+ if max_pos > self.weights.size(0):
203
+ self.make_weights(
204
+ max_pos + self.offset, self.embedding_dim, self.padding_idx
205
+ )
206
+
207
+ return (
208
+ self.weights.index_select(0, position_ids.view(-1))
209
+ .view(bsz, seq_len, self.weights.shape[-1])
210
+ .detach()
211
+ )
212
+
213
+ def create_position_ids_from_inputs_embeds(
214
+ self, inputs_embeds, past_key_values_length
215
+ ):
216
+ """
217
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
218
+
219
+ Args:
220
+ inputs_embeds: torch.Tensor
221
+
222
+ Returns: torch.Tensor
223
+ """
224
+ input_shape = inputs_embeds.size()[:-1]
225
+ sequence_length = input_shape[1]
226
+
227
+ position_ids = torch.arange(
228
+ self.padding_idx + 1,
229
+ sequence_length + self.padding_idx + 1,
230
+ dtype=torch.long,
231
+ device=inputs_embeds.device,
232
+ )
233
+ return (
234
+ position_ids.unsqueeze(0).expand(input_shape).contiguous()
235
+ + past_key_values_length
236
+ )
237
+
238
+
239
+ # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
240
+ class IndicTransAttention(nn.Module):
241
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
242
+
243
+ def __init__(
244
+ self,
245
+ embed_dim: int,
246
+ num_heads: int,
247
+ dropout: float = 0.0,
248
+ is_decoder: bool = False,
249
+ bias: bool = True,
250
+ ):
251
+ super().__init__()
252
+ self.embed_dim = embed_dim
253
+ self.num_heads = num_heads
254
+ self.dropout = dropout
255
+ self.head_dim = embed_dim // num_heads
256
+
257
+ if (self.head_dim * num_heads) != self.embed_dim:
258
+ raise ValueError(
259
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
260
+ f" and `num_heads`: {num_heads})."
261
+ )
262
+ self.scaling = self.head_dim**-0.5
263
+ self.is_decoder = is_decoder
264
+
265
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
266
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
267
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
268
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
269
+
270
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
271
+ return (
272
+ tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
273
+ .transpose(1, 2)
274
+ .contiguous()
275
+ )
276
+
277
+ def forward(
278
+ self,
279
+ hidden_states: torch.Tensor,
280
+ key_value_states: Optional[torch.Tensor] = None,
281
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
282
+ attention_mask: Optional[torch.Tensor] = None,
283
+ layer_head_mask: Optional[torch.Tensor] = None,
284
+ output_attentions: bool = False,
285
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
286
+ """Input shape: Batch x Time x Channel"""
287
+
288
+ # if key_value_states are provided this layer is used as a cross-attention layer
289
+ # for the decoder
290
+ is_cross_attention = key_value_states is not None
291
+
292
+ bsz, tgt_len, _ = hidden_states.size()
293
+
294
+ # get query proj
295
+ query_states = self.q_proj(hidden_states) * self.scaling
296
+ # get key, value proj
297
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
298
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
299
+ # the provided `key_value_states` to support prefix tuning
300
+ if (
301
+ is_cross_attention
302
+ and past_key_value is not None
303
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
304
+ ):
305
+ # reuse k,v, cross_attentions
306
+ key_states = past_key_value[0]
307
+ value_states = past_key_value[1]
308
+ elif is_cross_attention:
309
+ # cross_attentions
310
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
311
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
312
+ elif past_key_value is not None:
313
+ # reuse k, v, self_attention
314
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
315
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
316
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
317
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
318
+ else:
319
+ # self_attention
320
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
321
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
322
+
323
+ if self.is_decoder:
324
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
325
+ # Further calls to cross_attention layer can then reuse all cross-attention
326
+ # key/value_states (first "if" case)
327
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
328
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
329
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
330
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
331
+ past_key_value = (key_states, value_states)
332
+
333
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
334
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
335
+ key_states = key_states.reshape(*proj_shape)
336
+ value_states = value_states.reshape(*proj_shape)
337
+
338
+ src_len = key_states.size(1)
339
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
340
+
341
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
342
+ raise ValueError(
343
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
344
+ f" {attn_weights.size()}"
345
+ )
346
+
347
+ if attention_mask is not None:
348
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
349
+ raise ValueError(
350
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
351
+ )
352
+ attn_weights = (
353
+ attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
354
+ + attention_mask
355
+ )
356
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
357
+
358
+ attn_weights = F.softmax(attn_weights, dim=-1)
359
+
360
+ if layer_head_mask is not None:
361
+ if layer_head_mask.size() != (self.num_heads,):
362
+ raise ValueError(
363
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
364
+ f" {layer_head_mask.size()}"
365
+ )
366
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
367
+ bsz, self.num_heads, tgt_len, src_len
368
+ )
369
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
370
+
371
+ if output_attentions:
372
+ # this operation is a bit awkward, but it's required to
373
+ # make sure that attn_weights keeps its gradient.
374
+ # In order to do so, attn_weights have to be reshaped
375
+ # twice and have to be reused in the following
376
+ attn_weights_reshaped = attn_weights.view(
377
+ bsz, self.num_heads, tgt_len, src_len
378
+ )
379
+ attn_weights = attn_weights_reshaped.view(
380
+ bsz * self.num_heads, tgt_len, src_len
381
+ )
382
+ else:
383
+ attn_weights_reshaped = None
384
+
385
+ attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
386
+
387
+ attn_output = torch.bmm(attn_probs, value_states)
388
+
389
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
390
+ raise ValueError(
391
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
392
+ f" {attn_output.size()}"
393
+ )
394
+
395
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
396
+ attn_output = attn_output.transpose(1, 2)
397
+
398
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
399
+ # partitioned across GPUs when using tensor-parallelism.
400
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
401
+
402
+ attn_output = self.out_proj(attn_output)
403
+
404
+ return attn_output, attn_weights_reshaped, past_key_value
405
+
406
+
407
+ # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
408
+ class IndicTransEncoderLayer(nn.Module):
409
+ def __init__(self, config: IndicTransConfig):
410
+ super().__init__()
411
+ self.embed_dim = config.encoder_embed_dim
412
+ self.self_attn = IndicTransAttention(
413
+ embed_dim=self.embed_dim,
414
+ num_heads=config.encoder_attention_heads,
415
+ dropout=config.attention_dropout,
416
+ )
417
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
418
+ self.dropout = config.dropout
419
+ self.activation_fn = ACT2FN[config.activation_function]
420
+ self.activation_dropout = config.activation_dropout
421
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
422
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
423
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
424
+ self.normalize_before = config.encoder_normalize_before
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states: torch.Tensor,
429
+ attention_mask: torch.Tensor,
430
+ layer_head_mask: torch.Tensor,
431
+ output_attentions: bool = False,
432
+ ) -> torch.Tensor:
433
+ """
434
+ Args:
435
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
436
+ attention_mask (`torch.FloatTensor`): attention mask of size
437
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
438
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
439
+ `(encoder_attention_heads,)`.
440
+ output_attentions (`bool`, *optional*):
441
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
442
+ returned tensors for more detail.
443
+ """
444
+ residual = hidden_states
445
+ if self.normalize_before:
446
+ hidden_states = self.self_attn_layer_norm(hidden_states)
447
+ hidden_states, attn_weights, _ = self.self_attn(
448
+ hidden_states=hidden_states,
449
+ attention_mask=attention_mask,
450
+ layer_head_mask=layer_head_mask,
451
+ output_attentions=output_attentions,
452
+ )
453
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
454
+ hidden_states = residual + hidden_states
455
+ if not self.normalize_before:
456
+ hidden_states = self.self_attn_layer_norm(hidden_states)
457
+
458
+ residual = hidden_states
459
+ if self.normalize_before:
460
+ hidden_states = self.final_layer_norm(hidden_states)
461
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
462
+ hidden_states = F.dropout(
463
+ hidden_states, p=self.activation_dropout, training=self.training
464
+ )
465
+ hidden_states = self.fc2(hidden_states)
466
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
467
+ hidden_states = residual + hidden_states
468
+ if not self.normalize_before:
469
+ hidden_states = self.final_layer_norm(hidden_states)
470
+
471
+ if hidden_states.dtype == torch.float16 and (
472
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
473
+ ):
474
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
475
+ hidden_states = torch.clamp(
476
+ hidden_states, min=-clamp_value, max=clamp_value
477
+ )
478
+
479
+ outputs = (hidden_states,)
480
+
481
+ if output_attentions:
482
+ outputs += (attn_weights,)
483
+
484
+ return outputs
485
+
486
+
487
+ # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
488
+ class IndicTransDecoderLayer(nn.Module):
489
+ def __init__(self, config: IndicTransConfig):
490
+ super().__init__()
491
+ self.embed_dim = config.decoder_embed_dim
492
+
493
+ self.self_attn = IndicTransAttention(
494
+ embed_dim=self.embed_dim,
495
+ num_heads=config.decoder_attention_heads,
496
+ dropout=config.attention_dropout,
497
+ is_decoder=True,
498
+ )
499
+ self.dropout = config.dropout
500
+ self.activation_fn = ACT2FN[config.activation_function]
501
+ self.activation_dropout = config.activation_dropout
502
+
503
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
504
+ self.encoder_attn = IndicTransAttention(
505
+ self.embed_dim,
506
+ config.decoder_attention_heads,
507
+ dropout=config.attention_dropout,
508
+ is_decoder=True,
509
+ )
510
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
511
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
512
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
513
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
514
+ self.normalize_before = config.decoder_normalize_before
515
+
516
+ def forward(
517
+ self,
518
+ hidden_states: torch.Tensor,
519
+ attention_mask: Optional[torch.Tensor] = None,
520
+ encoder_hidden_states: Optional[torch.Tensor] = None,
521
+ encoder_attention_mask: Optional[torch.Tensor] = None,
522
+ layer_head_mask: Optional[torch.Tensor] = None,
523
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
524
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
525
+ output_attentions: Optional[bool] = False,
526
+ use_cache: Optional[bool] = True,
527
+ ) -> torch.Tensor:
528
+ """
529
+ Args:
530
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
531
+ attention_mask (`torch.FloatTensor`): attention mask of size
532
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
533
+ encoder_hidden_states (`torch.FloatTensor`):
534
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
535
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
536
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
537
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
538
+ `(encoder_attention_heads,)`.
539
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
540
+ size `(decoder_attention_heads,)`.
541
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
542
+ output_attentions (`bool`, *optional*):
543
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
544
+ returned tensors for more detail.
545
+ """
546
+ residual = hidden_states
547
+ if self.normalize_before:
548
+ hidden_states = self.self_attn_layer_norm(hidden_states)
549
+
550
+ # Self Attention
551
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
552
+ self_attn_past_key_value = (
553
+ past_key_value[:2] if past_key_value is not None else None
554
+ )
555
+ # add present self-attn cache to positions 1,2 of present_key_value tuple
556
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
557
+ hidden_states=hidden_states,
558
+ past_key_value=self_attn_past_key_value,
559
+ attention_mask=attention_mask,
560
+ layer_head_mask=layer_head_mask,
561
+ output_attentions=output_attentions,
562
+ )
563
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
564
+ hidden_states = residual + hidden_states
565
+ if not self.normalize_before:
566
+ hidden_states = self.self_attn_layer_norm(hidden_states)
567
+
568
+ # Cross-Attention Block
569
+ cross_attn_present_key_value = None
570
+ cross_attn_weights = None
571
+ if encoder_hidden_states is not None:
572
+ residual = hidden_states
573
+ if self.normalize_before:
574
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
575
+
576
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
577
+ cross_attn_past_key_value = (
578
+ past_key_value[-2:] if past_key_value is not None else None
579
+ )
580
+ (
581
+ hidden_states,
582
+ cross_attn_weights,
583
+ cross_attn_present_key_value,
584
+ ) = self.encoder_attn(
585
+ hidden_states=hidden_states,
586
+ key_value_states=encoder_hidden_states,
587
+ attention_mask=encoder_attention_mask,
588
+ layer_head_mask=cross_attn_layer_head_mask,
589
+ past_key_value=cross_attn_past_key_value,
590
+ output_attentions=output_attentions,
591
+ )
592
+ hidden_states = F.dropout(
593
+ hidden_states, p=self.dropout, training=self.training
594
+ )
595
+ hidden_states = residual + hidden_states
596
+ if not self.normalize_before:
597
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
598
+
599
+ # add cross-attn to positions 3,4 of present_key_value tuple
600
+ present_key_value = present_key_value + cross_attn_present_key_value
601
+
602
+ # Fully Connected
603
+ residual = hidden_states
604
+ if self.normalize_before:
605
+ hidden_states = self.final_layer_norm(hidden_states)
606
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
607
+ hidden_states = F.dropout(
608
+ hidden_states, p=self.activation_dropout, training=self.training
609
+ )
610
+ hidden_states = self.fc2(hidden_states)
611
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
612
+ hidden_states = residual + hidden_states
613
+ if not self.normalize_before:
614
+ hidden_states = self.final_layer_norm(hidden_states)
615
+
616
+ outputs = (hidden_states,)
617
+
618
+ if output_attentions:
619
+ outputs += (self_attn_weights, cross_attn_weights)
620
+
621
+ if use_cache:
622
+ outputs += (present_key_value,)
623
+
624
+ return outputs
625
+
626
+
627
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
628
+ class IndicTransPreTrainedModel(PreTrainedModel):
629
+ config_class = IndicTransConfig
630
+ base_model_prefix = "model"
631
+ supports_gradient_checkpointing = True
632
+ _no_split_modules = ["IndicTransAttention"]
633
+
634
+ def _init_weights(self, module):
635
+ std = self.config.init_std
636
+ if isinstance(module, nn.Linear):
637
+ module.weight.data.normal_(mean=0.0, std=std)
638
+ if module.bias is not None:
639
+ module.bias.data.zero_()
640
+ elif isinstance(module, nn.Embedding):
641
+ module.weight.data.normal_(mean=0.0, std=std)
642
+ if module.padding_idx is not None:
643
+ module.weight.data[module.padding_idx].zero_()
644
+
645
+ def _set_gradient_checkpointing(self, module, value=False):
646
+ if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
647
+ module.gradient_checkpointing = value
648
+
649
+
650
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
651
+ class IndicTransEncoder(IndicTransPreTrainedModel):
652
+ """
653
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
654
+ [`IndicTransEncoderLayer`].
655
+
656
+ Args:
657
+ config: IndicTransConfig
658
+ embed_tokens (nn.Embedding): output embedding
659
+ """
660
+
661
+ def __init__(
662
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
663
+ ):
664
+ super().__init__(config)
665
+
666
+ self.dropout = config.dropout
667
+ self.layerdrop = config.encoder_layerdrop
668
+
669
+ embed_dim = config.encoder_embed_dim
670
+ self.padding_idx = config.pad_token_id
671
+ self.max_source_positions = config.max_source_positions
672
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
673
+
674
+ self.embed_tokens = nn.Embedding(
675
+ config.encoder_vocab_size, embed_dim, self.padding_idx
676
+ )
677
+
678
+ if embed_tokens is not None:
679
+ self.embed_tokens.weight = embed_tokens.weight
680
+
681
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
682
+ config.max_source_positions,
683
+ embed_dim,
684
+ self.padding_idx,
685
+ )
686
+ self.layers = nn.ModuleList(
687
+ [IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
688
+ )
689
+ self.layer_norm = (
690
+ nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
691
+ )
692
+ self.layernorm_embedding = (
693
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
694
+ )
695
+
696
+ self.gradient_checkpointing = False
697
+ # Initialize weights and apply final processing
698
+ self.post_init()
699
+
700
+ def forward(
701
+ self,
702
+ input_ids: Optional[torch.Tensor] = None,
703
+ attention_mask: Optional[torch.Tensor] = None,
704
+ head_mask: Optional[torch.Tensor] = None,
705
+ inputs_embeds: Optional[torch.Tensor] = None,
706
+ output_attentions: Optional[bool] = None,
707
+ output_hidden_states: Optional[bool] = None,
708
+ return_dict: Optional[bool] = None,
709
+ ):
710
+ r"""
711
+ Args:
712
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
713
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
714
+ provide it.
715
+
716
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
717
+ [`PreTrainedTokenizer.__call__`] for details.
718
+
719
+ [What are input IDs?](../glossary#input-ids)
720
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
721
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
722
+
723
+ - 1 for tokens that are **not masked**,
724
+ - 0 for tokens that are **masked**.
725
+
726
+ [What are attention masks?](../glossary#attention-mask)
727
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
728
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
729
+
730
+ - 1 indicates the head is **not masked**,
731
+ - 0 indicates the head is **masked**.
732
+
733
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
734
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
735
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
736
+ than the model's internal embedding lookup matrix.
737
+ output_attentions (`bool`, *optional*):
738
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
739
+ returned tensors for more detail.
740
+ output_hidden_states (`bool`, *optional*):
741
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
742
+ for more detail.
743
+ return_dict (`bool`, *optional*):
744
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
745
+ """
746
+ output_attentions = (
747
+ output_attentions
748
+ if output_attentions is not None
749
+ else self.config.output_attentions
750
+ )
751
+ output_hidden_states = (
752
+ output_hidden_states
753
+ if output_hidden_states is not None
754
+ else self.config.output_hidden_states
755
+ )
756
+ return_dict = (
757
+ return_dict if return_dict is not None else self.config.use_return_dict
758
+ )
759
+
760
+ # retrieve input_ids and inputs_embeds
761
+ if input_ids is not None and inputs_embeds is not None:
762
+ raise ValueError(
763
+ "You cannot specify both input_ids and inputs_embeds at the same time"
764
+ )
765
+ elif input_ids is not None:
766
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
767
+ input_shape = input_ids.size()
768
+ input_ids = input_ids.view(-1, input_shape[-1])
769
+ elif inputs_embeds is not None:
770
+ input_shape = inputs_embeds.size()[:-1]
771
+ else:
772
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
773
+
774
+ if inputs_embeds is None:
775
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
776
+
777
+ embed_pos = self.embed_positions(input_ids, inputs_embeds)
778
+ embed_pos = embed_pos.to(inputs_embeds.device)
779
+
780
+ hidden_states = inputs_embeds + embed_pos
781
+ if self.layernorm_embedding is not None:
782
+ x = self.layernorm_embedding(hidden_states)
783
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
784
+
785
+ # expand attention_mask
786
+ if attention_mask is not None:
787
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
788
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
789
+
790
+ encoder_states = () if output_hidden_states else None
791
+ all_attentions = () if output_attentions else None
792
+
793
+ # check if head_mask has a correct number of layers specified if desired
794
+ if head_mask is not None:
795
+ if head_mask.size()[0] != len(self.layers):
796
+ raise ValueError(
797
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
798
+ f" {head_mask.size()[0]}."
799
+ )
800
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
801
+
802
+ for idx, encoder_layer in enumerate(self.layers):
803
+ if output_hidden_states:
804
+ encoder_states = encoder_states + (hidden_states,)
805
+
806
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
807
+ dropout_probability = torch.rand([])
808
+
809
+ skip_the_layer = (
810
+ True
811
+ if self.training and (dropout_probability < self.layerdrop)
812
+ else False
813
+ )
814
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
815
+ # under deepspeed zero3 all gpus must run in sync
816
+
817
+ if self.gradient_checkpointing and self.training:
818
+ # create gradient checkpointing function
819
+ def create_custom_forward(module):
820
+ def custom_forward(*inputs):
821
+ return module(*inputs, output_attentions)
822
+
823
+ return custom_forward
824
+
825
+ layer_outputs = torch.utils.checkpoint.checkpoint(
826
+ create_custom_forward(encoder_layer),
827
+ hidden_states,
828
+ attention_mask,
829
+ (head_mask[idx] if head_mask is not None else None),
830
+ )
831
+ else:
832
+ layer_outputs = encoder_layer(
833
+ hidden_states,
834
+ attention_mask,
835
+ layer_head_mask=(
836
+ head_mask[idx] if head_mask is not None else None
837
+ ),
838
+ output_attentions=output_attentions,
839
+ )
840
+
841
+ hidden_states = layer_outputs[0]
842
+
843
+ if skip_the_layer:
844
+ layer_outputs = (None, None)
845
+
846
+ if output_attentions:
847
+ all_attentions = all_attentions + (layer_outputs[1],)
848
+
849
+ if self.layer_norm is not None:
850
+ hidden_states = self.layer_norm(hidden_states)
851
+
852
+ if output_hidden_states:
853
+ encoder_states = encoder_states + (hidden_states,)
854
+
855
+ if not return_dict:
856
+ return tuple(
857
+ v
858
+ for v in [hidden_states, encoder_states, all_attentions]
859
+ if v is not None
860
+ )
861
+ return BaseModelOutput(
862
+ last_hidden_state=hidden_states,
863
+ hidden_states=encoder_states,
864
+ attentions=all_attentions,
865
+ )
866
+
867
+
868
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
869
+ class IndicTransDecoder(IndicTransPreTrainedModel):
870
+ """
871
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
872
+
873
+ Args:
874
+ config: IndicTransConfig
875
+ embed_tokens (nn.Embedding): output embedding
876
+ """
877
+
878
+ def __init__(
879
+ self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
880
+ ):
881
+ super().__init__(config)
882
+ self.dropout = config.dropout
883
+ self.layerdrop = config.decoder_layerdrop
884
+
885
+ embed_dim = config.encoder_embed_dim
886
+ self.padding_idx = config.pad_token_id
887
+ self.max_target_positions = config.max_target_positions
888
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
889
+
890
+ self.embed_tokens = nn.Embedding(
891
+ config.decoder_vocab_size, embed_dim, self.padding_idx
892
+ )
893
+
894
+ if embed_tokens is not None:
895
+ self.embed_tokens.weight = embed_tokens.weight
896
+
897
+ self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
898
+ config.max_target_positions,
899
+ embed_dim,
900
+ self.padding_idx,
901
+ )
902
+ self.layers = nn.ModuleList(
903
+ [IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
904
+ )
905
+ self.layer_norm = (
906
+ nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
907
+ )
908
+ self.layernorm_embedding = (
909
+ nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
910
+ )
911
+
912
+ self.gradient_checkpointing = False
913
+ # Initialize weights and apply final processing
914
+ self.post_init()
915
+
916
+ def forward(
917
+ self,
918
+ input_ids: Optional[torch.Tensor] = None,
919
+ attention_mask: Optional[torch.Tensor] = None,
920
+ encoder_hidden_states: Optional[torch.Tensor] = None,
921
+ encoder_attention_mask: Optional[torch.Tensor] = None,
922
+ head_mask: Optional[torch.Tensor] = None,
923
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
924
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
925
+ inputs_embeds: Optional[torch.Tensor] = None,
926
+ use_cache: Optional[bool] = None,
927
+ output_attentions: Optional[bool] = None,
928
+ output_hidden_states: Optional[bool] = None,
929
+ return_dict: Optional[bool] = None,
930
+ ):
931
+ r"""
932
+ Args:
933
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
934
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
935
+ provide it.
936
+
937
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
938
+ [`PreTrainedTokenizer.__call__`] for details.
939
+
940
+ [What are input IDs?](../glossary#input-ids)
941
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
942
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
943
+
944
+ - 1 for tokens that are **not masked**,
945
+ - 0 for tokens that are **masked**.
946
+
947
+ [What are attention masks?](../glossary#attention-mask)
948
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
949
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
950
+ of the decoder.
951
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
952
+ Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
953
+ selected in `[0, 1]`:
954
+
955
+ - 1 for tokens that are **not masked**,
956
+ - 0 for tokens that are **masked**.
957
+
958
+ [What are attention masks?](../glossary#attention-mask)
959
+ head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
960
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
961
+
962
+ - 1 indicates the head is **not masked**,
963
+ - 0 indicates the head is **masked**.
964
+
965
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
966
+ Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
967
+ cross-attention on hidden heads. Mask values selected in `[0, 1]`:
968
+
969
+ - 1 indicates the head is **not masked**,
970
+ - 0 indicates the head is **masked**.
971
+
972
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
973
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
974
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
975
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
976
+
977
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
978
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
979
+
980
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
981
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
982
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
983
+ shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
984
+ `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
985
+ control over how to convert `input_ids` indices into associated vectors than the model's internal
986
+ embedding lookup matrix.
987
+ output_attentions (`bool`, *optional*):
988
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
989
+ returned tensors for more detail.
990
+ output_hidden_states (`bool`, *optional*):
991
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
992
+ for more detail.
993
+ return_dict (`bool`, *optional*):
994
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
995
+ """
996
+ output_attentions = (
997
+ output_attentions
998
+ if output_attentions is not None
999
+ else self.config.output_attentions
1000
+ )
1001
+ output_hidden_states = (
1002
+ output_hidden_states
1003
+ if output_hidden_states is not None
1004
+ else self.config.output_hidden_states
1005
+ )
1006
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1007
+ return_dict = (
1008
+ return_dict if return_dict is not None else self.config.use_return_dict
1009
+ )
1010
+
1011
+ # retrieve input_ids and inputs_embeds
1012
+ if input_ids is not None and inputs_embeds is not None:
1013
+ raise ValueError(
1014
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
1015
+ )
1016
+ elif input_ids is not None:
1017
+ input_shape = input_ids.size()
1018
+ input_ids = input_ids.view(-1, input_shape[-1])
1019
+ elif inputs_embeds is not None:
1020
+ input_shape = inputs_embeds.size()[:-1]
1021
+ else:
1022
+ raise ValueError(
1023
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
1024
+ )
1025
+
1026
+ # past_key_values_length
1027
+ past_key_values_length = (
1028
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
1029
+ )
1030
+
1031
+ if inputs_embeds is None:
1032
+ inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1033
+
1034
+ # create causal mask
1035
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1036
+ combined_attention_mask = None
1037
+ if input_shape[-1] > 1:
1038
+ combined_attention_mask = _make_causal_mask(
1039
+ input_shape,
1040
+ inputs_embeds.dtype,
1041
+ device=inputs_embeds.device,
1042
+ past_key_values_length=past_key_values_length,
1043
+ )
1044
+
1045
+ if attention_mask is not None and combined_attention_mask is not None:
1046
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1047
+ combined_attention_mask = combined_attention_mask + _expand_mask(
1048
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1049
+ )
1050
+
1051
+ # expand encoder attention mask
1052
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
1053
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1054
+ encoder_attention_mask = _expand_mask(
1055
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1056
+ )
1057
+
1058
+ # embed positions
1059
+ positions = self.embed_positions(
1060
+ input_ids, inputs_embeds, past_key_values_length
1061
+ )
1062
+ positions = positions.to(inputs_embeds.device)
1063
+
1064
+ hidden_states = inputs_embeds + positions
1065
+ if self.layernorm_embedding is not None:
1066
+ hidden_states = self.layernorm_embedding(hidden_states)
1067
+
1068
+ hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1069
+
1070
+ if self.gradient_checkpointing and self.training:
1071
+ if use_cache:
1072
+ logger.warning_once(
1073
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting"
1074
+ " `use_cache=False`..."
1075
+ )
1076
+ use_cache = False
1077
+
1078
+ # decoder layers
1079
+ all_hidden_states = () if output_hidden_states else None
1080
+ all_self_attns = () if output_attentions else None
1081
+ all_cross_attentions = () if output_attentions else None
1082
+ next_decoder_cache = () if use_cache else None
1083
+
1084
+ # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1085
+ for attn_mask, mask_name in zip(
1086
+ [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
1087
+ ):
1088
+ if attn_mask is not None:
1089
+ if attn_mask.size()[0] != len(self.layers):
1090
+ raise ValueError(
1091
+ f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
1092
+ f" {head_mask.size()[0]}."
1093
+ )
1094
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
1095
+
1096
+ for idx, decoder_layer in enumerate(self.layers):
1097
+ if output_hidden_states:
1098
+ all_hidden_states += (hidden_states,)
1099
+
1100
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1101
+ dropout_probability = torch.rand([])
1102
+
1103
+ skip_the_layer = (
1104
+ True
1105
+ if self.training and (dropout_probability < self.layerdrop)
1106
+ else False
1107
+ )
1108
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
1109
+ # under deepspeed zero3 all gpus must run in sync
1110
+
1111
+ past_key_value = (
1112
+ past_key_values[idx] if past_key_values is not None else None
1113
+ )
1114
+
1115
+ if self.gradient_checkpointing and self.training:
1116
+
1117
+ def create_custom_forward(module):
1118
+ def custom_forward(*inputs):
1119
+ # None for past_key_value
1120
+ return module(*inputs, output_attentions, use_cache)
1121
+
1122
+ return custom_forward
1123
+
1124
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1125
+ create_custom_forward(decoder_layer),
1126
+ hidden_states,
1127
+ combined_attention_mask,
1128
+ encoder_hidden_states,
1129
+ encoder_attention_mask,
1130
+ head_mask[idx] if head_mask is not None else None,
1131
+ cross_attn_head_mask[idx]
1132
+ if cross_attn_head_mask is not None
1133
+ else None,
1134
+ None,
1135
+ )
1136
+ else:
1137
+ layer_outputs = decoder_layer(
1138
+ hidden_states,
1139
+ attention_mask=combined_attention_mask,
1140
+ encoder_hidden_states=encoder_hidden_states,
1141
+ encoder_attention_mask=encoder_attention_mask,
1142
+ layer_head_mask=(
1143
+ head_mask[idx] if head_mask is not None else None
1144
+ ),
1145
+ cross_attn_layer_head_mask=(
1146
+ cross_attn_head_mask[idx]
1147
+ if cross_attn_head_mask is not None
1148
+ else None
1149
+ ),
1150
+ past_key_value=past_key_value,
1151
+ output_attentions=output_attentions,
1152
+ use_cache=use_cache,
1153
+ )
1154
+
1155
+ hidden_states = layer_outputs[0]
1156
+
1157
+ if skip_the_layer:
1158
+ continue
1159
+
1160
+ if use_cache:
1161
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1162
+
1163
+ if output_attentions:
1164
+ all_self_attns += (layer_outputs[1],)
1165
+ all_cross_attentions += (layer_outputs[2],)
1166
+
1167
+ if self.layer_norm is not None:
1168
+ hidden_states = self.layer_norm(hidden_states)
1169
+
1170
+ # add hidden states from the last decoder layer
1171
+ if output_hidden_states:
1172
+ all_hidden_states += (hidden_states,)
1173
+
1174
+ next_cache = next_decoder_cache if use_cache else None
1175
+ if not return_dict:
1176
+ return tuple(
1177
+ v
1178
+ for v in [
1179
+ hidden_states,
1180
+ next_cache,
1181
+ all_hidden_states,
1182
+ all_self_attns,
1183
+ all_cross_attentions,
1184
+ ]
1185
+ if v is not None
1186
+ )
1187
+ return BaseModelOutputWithPastAndCrossAttentions(
1188
+ last_hidden_state=hidden_states,
1189
+ past_key_values=next_cache,
1190
+ hidden_states=all_hidden_states,
1191
+ attentions=all_self_attns,
1192
+ cross_attentions=all_cross_attentions,
1193
+ )
1194
+
1195
+
1196
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
1197
+ class IndicTransModel(IndicTransPreTrainedModel):
1198
+ _tied_weights_keys = None
1199
+
1200
+ def __init__(self, config: IndicTransConfig):
1201
+ super().__init__(config)
1202
+
1203
+ self.encoder = IndicTransEncoder(config)
1204
+ self.decoder = IndicTransDecoder(config)
1205
+
1206
+ # Initialize weights and apply final processing
1207
+ self.post_init()
1208
+
1209
+ def get_encoder(self):
1210
+ return self.encoder
1211
+
1212
+ def get_decoder(self):
1213
+ return self.decoder
1214
+
1215
+ def forward(
1216
+ self,
1217
+ input_ids: Optional[torch.LongTensor] = None,
1218
+ attention_mask: Optional[torch.Tensor] = None,
1219
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1220
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1221
+ head_mask: Optional[torch.Tensor] = None,
1222
+ decoder_head_mask: Optional[torch.Tensor] = None,
1223
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1224
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1225
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1226
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1227
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1228
+ use_cache: Optional[bool] = None,
1229
+ output_attentions: Optional[bool] = None,
1230
+ output_hidden_states: Optional[bool] = None,
1231
+ return_dict: Optional[bool] = None,
1232
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
1233
+ output_attentions = (
1234
+ output_attentions
1235
+ if output_attentions is not None
1236
+ else self.config.output_attentions
1237
+ )
1238
+ output_hidden_states = (
1239
+ output_hidden_states
1240
+ if output_hidden_states is not None
1241
+ else self.config.output_hidden_states
1242
+ )
1243
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1244
+ return_dict = (
1245
+ return_dict if return_dict is not None else self.config.use_return_dict
1246
+ )
1247
+
1248
+ if encoder_outputs is None:
1249
+ encoder_outputs = self.encoder(
1250
+ input_ids=input_ids,
1251
+ attention_mask=attention_mask,
1252
+ head_mask=head_mask,
1253
+ inputs_embeds=inputs_embeds,
1254
+ output_attentions=output_attentions,
1255
+ output_hidden_states=output_hidden_states,
1256
+ return_dict=return_dict,
1257
+ )
1258
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1259
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1260
+ encoder_outputs = BaseModelOutput(
1261
+ last_hidden_state=encoder_outputs[0],
1262
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1263
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1264
+ )
1265
+
1266
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1267
+ decoder_outputs = self.decoder(
1268
+ input_ids=decoder_input_ids,
1269
+ attention_mask=decoder_attention_mask,
1270
+ encoder_hidden_states=encoder_outputs[0],
1271
+ encoder_attention_mask=attention_mask,
1272
+ head_mask=decoder_head_mask,
1273
+ cross_attn_head_mask=cross_attn_head_mask,
1274
+ past_key_values=past_key_values,
1275
+ inputs_embeds=decoder_inputs_embeds,
1276
+ use_cache=use_cache,
1277
+ output_attentions=output_attentions,
1278
+ output_hidden_states=output_hidden_states,
1279
+ return_dict=return_dict,
1280
+ )
1281
+
1282
+ if not return_dict:
1283
+ return decoder_outputs + encoder_outputs
1284
+
1285
+ return Seq2SeqModelOutput(
1286
+ last_hidden_state=decoder_outputs.last_hidden_state,
1287
+ past_key_values=decoder_outputs.past_key_values,
1288
+ decoder_hidden_states=decoder_outputs.hidden_states,
1289
+ decoder_attentions=decoder_outputs.attentions,
1290
+ cross_attentions=decoder_outputs.cross_attentions,
1291
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1292
+ encoder_hidden_states=encoder_outputs.hidden_states,
1293
+ encoder_attentions=encoder_outputs.attentions,
1294
+ )
1295
+
1296
+
1297
+ # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
1298
+ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1299
+ base_model_prefix = "model"
1300
+ _tied_weights_keys = None
1301
+
1302
+ def __init__(self, config: IndicTransConfig):
1303
+ super().__init__(config)
1304
+ self.model = IndicTransModel(config)
1305
+ self.lm_head = nn.Linear(
1306
+ config.decoder_embed_dim, config.decoder_vocab_size, bias=False
1307
+ )
1308
+
1309
+ if config.share_decoder_input_output_embed:
1310
+ self.lm_head.weight = self.model.decoder.embed_tokens.weight
1311
+
1312
+ self.post_init()
1313
+
1314
+ def tie_weights(self):
1315
+ pass
1316
+
1317
+ def get_encoder(self):
1318
+ return self.model.get_encoder()
1319
+
1320
+ def get_decoder(self):
1321
+ return self.model.get_decoder()
1322
+
1323
+ def get_output_embeddings(self):
1324
+ return self.lm_head
1325
+
1326
+ def set_output_embeddings(self, new_embeddings):
1327
+ self.lm_head = new_embeddings
1328
+
1329
+ def forward(
1330
+ self,
1331
+ input_ids: Optional[torch.LongTensor] = None,
1332
+ attention_mask: Optional[torch.Tensor] = None,
1333
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1334
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
1335
+ head_mask: Optional[torch.Tensor] = None,
1336
+ decoder_head_mask: Optional[torch.Tensor] = None,
1337
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1338
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1339
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1340
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1341
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1342
+ labels: Optional[torch.LongTensor] = None,
1343
+ use_cache: Optional[bool] = None,
1344
+ output_attentions: Optional[bool] = None,
1345
+ output_hidden_states: Optional[bool] = None,
1346
+ return_dict: Optional[bool] = None,
1347
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
1348
+ r"""
1349
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1350
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1351
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1352
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1353
+
1354
+ Returns:
1355
+ """
1356
+ return_dict = (
1357
+ return_dict if return_dict is not None else self.config.use_return_dict
1358
+ )
1359
+
1360
+ if labels is not None:
1361
+ if decoder_input_ids is None:
1362
+ decoder_input_ids = shift_tokens_right(
1363
+ labels, self.config.pad_token_id, self.config.decoder_start_token_id
1364
+ )
1365
+
1366
+ outputs = self.model(
1367
+ input_ids,
1368
+ attention_mask=attention_mask,
1369
+ decoder_input_ids=decoder_input_ids,
1370
+ encoder_outputs=encoder_outputs,
1371
+ decoder_attention_mask=decoder_attention_mask,
1372
+ head_mask=head_mask,
1373
+ decoder_head_mask=decoder_head_mask,
1374
+ cross_attn_head_mask=cross_attn_head_mask,
1375
+ past_key_values=past_key_values,
1376
+ inputs_embeds=inputs_embeds,
1377
+ decoder_inputs_embeds=decoder_inputs_embeds,
1378
+ use_cache=use_cache,
1379
+ output_attentions=output_attentions,
1380
+ output_hidden_states=output_hidden_states,
1381
+ return_dict=return_dict,
1382
+ )
1383
+ lm_logits = self.lm_head(outputs[0])
1384
+
1385
+ masked_lm_loss = None
1386
+ if labels is not None:
1387
+ # move labels to the correct device to enable PP
1388
+ labels = labels.to(lm_logits.device)
1389
+ loss_fct = nn.CrossEntropyLoss()
1390
+ masked_lm_loss = loss_fct(
1391
+ lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
1392
+ )
1393
+
1394
+ if not return_dict:
1395
+ output = (lm_logits,) + outputs[1:]
1396
+ return (
1397
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1398
+ )
1399
+
1400
+ return Seq2SeqLMOutput(
1401
+ loss=masked_lm_loss,
1402
+ logits=lm_logits,
1403
+ past_key_values=outputs.past_key_values,
1404
+ decoder_hidden_states=outputs.decoder_hidden_states,
1405
+ decoder_attentions=outputs.decoder_attentions,
1406
+ cross_attentions=outputs.cross_attentions,
1407
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1408
+ encoder_hidden_states=outputs.encoder_hidden_states,
1409
+ encoder_attentions=outputs.encoder_attentions,
1410
+ )
1411
+
1412
+ def prepare_inputs_for_generation(
1413
+ self,
1414
+ decoder_input_ids,
1415
+ past_key_values=None,
1416
+ attention_mask=None,
1417
+ head_mask=None,
1418
+ decoder_head_mask=None,
1419
+ cross_attn_head_mask=None,
1420
+ use_cache=None,
1421
+ encoder_outputs=None,
1422
+ **kwargs,
1423
+ ):
1424
+ # cut decoder_input_ids if past is used
1425
+ if past_key_values is not None:
1426
+ decoder_input_ids = decoder_input_ids[:, -1:]
1427
+
1428
+ return {
1429
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
1430
+ "encoder_outputs": encoder_outputs,
1431
+ "past_key_values": past_key_values,
1432
+ "decoder_input_ids": decoder_input_ids,
1433
+ "attention_mask": attention_mask,
1434
+ "head_mask": head_mask,
1435
+ "decoder_head_mask": decoder_head_mask,
1436
+ "cross_attn_head_mask": cross_attn_head_mask,
1437
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1438
+ }
1439
+
1440
+ @staticmethod
1441
+ def _reorder_cache(past_key_values, beam_idx):
1442
+ reordered_past = ()
1443
+ for layer_past in past_key_values:
1444
+ reordered_past += (
1445
+ tuple(
1446
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1447
+ ),
1448
+ )
1449
+ return reordered_past
sample.srt ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1
2
+ 00:00:00,000 --> 00:00:01,845
3
+ Sadhguru: If you activate this dimension of energy,
4
+ 2
5
+ 00:00:01,845 --> 00:00:05,013
6
+ other dimensions of life will open up.
7
+ 3
8
+ 00:00:05,013 --> 00:00:07,545
9
+ One thing is,
10
+ 4
11
+ 00:00:07,545 --> 00:00:10,377
12
+ are you ready for those dimensions?
13
+ 5
14
+ 00:00:10,377 --> 00:00:11,280
15
+ Kundalini Yoga,
16
+ 6
17
+ 00:00:11,280 --> 00:00:12,318
18
+ in its essence,
19
+ 7
20
+ 00:00:12,318 --> 00:00:16,637
21
+ is the most dangerous form of yoga.
22
+ 8
23
+ 00:00:16,637 --> 00:00:18,008
24
+ I'm saying dangerous,
25
+ 9
26
+ 00:00:18,008 --> 00:00:20,538
27
+ because it's the most potent also.
28
+ 10
29
+ 00:00:20,538 --> 00:00:22,890
30
+ If you have to jump into an abyss,
31
+ 11
32
+ 00:00:22,890 --> 00:00:28,613
33
+ you should be insane or you should have enormous trust in somebody.
34
+ 12
35
+ 00:00:28:613 --> 00:00:37,465
36
+ So what is Kundalini?
37
+ 13
38
+ 00:00:37,465 --> 00:00:41,788
39
+ Right now,
40
+ 14
41
+ 00:00:41,788 --> 00:00:43,100
42
+ I'm speaking,
43
+ 15
44
+ 00:00:43,100 --> 00:00:46,495
45
+ this is Kundalini.
46
+ 16
47
+ 00:00:46,495 --> 00:00:47,843
48
+ You are alert and listening.
49
+ 17
50
+ 00:00:47,843 --> 00:00:49,483
51
+ If you are alert and listening,
52
+ 18
53
+ 00:00:49,483 --> 00:00:51,278
54
+ that is Kundalini
55
+ 19
56
+ 00:00:51,278 --> 00:00:54,528
57
+ A flower is blossoming,
58
+ 20
59
+ 00:00:54,528 --> 00:00:55,680
60
+ that is Kundalini.
61
+ 21
62
+ 00:00:55,680 --> 00:00:57,645
63
+ A dog is barking,
64
+ 22
65
+ 00:00:57,645 --> 00:00:59,480
66
+ that is also Kundalini,
67
+ 23
68
+ 00:00:59,480 --> 00:01:00,553
69
+ or in other words,
70
+ 24
71
+ 00:01:00,553 --> 00:01:03,545
72
+ the fundamental life force in the existence,
73
+ 25
74
+ 00:01:03,545 --> 00:01:04,880
75
+ we call it Kundalini.
76
+ 26
77
+ 00:01:04,880 --> 00:01:07,513
78
+ Now,
79
+ 27
80
+ 00:01:07,513 --> 00:01:08,555
81
+ within the system,
82
+ 28
83
+ 00:01:08,555 --> 00:01:10,377
84
+ within the human system,
85
+ 29
86
+ 00:01:10,377 --> 00:01:14,065
87
+ if you look at this as a kind of a life package,
88
+ 30
89
+ 00:01:14,065 --> 00:01:15,287
90
+ it's a piece of life.
91
+ 31
92
+ 00:01:15,287 --> 00:01:18,540
93
+ This piece of life is packed in a certain way,
94
+ 32
95
+ 00:01:18,540 --> 00:01:20,725
96
+ with layers of this energy.
97
+ 33
98
+ 00:01:20,725 --> 00:01:24,887
99
+ One dimension of energy comes alive immediately,
100
+ 34
101
+ 00:01:24,887 --> 00:01:28,667
102
+ because that is necessary for your survival process.
103
+ 35
104
+ 00:01:28,667 --> 00:01:37,369
105
+ The other dimensions of energy will not come alive unless you do something about it.
106
+ 36
107
+ 00:01:37,369 --> 00:01:40,203
108
+ Unless you're aware of it and activate it in a certain way,
109
+ 37
110
+ 00:01:40,203 --> 00:01:42,093
111
+ they do not come into existence.
112
+ 38
113
+ 00:01:42,093 --> 00:01:43,980
114
+ They remain dormant.
115
+ 39
116
+ 00:01:43,980 --> 00:01:49,932
117
+ The dormant energy is way bigger than the energy that is in use right now.
118
+ 40
119
+ 00:01:49,932 --> 00:01:53,582
120
+ To take care of your survival process,
121
+ 41
122
+ 00:01:53,582 --> 00:01:57,169
123
+ to live a physical life completely,
124
+ 42
125
+ 00:01:57,169 --> 00:02:00,902
126
+ to live a complete physical and intellectual life,
127
+ 43
128
+ 00:02:00,902 --> 00:02:08,451
129
+ you need to activate only about twenty-one of your chakras.
130
+ 44
131
+ 00:02:08,451 --> 00:02:12,246
132
+ Out of this one hundred and fourteen,
133
+ 45
134
+ 00:02:12,246 --> 00:02:14,444
135
+ if about twenty-one of them are on,
136
+ 46
137
+ 00:02:14,444 --> 00:02:16,248
138
+ you will live a complete life,
139
+ 47
140
+ 00:02:16,248 --> 00:02:18,322
141
+ you will not feel any inadequacy.
142
+ 48
143
+ 00:02:18,322 --> 00:02:21,267
144
+ You will live a complete physical life.
145
+ 49
146
+ 00:02:21,267 --> 00:02:23,112
147
+ There'll be no problem with your life,
148
+ 50
149
+ 00:02:23,112 --> 00:02:25,249
150
+ you will think you're a great success,
151
+ 51
152
+ 00:02:25,249 --> 00:02:29,458
153
+ but you're only twenty-one,
154
+ 52
155
+ 00:02:29,458 --> 00:02:34,000
156
+ that is less than twenty-one percent out of one hundred and fourteen,
157
+ 53
158
+ 00:02:34,000 --> 00:02:35,949
159
+ less than twenty percent.
160
+ 54
161
+ 00:02:35,949 --> 00:02:37,733
162
+ At let… less than twenty percent,
163
+ 55
164
+ 00:02:37,733 --> 00:02:41,984
165
+ you will feel like a complete life without any inadequacies.
166
+ 56
167
+ 00:02:41,984 --> 00:02:44,618
168
+ The remaining percentage of life,
169
+ 57
170
+ 00:02:44,618 --> 00:02:46,339
171
+ what is it about?
172
+ 58
173
+ 00:02:46,339 --> 00:02:50,632
174
+ It is not even needed if your intention is just to live well.
175
+ 59
176
+ 00:02:50,632 --> 00:02:53,846
177
+ If you activate this dimension of energy,
178
+ 60
179
+ 00:02:53,846 --> 00:02:57,143
180
+ other dimensions of life will open up.
181
+ 61
182
+ 00:02:57,143 --> 00:02:58,865
183
+ One thing is,
184
+ 62
185
+ 00:02:58,865 --> 00:03:00,980
186
+ “Are you ready for those dimensions?”
187
+ 63
188
+ 00:03:00,980 --> 00:03:05,003
189
+ The question is not about whether it's good or bad.
190
+ 64
191
+ 00:03:05,003 --> 00:03:06,890
192
+ The question is just about,
193
+ 65
194
+ 00:03:06,890 --> 00:03:08,902
195
+ “Are you ready for it?”
196
+ 66
197
+ 00:03:08,902 --> 00:03:13,920
198
+ Because even if the best things in life come to your life,
199
+ 67
200
+ 00:03:13,920 --> 00:03:16,803
201
+ when you are not ready for it,
202
+ 68
203
+ 00:03:16,803 --> 00:03:19,084
204
+ it will not be a good thing for you.
205
+ 69
206
+ 00:03:19,084 --> 00:03:22,713
207
+ In your experience it will not be a good thing
208
+ 70
209
+ 00:03:22,713 --> 00:03:25,223
210
+ if something came to you when you're not ready for it,
211
+ 71
212
+ 00:03:25,223 --> 00:03:26,467
213
+ isn't it so?
214
+ 72
215
+ 00:03:26,467 --> 00:03:28,250
216
+ Even if it's a greatest thing,
217
+ 73
218
+ 00:03:28,250 --> 00:03:29,702
219
+ it may be the greatest thing,
220
+ 74
221
+ 00:03:29,702 --> 00:03:33,539
222
+ but it came to you when you are not ready for it.
223
+ 75
224
+ 00:03:33,539 --> 00:03:34,969
225
+ Then it is not a good thing,
226
+ 76
227
+ 00:03:34,969 --> 00:03:36,525
228
+ isn't it?
229
+ 77
230
+ 00:03:36,525 --> 00:03:38,018
231
+ So are you ready for it,
232
+ 78
233
+ 00:03:38,018 --> 00:03:39,034
234
+ is the first question.
235
+ 79
236
+ 00:03:39,034 --> 00:03:40,071
237
+ If you're ready for it,
238
+ 80
239
+ 00:03:40,071 --> 00:03:41,709
240
+ what can we do for it?
241
+ 81
242
+ 00:03:41,709 --> 00:03:44,073
243
+ What can we do to activate it?
244
+ 82
245
+ 00:03:44,073 --> 00:03:46,541
246
+ The various ways of doing this,
247
+ 83
248
+ 00:03:46,541 --> 00:03:47,143
249
+ many,
250
+ 84
251
+ 00:03:47,143--> 00:03:48,200
252
+ many ways,
253
+ 85
254
+ 00:03:48,200 --> 00:03:50,938
255
+ but the Kundalini Yoga…
256
+ 86
257
+ 00:03:50,938 --> 00:03:56,558
258
+ are people familiar with Kundalini Yoga practicing or…?
259
+ 87
260
+ 00:03:56,558 --> 00:04:00,083
261
+ Okay.
262
+ 88
263
+ 00:04:00,083 --> 00:04:01,535
264
+ Kundalini Yoga.
265
+ 89
266
+ 00:04:01,535 --> 00:04:04,708
267
+ I…I'm not making a comment about anybody,
268
+ 90
269
+ 00:04:04,708 --> 00:04:07,279
270
+ okay?
271
+ 91
272
+ 00:04:07,279 --> 00:04:08,378
273
+ Kundalini Yoga,
274
+ 92
275
+ 00:04:08,378 --> 00:04:13,687
276
+ in its essence is the most dangerous form of yoga.
277
+ 93
278
+ 00:04:13,687 --> 00:04:15,471
279
+ I'm saying dangerous,
280
+ 94
281
+ 00:04:15,471 --> 00:04:19,203
282
+ because it's the most potent also.
283
+ 95
284
+ 00:04:19,203 --> 00:04:27,561
285
+ What is most potent is always the most dangerous if improperly handled.
286
+ 96
287
+ 00:04:27,561 --> 00:04:31,169
288
+ There are various kinds of energy in the world right now,
289
+ 97
290
+ 00:04:31,169 --> 00:04:33,886
291
+ even the electricity is being manufactured in… I mean,
292
+ 98
293
+ 00:04:33,886 --> 00:04:36,126
294
+ produced in so many different ways.
295
+ 99
296
+ 00:04:36,126 --> 00:04:40,170
297
+ One of the ways that we do it is through nuclear reactions,
298
+ 100
299
+ 00:04:40,170 --> 00:04:42,907
300
+ nuclear reactors rather..
301
+ 101
302
+ 00:04:42,907 --> 00:04:46,889
303
+ It is the most efficient way of producing energy that we know right now,
304
+ 102
305
+ 00:04:46,889 --> 00:04:49,211
306
+ but it is also the most dangerous way,
307
+ 103
308
+ 00:04:49,211 --> 00:04:51,223
309
+ isn't it?
310
+ 104
311
+ 00:04:51,223 --> 00:04:52,965
312
+ When things go wrong,
313
+ 105
314
+ 00:04:52,965 --> 00:04:55,744
315
+ they go seriously wrong.
316
+ 106
317
+ 00:04:55,744 --> 00:04:57,299
318
+ When they're going right,
319
+ 107
320
+ 00:04:57,299 --> 00:05:04,288
321
+ it is the easiest and the best way to produce energy on the planet is nuclear energy actually.
322
+ 108
323
+ 00:05:04,288 --> 00:05:06,258
324
+ But when it goes bad,
325
+ 109
326
+ 00:05:06,258 --> 00:05:07,502
327
+ it goes bad,
328
+ 110
329
+ 00:05:07,502 --> 00:05:08,539
330
+ really bad,
331
+ 111
332
+ 00:05:08,539 --> 00:05:11,256
333
+ like in ways that you can't fix it.
334
+ 112
335
+ 00:05:11,256 --> 00:05:11,691
336
+ So,
337
+ 113
338
+ 00:05:11,691 --> 00:05:13,454
339
+ similarly with Kundalini Yoga,
340
+ 114
341
+ 00:05:13,454 --> 00:05:17,353
342
+ it is the most potent and it is the most dangerous.
343
+ 115
344
+ 00:05:17,353 --> 00:05:21,106
345
+ Without the necessary preparation and guidance,
346
+ 116
347
+ 00:05:21,106 --> 00:05:23,387
348
+ without expert guidance,
349
+ 117
350
+ 00:05:23,387 --> 00:05:25,793
351
+ constant guidance and observation,
352
+ 118
353
+ 00:05:25,793 --> 00:05:28,219
354
+ nobody should ever attempt it.
355
+ 119
356
+ 00:05:28,219 --> 00:05:29,588
357
+ But the problem is,
358
+ 120
359
+ 00:05:29,588 --> 00:05:33,611
360
+ books have been written about it and everybody wants to do the highest yoga.
361
+ 121
362
+ 00:05:33,611 --> 00:05:35,830
363
+ Nobody wants to start with A,
364
+ 122
365
+ 00:05:35,830 --> 00:05:38,650
366
+ everybody wants to start the alphabet with Z.
367
+ 123
368
+ 00:05:38,650 --> 00:05:42,383
369
+ This attitude itself is dangerous.
370
+ 124
371
+ 00:05:42,383 --> 00:05:50,326
372
+ What can be a life-transforming force can become a life-destructive force
373
+ 125
374
+ 00:05:50,326 --> 00:05:55,780
375
+ simply because without the necessary commitment and dedication and focus and understanding
376
+ 126
377
+ 00:05:55,780 --> 00:05:57,294
378
+ it is being handled.
379
+ 127
380
+ 00:05:57,294 --> 00:05:58,849
381
+ Anyway,
382
+ 128
383
+ 00:05:58,849 --> 00:06:00,550
384
+ about raising the Kundalini,
385
+ 129
386
+ 00:06:00,550 --> 00:06:02,292
387
+ if the Kundalini rises,
388
+ 130
389
+ 00:06:02,292 --> 00:06:08,534
390
+ the dimensions of your life will change so rapidly that you must be willing
391
+ 131
392
+ 00:06:08,534 --> 00:06:11,562
393
+ to make the outside admin… adjustments equally,
394
+ 132
395
+ 00:06:11,562 --> 00:06:12,806
396
+ quick.
397
+ 133
398
+ 00:06:12,806 --> 00:06:14,320
399
+ Otherwise,
400
+ 134
401
+ 00:06:14,320 --> 00:06:18,426
402
+ things will fall apart in a big way.
403
+ 135
404
+ 00:06:18,426 --> 00:06:20,624
405
+ In the classical yogic traditions,
406
+ 136
407
+ 00:06:20,624 --> 00:06:26,161
408
+ there is a certain type of yoga we teach for people who live in family situations.
409
+ 137
410
+ 00:06:26,161 --> 00:06:30,205
411
+ There is a certain other type of yoga we teach for ascetics.
412
+ 138
413
+ 00:06:30,205 --> 00:06:35,159
414
+ In Isha,
415
+ 139
416
+ 00:06:35,159 --> 00:06:32,486
417
+ we have both the forms,
418
+ 140
419
+ 00:06:32,486 --> 00:06:36,219
420
+ we have ascetic yoga and we have the general yoga.
421
+ 141
422
+ 00:06:36,219 --> 00:06:38,915
423
+ We will never teach you the ascetic form.
424
+ 142
425
+ 00:06:38,915 --> 00:06:41,590
426
+ That is the most pos… potent way to do it.
427
+ 143
428
+ 00:06:41,590 --> 00:06:46,733
429
+ But it will demand a certain dimension of discipline and focus,
430
+ 144
431
+ 00:06:46,733 --> 00:06:50,113
432
+ which your regular lives will not allow.
433
+ 145
434
+ 00:06:50,113 --> 00:06:52,498
435
+ If you do that kind of yoga,
436
+ 146
437
+ 00:06:52,498 --> 00:06:56,273
438
+ it will dismantle your outside life instantly.
439
+ 147
440
+ 00:06:56,273 --> 00:06:59,632
441
+ Now this Yoga is not designed to dismantle your life,
442
+ 148
443
+ 00:06:59,632 --> 00:07:05,874
444
+ this Yoga is designed to make your life happen better.
445
+ 149
446
+ 00:07:05,874 --> 00:07:07,907
447
+ When life happens better,
448
+ 150
449
+ 00:07:07,907 --> 00:07:09,835
450
+ when things happen better,
451
+ 151
452
+ 00:07:09,835 --> 00:07:11,328
453
+ you make more money,
454
+ 152
455
+ 00:07:11,328 --> 00:07:13,112
456
+ your business is going better,
457
+ 153
458
+ 00:07:13,112 --> 00:07:14,854
459
+ your profession is happening better.
460
+ 154
461
+ 00:07:14,854 --> 00:07:19,644
462
+ You're generally unfortunately,
463
+ 155
464
+ 00:07:19,644 --> 00:07:23,833
465
+ you are longing to seek the higher becomes slower.
466
+ 156
467
+ 00:07:23,833 --> 00:07:25,512
468
+ Yes.
469
+ 157
470
+ 00:07:25,512 --> 00:07:27,006
471
+ So in the real sense,
472
+ 158
473
+ 00:07:27,006 --> 00:07:28,873
474
+ it is not the good way to do it.
475
+ 159
476
+ 00:07:28,873 --> 00:07:31,714
477
+ But it's the only way it works in today's world.
478
+ 160
479
+ 00:07:31,714 --> 00:07:34,596
480
+ And it's the only way it works for majority of the people.
481
+ 161
482
+ 00:07:34,596 --> 00:07:36,255
483
+ For a small number of people,
484
+ 162
485
+ 00:07:36,255 --> 00:07:38,018
486
+ we can do it other ways.
487
+ 163
488
+ 00:07:38,018 --> 00:07:43,680
489
+ We can bypass all these things and just do very powerful ways of doing things.
490
+ 164
491
+ 00:07:43,680 --> 00:07:47,329
492
+ But it will dismantle all social structures around them,
493
+ 165
494
+ 00:07:47,329 --> 00:07:49,631
495
+ which is not good for everybody to do.
496
+ 166
497
+ 00:07:49,631 --> 00:07:51,809
498
+ So these are different dimensions.
499
+ 167
500
+ 00:07:51,809 --> 00:07:52,887
501
+ Kundalini Yoga,
502
+ 168
503
+ 00:07:52,887 --> 00:07:54,443
504
+ if it has to be practiced,
505
+ 169
506
+ 00:07:54,443 --> 00:07:57,263
507
+ you must be in a certain kind of atmosphere.
508
+ 170
509
+ 00:07:57,263 --> 00:08:01,618
510
+ You cannot live in social situations and do Kundalini Yoga
511
+ 171
512
+ 00:08:01,618 --> 00:08:02,468
513
+ Otherwise,
514
+ 172
515
+ 00:08:02,468 --> 00:08:03,878
516
+ in the name of Kundalini Yoga,
517
+ 173
518
+ 00:08:03,878 --> 00:08:06,263
519
+ you're doing something simplistic.
520
+ 174
521
+ 00:08:06,263 --> 00:08:07,176
522
+ Otherwise,
523
+ 175
524
+ 00:08:07,176 --> 00:08:10,971
525
+ Kundalini yoga can transform the way you are within days.
526
+ 176
527
+ 00:08:10,971 --> 00:08:16,674
528
+ Suddenly you find you're a stranger in your own home within two days of practice,
529
+ 177
530
+ 00:08:16,674 --> 00:08:20,469
531
+ because it will change everything about you.
532
+ 178
533
+ 00:08:20,469 --> 00:08:20,759
534
+ So,
535
+ 179
536
+ 00:08:20,759 --> 00:08:22,169
537
+ can we raise the Kundalini?
538
+ 180
539
+ 00:08:22,169 --> 00:08:22,480
540
+ Yes,
541
+ 181
542
+ 00:08:22,480 --> 00:08:25,135
543
+ we can.
544
+ 182
545
+ 00:08:25,135 --> 00:08:30,651
546
+ One way is to create a conducive atmosphere so that slowly it rises.
547
+ 183
548
+ 00:08:30,651 --> 00:08:35,110
549
+ The other way is to provoke it in such a way that it rises quickly.
550
+ 184
551
+ 00:08:35,110 --> 00:08:37,225
552
+ If it rises quickly,
553
+ 185
554
+ 00:08:37,225 --> 00:08:39,423
555
+ then everything changes dramatically.
556
+ 186
557
+ 00:08:39,423 --> 00:08:42,223
558
+ If it rises slowly over a period of time,
559
+ 187
560
+ 00:08:42,223 --> 00:08:44,234
561
+ changes will happen slowly,
562
+ 188
563
+ 00:08:44,234 --> 00:08:48,216
564
+ you will be capable of handling these changes over a period of time.
565
+ 189
566
+ 00:08:48,216 --> 00:08:49,958
567
+ But if it happens very quick,
568
+ 190
569
+ 00:08:49,958 --> 00:08:53,193
570
+ then you will not be able to handle the changes,
571
+ 191
572
+ 00:08:53,193 --> 00:08:56,677
573
+ things will look like things are falling apart.
574
+ 192
575
+ 00:08:56,677 --> 00:08:58,958
576
+ So there are different ways of doing this.
577
+ 193
578
+ 00:08:58,958 --> 00:08:59,705
579
+ How many ways?
580
+ 194
581
+ 00:08:59,705 --> 00:09:00,887
582
+ There are too many ways,
583
+ 195
584
+ 00:09:00,887 --> 00:09:03,023
585
+ I will not go into how many ways.
586
+ 196
587
+ 00:09:03,023 --> 00:09:04,972
588
+ There are so many ways of doing it.
589
+ 197
590
+ 00:09:04,972 --> 00:09:05,885
591
+ Essentially,
592
+ 198
593
+ 00:09:05,885 --> 00:09:10,903
594
+ there are one hundred and twelve ways of doing it.
595
+ 199
596
+ 00:09:10,903 --> 00:09:14,595
597
+ There are hundred and twelve ways in which you can take this up,
598
+ 200
599
+ 00:09:14,595 --> 00:09:17,602
600
+ from the base to…
601
+ 201
602
+ 00:09:17,602 --> 00:09:18,099
603
+ Oh!
604
+ 202
605
+ 00:09:18,099 --> 00:09:19,675
606
+ for this you have to know the structure,
607
+ 203
608
+ 00:09:19,675 --> 00:09:22,122
609
+ otherwise it will become very elaborate.
610
+ 204
611
+ 00:09:22,122 --> 00:09:25,606
612
+ Hmm,
613
+ 205
614
+ 00:09:25,606 --> 00:09:28,385
615
+ out of this one hundred and fourteen chakras,
616
+ 206
617
+ 00:09:28,385 --> 00:09:33,093
618
+ there are seven which we recognize as seven dimensions.
619
+ 207
620
+ 00:09:33,093 --> 00:09:34,462
621
+ Out of this,
622
+ 208
623
+ 00:09:34,462 --> 00:09:36,287
624
+ six are within the body,
625
+ 209
626
+ 00:09:36,287 --> 00:09:39,750
627
+ one just outside the physical body.
628
+ 210
629
+ 00:09:39,750 --> 00:09:40,144
630
+ So,
631
+ 211
632
+ 00:09:40,144 --> 00:09:43,192
633
+ if you employ this one hundred and twelve methods,
634
+ 212
635
+ 00:09:43,192 --> 00:09:45,639
636
+ you will handle the six chakras,
637
+ 213
638
+ 00:09:45,639 --> 00:09:48,190
639
+ the seventh one you cannot handle.
640
+ 214
641
+ 00:09:48,190 --> 00:09:52,151
642
+ There are hundred and twelve ways in which you can at… attain to a chakra
643
+ 215
644
+ 00:09:52,151 --> 00:09:53,976
645
+ which we refer to as Agna,
646
+ 216
647
+ 00:09:53,976 --> 00:09:55,448
648
+ but from Agna to Sahasrara,
649
+ 217
650
+ 00:09:55,448 --> 00:09:57,398
651
+ there is no way.
652
+ 218
653
+ 00:09:57,398 --> 00:09:59,264
654
+ There is no way to do it,
655
+ 219
656
+ 00:09:59,264 --> 00:10:02,520
657
+ you just have to jump into an abyss.
658
+ 220
659
+ 00:10:02,520 --> 00:10:04,843
660
+ If you have to jump into an abyss,
661
+ 221
662
+ 00:10:04,843 --> 00:10:11,417
663
+ you should be insane or you should have enormous trust in somebody.
664
+ 222
665
+ 00:10:11,417 --> 00:10:15,543
666
+ Somebody says jump and you are jumping because
667
+ 223
668
+ 00:10:15,543 --> 00:10:19,214
669
+ you have such a deep trust in somebody that when he says jump,
670
+ 224
671
+ 00:10:19,214 --> 00:10:21,226
672
+ it has to be good for you.
673
+ 225
674
+ 00:10:21,226 --> 00:10:26,078
675
+ You simply jump into a bottomless pit.
676
+ 226
677
+ 00:10:26,078 --> 00:10:27,011
678
+ So,
679
+ 227
680
+ 00:10:27,011 --> 00:10:33,088
681
+ the journey from the Mooladhara to Agna there are one hundred and twelve ways to get there.
682
+ 228
683
+ 00:10:33,088 --> 00:10:34,311
684
+ But from there to there,
685
+ 229
686
+ 00:10:34,311 --> 00:10:35,307
687
+ there is no way.
688
+ 230
689
+ 00:10:35,307 --> 00:10:36,841
690
+ It is just one jump,
691
+ 231
692
+ 00:10:36,841 --> 00:10:38,998
693
+ that can happen in trust,
694
+ 232
695
+ 00:10:38,998 --> 00:10:42,316
696
+ in devotion or in madness.
697
+ 233
698
+ 00:10:42,316 --> 00:10:57,780
699
+ Choice is yours.