Commit
•
9d452e1
1
Parent(s):
1521005
codebase added
Browse files- .gitattributes +12 -0
- IndicTransTokenizer/__init__.py +0 -0
- IndicTransTokenizer/__pycache__/__init__.cpython-39.pyc +0 -0
- IndicTransTokenizer/__pycache__/tokenizer.cpython-39.pyc +0 -0
- IndicTransTokenizer/__pycache__/utils.cpython-39.pyc +0 -0
- IndicTransTokenizer/en-indic/dict.SRC.json +3 -0
- IndicTransTokenizer/en-indic/dict.TGT.json +3 -0
- IndicTransTokenizer/en-indic/model.SRC +3 -0
- IndicTransTokenizer/en-indic/model.TGT +3 -0
- IndicTransTokenizer/indic-en/dict.SRC.json +3 -0
- IndicTransTokenizer/indic-en/dict.TGT.json +3 -0
- IndicTransTokenizer/indic-en/model.SRC +3 -0
- IndicTransTokenizer/indic-en/model.TGT +3 -0
- IndicTransTokenizer/tokenizer.py +259 -0
- IndicTransTokenizer/utils.py +591 -0
- README.md +62 -0
- configuration_indictrans.py +307 -0
- convert_indictrans_checkpoint_to_pytorch.py +107 -0
- example.py +125 -0
- handler.py +194 -0
- install.sh +52 -0
- modeling_indictrans.py +1449 -0
- sample.srt +699 -0
.gitattributes
CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
IndicTransTokenizer/en-indic/model.TGT filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.TGT filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.SRC filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.SRC.json filter=lfs diff=lfs merge=lfs -text
|
40 |
+
*.TGT.json filter=lfs diff=lfs merge=lfs -text
|
41 |
+
IndicTransTokenizer/en-indic/dict.SRC.json filter=lfs diff=lfs merge=lfs -text
|
42 |
+
IndicTransTokenizer/en-indic/dict.TGT.json filter=lfs diff=lfs merge=lfs -text
|
43 |
+
IndicTransTokenizer/en-indic/model.SRC filter=lfs diff=lfs merge=lfs -text
|
44 |
+
IndicTransTokenizer/indic-en/dict.SRC.json filter=lfs diff=lfs merge=lfs -text
|
45 |
+
IndicTransTokenizer/indic-en/dict.TGT.json filter=lfs diff=lfs merge=lfs -text
|
46 |
+
IndicTransTokenizer/indic-en/model.SRC filter=lfs diff=lfs merge=lfs -text
|
47 |
+
IndicTransTokenizer/indic-en/model.TGT filter=lfs diff=lfs merge=lfs -text
|
IndicTransTokenizer/__init__.py
ADDED
File without changes
|
IndicTransTokenizer/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (156 Bytes). View file
|
|
IndicTransTokenizer/__pycache__/tokenizer.cpython-39.pyc
ADDED
Binary file (10.1 kB). View file
|
|
IndicTransTokenizer/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (14.1 kB). View file
|
|
IndicTransTokenizer/en-indic/dict.SRC.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99cabf338bf3db11eafae2769584b8b5d3aa579989feb7e9f72236bdf69810e1
|
3 |
+
size 645274
|
IndicTransTokenizer/en-indic/dict.TGT.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f7817850c9e4b99c59fad57d0611c7720f1921f215e6f247cf25d52eff7f1146
|
3 |
+
size 3390440
|
IndicTransTokenizer/en-indic/model.SRC
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3cedc5cbcc740369b76201942a0f096fec7287fee039b55bdb956f301235b914
|
3 |
+
size 759425
|
IndicTransTokenizer/en-indic/model.TGT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
|
3 |
+
size 3256903
|
IndicTransTokenizer/indic-en/dict.SRC.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f9cdb988b42c4e0f4fce5e44cc66975e5088a96d111a149b9ac7d55059b8ec1
|
3 |
+
size 3391183
|
IndicTransTokenizer/indic-en/dict.TGT.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:13c3a162fe655dbe99c790a413675c5d0634cd771fadcefe8d407676a7d1a311
|
3 |
+
size 644755
|
IndicTransTokenizer/indic-en/model.SRC
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
|
3 |
+
size 3256903
|
IndicTransTokenizer/indic-en/model.TGT
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3cedc5cbcc740369b76201942a0f096fec7287fee039b55bdb956f301235b914
|
3 |
+
size 759425
|
IndicTransTokenizer/tokenizer.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from transformers import BatchEncoding
|
6 |
+
from typing import Dict, List, Tuple, Union
|
7 |
+
from sentencepiece import SentencePieceProcessor
|
8 |
+
|
9 |
+
_PATH = os.path.dirname(os.path.realpath(__file__))
|
10 |
+
|
11 |
+
|
12 |
+
class IndicTransTokenizer:
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
src_vocab_fp=None,
|
16 |
+
tgt_vocab_fp=None,
|
17 |
+
src_spm_fp=None,
|
18 |
+
tgt_spm_fp=None,
|
19 |
+
unk_token="<unk>",
|
20 |
+
bos_token="<s>",
|
21 |
+
eos_token="</s>",
|
22 |
+
pad_token="<pad>",
|
23 |
+
direction="indic-en",
|
24 |
+
model_max_length=256,
|
25 |
+
):
|
26 |
+
self.model_max_length = model_max_length
|
27 |
+
|
28 |
+
self.supported_langs = [
|
29 |
+
"asm_Beng",
|
30 |
+
"ben_Beng",
|
31 |
+
"brx_Deva",
|
32 |
+
"doi_Deva",
|
33 |
+
"eng_Latn",
|
34 |
+
"gom_Deva",
|
35 |
+
"guj_Gujr",
|
36 |
+
"hin_Deva",
|
37 |
+
"kan_Knda",
|
38 |
+
"kas_Arab",
|
39 |
+
"kas_Deva",
|
40 |
+
"mai_Deva",
|
41 |
+
"mal_Mlym",
|
42 |
+
"mar_Deva",
|
43 |
+
"mni_Beng",
|
44 |
+
"mni_Mtei",
|
45 |
+
"npi_Deva",
|
46 |
+
"ory_Orya",
|
47 |
+
"pan_Guru",
|
48 |
+
"san_Deva",
|
49 |
+
"sat_Olck",
|
50 |
+
"snd_Arab",
|
51 |
+
"snd_Deva",
|
52 |
+
"tam_Taml",
|
53 |
+
"tel_Telu",
|
54 |
+
"urd_Arab",
|
55 |
+
]
|
56 |
+
|
57 |
+
self.src_vocab_fp = (
|
58 |
+
src_vocab_fp
|
59 |
+
if (src_vocab_fp is not None)
|
60 |
+
else os.path.join(_PATH, direction, "dict.SRC.json")
|
61 |
+
)
|
62 |
+
self.tgt_vocab_fp = (
|
63 |
+
tgt_vocab_fp
|
64 |
+
if (tgt_vocab_fp is not None)
|
65 |
+
else os.path.join(_PATH, direction, "dict.TGT.json")
|
66 |
+
)
|
67 |
+
self.src_spm_fp = (
|
68 |
+
src_spm_fp
|
69 |
+
if (src_spm_fp is not None)
|
70 |
+
else os.path.join(_PATH, direction, "model.SRC")
|
71 |
+
)
|
72 |
+
self.tgt_spm_fp = (
|
73 |
+
tgt_spm_fp
|
74 |
+
if (tgt_spm_fp is not None)
|
75 |
+
else os.path.join(_PATH, direction, "model.TGT")
|
76 |
+
)
|
77 |
+
|
78 |
+
self.unk_token = unk_token
|
79 |
+
self.pad_token = pad_token
|
80 |
+
self.eos_token = eos_token
|
81 |
+
self.bos_token = bos_token
|
82 |
+
|
83 |
+
self.encoder = self._load_json(self.src_vocab_fp)
|
84 |
+
if self.unk_token not in self.encoder:
|
85 |
+
raise KeyError("<unk> token must be in vocab")
|
86 |
+
assert self.pad_token in self.encoder
|
87 |
+
self.encoder_rev = {v: k for k, v in self.encoder.items()}
|
88 |
+
|
89 |
+
self.decoder = self._load_json(self.tgt_vocab_fp)
|
90 |
+
if self.unk_token not in self.encoder:
|
91 |
+
raise KeyError("<unk> token must be in vocab")
|
92 |
+
assert self.pad_token in self.encoder
|
93 |
+
self.decoder_rev = {v: k for k, v in self.decoder.items()}
|
94 |
+
|
95 |
+
# load SentencePiece model for pre-processing
|
96 |
+
self.src_spm = self._load_spm(self.src_spm_fp)
|
97 |
+
self.tgt_spm = self._load_spm(self.tgt_spm_fp)
|
98 |
+
|
99 |
+
def is_special_token(self, x: str):
|
100 |
+
return (x == self.pad_token) or (x == self.bos_token) or (x == self.eos_token)
|
101 |
+
|
102 |
+
def get_vocab_size(self, src: bool) -> int:
|
103 |
+
"""Returns the size of the vocabulary"""
|
104 |
+
return len(self.encoder) if src else len(self.decoder)
|
105 |
+
|
106 |
+
def _load_spm(self, path: str) -> SentencePieceProcessor:
|
107 |
+
return SentencePieceProcessor(model_file=path)
|
108 |
+
|
109 |
+
def _save_json(self, data, path: str) -> None:
|
110 |
+
with open(path, "w", encoding="utf-8") as f:
|
111 |
+
json.dump(data, f, indent=2)
|
112 |
+
|
113 |
+
def _load_json(self, path: str) -> Union[Dict, List]:
|
114 |
+
with open(path, "r", encoding="utf-8") as f:
|
115 |
+
return json.load(f)
|
116 |
+
|
117 |
+
def _convert_token_to_id(self, token: str, src: bool) -> int:
|
118 |
+
"""Converts an token (str) into an index (integer) using the source/target vocabulary map."""
|
119 |
+
return (
|
120 |
+
self.encoder.get(token, self.encoder[self.unk_token])
|
121 |
+
if src
|
122 |
+
else self.decoder.get(token, self.encoder[self.unk_token])
|
123 |
+
)
|
124 |
+
|
125 |
+
def _convert_id_to_token(self, index: int, src: bool) -> str:
|
126 |
+
"""Converts an index (integer) into a token (str) using the source/target vocabulary map."""
|
127 |
+
return (
|
128 |
+
self.encoder_rev.get(index, self.unk_token)
|
129 |
+
if src
|
130 |
+
else self.decoder_rev.get(index, self.unk_token)
|
131 |
+
)
|
132 |
+
|
133 |
+
def _convert_tokens_to_string(self, tokens: List[str], src: bool) -> str:
|
134 |
+
"""Uses sentencepiece model for detokenization"""
|
135 |
+
if src:
|
136 |
+
if tokens[0] in self.supported_langs and tokens[1] in self.supported_langs:
|
137 |
+
tokens = tokens[2:]
|
138 |
+
return " ".join(tokens)
|
139 |
+
else:
|
140 |
+
return " ".join(tokens)
|
141 |
+
|
142 |
+
def _remove_translation_tags(self, text: str) -> Tuple[List, str]:
|
143 |
+
"""Removes the translation tags before text normalization and tokenization."""
|
144 |
+
tokens = text.split(" ")
|
145 |
+
return tokens[:2], " ".join(tokens[2:])
|
146 |
+
|
147 |
+
def _tokenize_src_line(self, line: str) -> List[str]:
|
148 |
+
"""Tokenizes a source line."""
|
149 |
+
tags, text = self._remove_translation_tags(line)
|
150 |
+
tokens = self.src_spm.encode(text, out_type=str)
|
151 |
+
return tags + tokens
|
152 |
+
|
153 |
+
def _tokenize_tgt_line(self, line: str) -> List[str]:
|
154 |
+
"""Tokenizes a target line."""
|
155 |
+
return self.tgt_spm.encode(line, out_type=str)
|
156 |
+
|
157 |
+
def tokenize(self, text: str, src: bool) -> List[str]:
|
158 |
+
"""Tokenizes a string into tokens using the source/target vocabulary."""
|
159 |
+
return self._tokenize_src_line(text) if src else self._tokenize_tgt_line(text)
|
160 |
+
|
161 |
+
def batch_tokenize(self, batch: List[str], src: bool) -> List[List[str]]:
|
162 |
+
"""Tokenizes a list of strings into tokens using the source/target vocabulary."""
|
163 |
+
return [self.tokenize(line, src) for line in batch]
|
164 |
+
|
165 |
+
def _create_attention_mask(self, ids: List[int], max_seq_len: int) -> List[int]:
|
166 |
+
"""Creates a attention mask for the input sequence."""
|
167 |
+
return ([0] * (max_seq_len - len(ids))) + ([1] * (len(ids) + 1))
|
168 |
+
|
169 |
+
def _pad_batch(self, tokens: List[str], max_seq_len: int) -> List[str]:
|
170 |
+
"""Pads a batch of tokens and adds BOS/EOS tokens."""
|
171 |
+
return (
|
172 |
+
([self.pad_token] * (max_seq_len - len(tokens))) + tokens + [self.eos_token]
|
173 |
+
)
|
174 |
+
|
175 |
+
def _decode_line(self, ids: List[int], src: bool) -> List[str]:
|
176 |
+
return [self._convert_id_to_token(_id, src) for _id in ids]
|
177 |
+
|
178 |
+
def _encode_line(self, tokens: List[str], src: bool) -> List[int]:
|
179 |
+
return [self._convert_token_to_id(token, src) for token in tokens]
|
180 |
+
|
181 |
+
def _strip_special_tokens(self, tokens: List[str]) -> List[str]:
|
182 |
+
return [token for token in tokens if not self.is_special_token(token)]
|
183 |
+
|
184 |
+
def _single_input_preprocessing(
|
185 |
+
self, tokens: List[str], src: bool, max_seq_len: int
|
186 |
+
) -> Tuple[List[int], List[int], int]:
|
187 |
+
"""Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map."""
|
188 |
+
attention_mask = self._create_attention_mask(tokens, max_seq_len)
|
189 |
+
padded_tokens = self._pad_batch(tokens, max_seq_len)
|
190 |
+
input_ids = self._encode_line(padded_tokens, src)
|
191 |
+
return input_ids, attention_mask
|
192 |
+
|
193 |
+
def _single_output_postprocessing(self, ids: List[int], src: bool) -> str:
|
194 |
+
"""Detokenizes a list of integer ids into a string using the source/target vocabulary."""
|
195 |
+
tokens = self._decode_line(ids, src)
|
196 |
+
tokens = self._strip_special_tokens(tokens)
|
197 |
+
return self._convert_tokens_to_string(tokens, src)
|
198 |
+
|
199 |
+
def __call__(
|
200 |
+
self,
|
201 |
+
batch: Union[list, str],
|
202 |
+
src: bool,
|
203 |
+
truncation: bool = False,
|
204 |
+
padding: str = "longest",
|
205 |
+
max_length: int = None,
|
206 |
+
return_tensors: str = "pt",
|
207 |
+
return_attention_mask: bool = True,
|
208 |
+
return_length: bool = False,
|
209 |
+
) -> BatchEncoding:
|
210 |
+
"""Tokenizes a string into tokens and also converts them into integers using source/target vocabulary map."""
|
211 |
+
assert padding in [
|
212 |
+
"longest",
|
213 |
+
"max_length",
|
214 |
+
], "padding should be either 'longest' or 'max_length'"
|
215 |
+
|
216 |
+
if not isinstance(batch, list):
|
217 |
+
raise TypeError(
|
218 |
+
f"batch must be a list, but current batch is of type {type(batch)}"
|
219 |
+
)
|
220 |
+
|
221 |
+
# tokenize the source sentences
|
222 |
+
batch = self.batch_tokenize(batch, src)
|
223 |
+
|
224 |
+
# truncate the sentences if needed
|
225 |
+
if truncation and max_length is not None:
|
226 |
+
batch = [ids[:max_length] for ids in batch]
|
227 |
+
|
228 |
+
lengths = [len(ids) for ids in batch]
|
229 |
+
|
230 |
+
max_seq_len = max(lengths) if padding == "longest" else max_length
|
231 |
+
|
232 |
+
input_ids, attention_mask = zip(
|
233 |
+
*[
|
234 |
+
self._single_input_preprocessing(
|
235 |
+
tokens=tokens, src=src, max_seq_len=max_seq_len
|
236 |
+
)
|
237 |
+
for tokens in batch
|
238 |
+
]
|
239 |
+
)
|
240 |
+
|
241 |
+
_data = {"input_ids": input_ids}
|
242 |
+
|
243 |
+
if return_attention_mask:
|
244 |
+
_data["attention_mask"] = attention_mask
|
245 |
+
|
246 |
+
if return_length:
|
247 |
+
_data["lengths"] = lengths
|
248 |
+
|
249 |
+
return BatchEncoding(_data, tensor_type=return_tensors)
|
250 |
+
|
251 |
+
def batch_decode(
|
252 |
+
self, batch: Union[list, torch.Tensor], src: bool
|
253 |
+
) -> List[List[str]]:
|
254 |
+
"""Detokenizes a list of integer ids or a tensor into a list of strings using the source/target vocabulary."""
|
255 |
+
|
256 |
+
if isinstance(batch, torch.Tensor):
|
257 |
+
batch = batch.detach().cpu().tolist()
|
258 |
+
|
259 |
+
return [self._single_output_postprocessing(ids=ids, src=src) for ids in batch]
|
IndicTransTokenizer/utils.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import regex as re
|
2 |
+
from joblib import Parallel, delayed
|
3 |
+
from nltk.tokenize import sent_tokenize
|
4 |
+
from typing import List, Tuple, Union
|
5 |
+
|
6 |
+
from sacremoses import MosesPunctNormalizer
|
7 |
+
from indicnlp.normalize import indic_normalize
|
8 |
+
from sacremoses import MosesTokenizer, MosesDetokenizer
|
9 |
+
from indicnlp.transliterate import unicode_transliterate
|
10 |
+
from indicnlp.tokenize import indic_tokenize, indic_detokenize
|
11 |
+
from indicnlp.tokenize.sentence_tokenize import sentence_split, DELIM_PAT_NO_DANDA
|
12 |
+
|
13 |
+
en_tok = MosesTokenizer(lang="en")
|
14 |
+
en_normalizer = MosesPunctNormalizer()
|
15 |
+
en_detok = MosesDetokenizer(lang="en")
|
16 |
+
xliterator = unicode_transliterate.UnicodeIndicTransliterator()
|
17 |
+
|
18 |
+
|
19 |
+
flores_codes = {
|
20 |
+
"asm_Beng": "as",
|
21 |
+
"awa_Deva": "hi",
|
22 |
+
"ben_Beng": "bn",
|
23 |
+
"bho_Deva": "hi",
|
24 |
+
"brx_Deva": "hi",
|
25 |
+
"doi_Deva": "hi",
|
26 |
+
"eng_Latn": "en",
|
27 |
+
"gom_Deva": "kK",
|
28 |
+
"guj_Gujr": "gu",
|
29 |
+
"hin_Deva": "hi",
|
30 |
+
"hne_Deva": "hi",
|
31 |
+
"kan_Knda": "kn",
|
32 |
+
"kas_Arab": "ur",
|
33 |
+
"kas_Deva": "hi",
|
34 |
+
"kha_Latn": "en",
|
35 |
+
"lus_Latn": "en",
|
36 |
+
"mag_Deva": "hi",
|
37 |
+
"mai_Deva": "hi",
|
38 |
+
"mal_Mlym": "ml",
|
39 |
+
"mar_Deva": "mr",
|
40 |
+
"mni_Beng": "bn",
|
41 |
+
"mni_Mtei": "hi",
|
42 |
+
"npi_Deva": "ne",
|
43 |
+
"ory_Orya": "or",
|
44 |
+
"pan_Guru": "pa",
|
45 |
+
"san_Deva": "hi",
|
46 |
+
"sat_Olck": "or",
|
47 |
+
"snd_Arab": "ur",
|
48 |
+
"snd_Deva": "hi",
|
49 |
+
"tam_Taml": "ta",
|
50 |
+
"tel_Telu": "te",
|
51 |
+
"urd_Arab": "ur",
|
52 |
+
}
|
53 |
+
|
54 |
+
|
55 |
+
flores_to_iso = {
|
56 |
+
"asm_Beng": "as",
|
57 |
+
"awa_Deva": "awa",
|
58 |
+
"ben_Beng": "bn",
|
59 |
+
"bho_Deva": "bho",
|
60 |
+
"brx_Deva": "brx",
|
61 |
+
"doi_Deva": "doi",
|
62 |
+
"eng_Latn": "en",
|
63 |
+
"gom_Deva": "gom",
|
64 |
+
"guj_Gujr": "gu",
|
65 |
+
"hin_Deva": "hi",
|
66 |
+
"hne_Deva": "hne",
|
67 |
+
"kan_Knda": "kn",
|
68 |
+
"kas_Arab": "ksa",
|
69 |
+
"kas_Deva": "ksd",
|
70 |
+
"kha_Latn": "kha",
|
71 |
+
"lus_Latn": "lus",
|
72 |
+
"mag_Deva": "mag",
|
73 |
+
"mai_Deva": "mai",
|
74 |
+
"mal_Mlym": "ml",
|
75 |
+
"mar_Deva": "mr",
|
76 |
+
"mni_Beng": "mnib",
|
77 |
+
"mni_Mtei": "mnim",
|
78 |
+
"npi_Deva": "ne",
|
79 |
+
"ory_Orya": "or",
|
80 |
+
"pan_Guru": "pa",
|
81 |
+
"san_Deva": "sa",
|
82 |
+
"sat_Olck": "sat",
|
83 |
+
"snd_Arab": "sda",
|
84 |
+
"snd_Deva": "sdd",
|
85 |
+
"tam_Taml": "ta",
|
86 |
+
"tel_Telu": "te",
|
87 |
+
"urd_Arab": "ur",
|
88 |
+
}
|
89 |
+
|
90 |
+
|
91 |
+
INDIC_NUM_MAP = {
|
92 |
+
"\u09e6": "0",
|
93 |
+
"0": "0",
|
94 |
+
"\u0ae6": "0",
|
95 |
+
"\u0ce6": "0",
|
96 |
+
"\u0966": "0",
|
97 |
+
"\u0660": "0",
|
98 |
+
"\uabf0": "0",
|
99 |
+
"\u0b66": "0",
|
100 |
+
"\u0a66": "0",
|
101 |
+
"\u1c50": "0",
|
102 |
+
"\u06f0": "0",
|
103 |
+
"\u09e7": "1",
|
104 |
+
"1": "1",
|
105 |
+
"\u0ae7": "1",
|
106 |
+
"\u0967": "1",
|
107 |
+
"\u0ce7": "1",
|
108 |
+
"\u06f1": "1",
|
109 |
+
"\uabf1": "1",
|
110 |
+
"\u0b67": "1",
|
111 |
+
"\u0a67": "1",
|
112 |
+
"\u1c51": "1",
|
113 |
+
"\u0c67": "1",
|
114 |
+
"\u09e8": "2",
|
115 |
+
"2": "2",
|
116 |
+
"\u0ae8": "2",
|
117 |
+
"\u0968": "2",
|
118 |
+
"\u0ce8": "2",
|
119 |
+
"\u06f2": "2",
|
120 |
+
"\uabf2": "2",
|
121 |
+
"\u0b68": "2",
|
122 |
+
"\u0a68": "2",
|
123 |
+
"\u1c52": "2",
|
124 |
+
"\u0c68": "2",
|
125 |
+
"\u09e9": "3",
|
126 |
+
"3": "3",
|
127 |
+
"\u0ae9": "3",
|
128 |
+
"\u0969": "3",
|
129 |
+
"\u0ce9": "3",
|
130 |
+
"\u06f3": "3",
|
131 |
+
"\uabf3": "3",
|
132 |
+
"\u0b69": "3",
|
133 |
+
"\u0a69": "3",
|
134 |
+
"\u1c53": "3",
|
135 |
+
"\u0c69": "3",
|
136 |
+
"\u09ea": "4",
|
137 |
+
"4": "4",
|
138 |
+
"\u0aea": "4",
|
139 |
+
"\u096a": "4",
|
140 |
+
"\u0cea": "4",
|
141 |
+
"\u06f4": "4",
|
142 |
+
"\uabf4": "4",
|
143 |
+
"\u0b6a": "4",
|
144 |
+
"\u0a6a": "4",
|
145 |
+
"\u1c54": "4",
|
146 |
+
"\u0c6a": "4",
|
147 |
+
"\u09eb": "5",
|
148 |
+
"5": "5",
|
149 |
+
"\u0aeb": "5",
|
150 |
+
"\u096b": "5",
|
151 |
+
"\u0ceb": "5",
|
152 |
+
"\u06f5": "5",
|
153 |
+
"\uabf5": "5",
|
154 |
+
"\u0b6b": "5",
|
155 |
+
"\u0a6b": "5",
|
156 |
+
"\u1c55": "5",
|
157 |
+
"\u0c6b": "5",
|
158 |
+
"\u09ec": "6",
|
159 |
+
"6": "6",
|
160 |
+
"\u0aec": "6",
|
161 |
+
"\u096c": "6",
|
162 |
+
"\u0cec": "6",
|
163 |
+
"\u06f6": "6",
|
164 |
+
"\uabf6": "6",
|
165 |
+
"\u0b6c": "6",
|
166 |
+
"\u0a6c": "6",
|
167 |
+
"\u1c56": "6",
|
168 |
+
"\u0c6c": "6",
|
169 |
+
"\u09ed": "7",
|
170 |
+
"7": "7",
|
171 |
+
"\u0aed": "7",
|
172 |
+
"\u096d": "7",
|
173 |
+
"\u0ced": "7",
|
174 |
+
"\u06f7": "7",
|
175 |
+
"\uabf7": "7",
|
176 |
+
"\u0b6d": "7",
|
177 |
+
"\u0a6d": "7",
|
178 |
+
"\u1c57": "7",
|
179 |
+
"\u0c6d": "7",
|
180 |
+
"\u09ee": "8",
|
181 |
+
"8": "8",
|
182 |
+
"\u0aee": "8",
|
183 |
+
"\u096e": "8",
|
184 |
+
"\u0cee": "8",
|
185 |
+
"\u06f8": "8",
|
186 |
+
"\uabf8": "8",
|
187 |
+
"\u0b6e": "8",
|
188 |
+
"\u0a6e": "8",
|
189 |
+
"\u1c58": "8",
|
190 |
+
"\u0c6e": "8",
|
191 |
+
"\u09ef": "9",
|
192 |
+
"9": "9",
|
193 |
+
"\u0aef": "9",
|
194 |
+
"\u096f": "9",
|
195 |
+
"\u0cef": "9",
|
196 |
+
"\u06f9": "9",
|
197 |
+
"\uabf9": "9",
|
198 |
+
"\u0b6f": "9",
|
199 |
+
"\u0a6f": "9",
|
200 |
+
"\u1c59": "9",
|
201 |
+
"\u0c6f": "9",
|
202 |
+
}
|
203 |
+
|
204 |
+
|
205 |
+
multispace_regex = re.compile("[ ]{2,}")
|
206 |
+
end_bracket_space_punc_regex = re.compile(r"\) ([\.!:?;,])")
|
207 |
+
digit_space_percent = re.compile(r"(\d) %")
|
208 |
+
double_quot_punc = re.compile(r"\"([,\.]+)")
|
209 |
+
digit_nbsp_digit = re.compile(r"(\d) (\d)")
|
210 |
+
|
211 |
+
|
212 |
+
def punc_norm(text, lang="en"):
|
213 |
+
text = (
|
214 |
+
text.replace("\r", "")
|
215 |
+
.replace("(", " (")
|
216 |
+
.replace(")", ") ")
|
217 |
+
.replace("( ", "(")
|
218 |
+
.replace(" )", ")")
|
219 |
+
.replace(" :", ":")
|
220 |
+
.replace(" ;", ";")
|
221 |
+
.replace("`", "'")
|
222 |
+
.replace("„", '"')
|
223 |
+
.replace("“", '"')
|
224 |
+
.replace("”", '"')
|
225 |
+
.replace("–", "-")
|
226 |
+
.replace("—", " - ")
|
227 |
+
.replace("´", "'")
|
228 |
+
.replace("‘", "'")
|
229 |
+
.replace("‚", "'")
|
230 |
+
.replace("’", "'")
|
231 |
+
.replace("''", '"')
|
232 |
+
.replace("´´", '"')
|
233 |
+
.replace("…", "...")
|
234 |
+
.replace(" « ", ' "')
|
235 |
+
.replace("« ", '"')
|
236 |
+
.replace("«", '"')
|
237 |
+
.replace(" » ", '" ')
|
238 |
+
.replace(" »", '"')
|
239 |
+
.replace("»", '"')
|
240 |
+
.replace(" %", "%")
|
241 |
+
.replace("nº ", "nº ")
|
242 |
+
.replace(" :", ":")
|
243 |
+
.replace(" ºC", " ºC")
|
244 |
+
.replace(" cm", " cm")
|
245 |
+
.replace(" ?", "?")
|
246 |
+
.replace(" !", "!")
|
247 |
+
.replace(" ;", ";")
|
248 |
+
.replace(", ", ", ")
|
249 |
+
)
|
250 |
+
|
251 |
+
text = multispace_regex.sub(" ", text)
|
252 |
+
text = end_bracket_space_punc_regex.sub(r")\1", text)
|
253 |
+
text = digit_space_percent.sub(r"\1%", text)
|
254 |
+
text = double_quot_punc.sub(
|
255 |
+
r'\1"', text
|
256 |
+
) # English "quotation," followed by comma, style
|
257 |
+
text = digit_nbsp_digit.sub(r"\1.\2", text) # What does it mean?
|
258 |
+
return text.strip(" ")
|
259 |
+
|
260 |
+
|
261 |
+
URL_PATTERN = r"\b(?<![\w/.])(?:(?:https?|ftp)://)?(?:(?:[\w-]+\.)+(?!\.))(?:[\w/\-?#&=%.]+)+(?!\.\w+)\b"
|
262 |
+
EMAIL_PATTERN = r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}"
|
263 |
+
# handles dates, time, percentages, proportion, ratio, etc
|
264 |
+
NUMERAL_PATTERN = r"(~?\d+\.?\d*\s?%?\s?-?\s?~?\d+\.?\d*\s?%|~?\d+%|\d+[-\/.,:']\d+[-\/.,:'+]\d+(?:\.\d+)?|\d+[-\/.:'+]\d+(?:\.\d+)?)"
|
265 |
+
# handles upi, social media handles and hashtags
|
266 |
+
OTHER_PATTERN = r"[A-Za-z0-9]*[#|@]\w+"
|
267 |
+
|
268 |
+
|
269 |
+
def normalize_indic_numerals(line: str):
|
270 |
+
"""
|
271 |
+
Normalize the numerals in Indic languages from native script to Roman script (if present).
|
272 |
+
|
273 |
+
Args:
|
274 |
+
line (str): an input string with Indic numerals to be normalized.
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
str: an input string with the all Indic numerals normalized to Roman script.
|
278 |
+
"""
|
279 |
+
return "".join([INDIC_NUM_MAP.get(c, c) for c in line])
|
280 |
+
|
281 |
+
|
282 |
+
def wrap_with_placeholders(text: str, patterns: list) -> Tuple[str, dict]:
|
283 |
+
"""
|
284 |
+
Wraps substrings with matched patterns in the given text with placeholders and returns
|
285 |
+
the modified text along with a mapping of the placeholders to their original value.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
text (str): an input string which needs to be wrapped with the placeholders.
|
289 |
+
pattern (list): list of patterns to search for in the input string.
|
290 |
+
|
291 |
+
Returns:
|
292 |
+
Tuple[str, dict]: a tuple containing the modified text and a dictionary mapping
|
293 |
+
placeholders to their original values.
|
294 |
+
"""
|
295 |
+
serial_no = 1
|
296 |
+
|
297 |
+
placeholder_entity_map = dict()
|
298 |
+
|
299 |
+
for pattern in patterns:
|
300 |
+
matches = set(re.findall(pattern, text))
|
301 |
+
|
302 |
+
# wrap common match with placeholder tags
|
303 |
+
for match in matches:
|
304 |
+
if pattern == URL_PATTERN:
|
305 |
+
# Avoids false positive URL matches for names with initials.
|
306 |
+
temp = match.replace(".", "")
|
307 |
+
if len(temp) < 4:
|
308 |
+
continue
|
309 |
+
if pattern == NUMERAL_PATTERN:
|
310 |
+
# Short numeral patterns do not need placeholder based handling.
|
311 |
+
temp = match.replace(" ", "").replace(".", "").replace(":", "")
|
312 |
+
if len(temp) < 4:
|
313 |
+
continue
|
314 |
+
|
315 |
+
# Set of Translations of "ID" in all the suppported languages have been collated.
|
316 |
+
# This has been added to deal with edge cases where placeholders might get translated.
|
317 |
+
indic_failure_cases = [
|
318 |
+
"آی ڈی ",
|
319 |
+
"ꯑꯥꯏꯗꯤ",
|
320 |
+
"आईडी",
|
321 |
+
"आई . डी . ",
|
322 |
+
"आई . डी .",
|
323 |
+
"आई. डी. ",
|
324 |
+
"आई. डी.",
|
325 |
+
"ऐटि",
|
326 |
+
"آئی ڈی ",
|
327 |
+
"ᱟᱭᱰᱤ ᱾",
|
328 |
+
"आयडी",
|
329 |
+
"ऐडि",
|
330 |
+
"आइडि",
|
331 |
+
]
|
332 |
+
placeholder = "<ID{}>".format(serial_no)
|
333 |
+
alternate_placeholder = "< ID{} >".format(serial_no)
|
334 |
+
placeholder_entity_map[placeholder] = match
|
335 |
+
placeholder_entity_map[alternate_placeholder] = match
|
336 |
+
placeholder = "<ID{}]".format(serial_no)
|
337 |
+
alternate_placeholder = "< ID{} ]".format(serial_no)
|
338 |
+
placeholder_entity_map[placeholder] = match
|
339 |
+
placeholder_entity_map[alternate_placeholder] = match
|
340 |
+
|
341 |
+
for i in indic_failure_cases:
|
342 |
+
placeholder_temp = "<{}{}>".format(i, serial_no)
|
343 |
+
placeholder_entity_map[placeholder_temp] = match
|
344 |
+
placeholder_temp = "< {}{} >".format(i, serial_no)
|
345 |
+
placeholder_entity_map[placeholder_temp] = match
|
346 |
+
placeholder_temp = "< {} {} >".format(i, serial_no)
|
347 |
+
placeholder_entity_map[placeholder_temp] = match
|
348 |
+
placeholder_temp = "<{} {}]".format(i, serial_no)
|
349 |
+
placeholder_entity_map[placeholder_temp] = match
|
350 |
+
placeholder_temp = "< {} {} ]".format(i, serial_no)
|
351 |
+
placeholder_entity_map[placeholder_temp] = match
|
352 |
+
placeholder_temp = "[{} {}]".format(i, serial_no)
|
353 |
+
placeholder_entity_map[placeholder_temp] = match
|
354 |
+
placeholder_temp = "[ {} {} ]".format(i, serial_no)
|
355 |
+
placeholder_entity_map[placeholder_temp] = match
|
356 |
+
|
357 |
+
text = text.replace(match, placeholder)
|
358 |
+
serial_no += 1
|
359 |
+
|
360 |
+
text = re.sub("\s+", " ", text)
|
361 |
+
|
362 |
+
# Regex has failure cases in trailing "/" in URLs, so this is a workaround.
|
363 |
+
text = text.replace(">/", ">")
|
364 |
+
text = text.replace("]/", "]")
|
365 |
+
|
366 |
+
return text, placeholder_entity_map
|
367 |
+
|
368 |
+
|
369 |
+
def normalize(
|
370 |
+
text: str,
|
371 |
+
patterns: list = [EMAIL_PATTERN, URL_PATTERN, NUMERAL_PATTERN, OTHER_PATTERN],
|
372 |
+
) -> Tuple[str, dict]:
|
373 |
+
"""
|
374 |
+
Normalizes and wraps the spans of input string with placeholder tags. It first normalizes
|
375 |
+
the Indic numerals in the input string to Roman script. Later, it uses the input string with normalized
|
376 |
+
Indic numerals to wrap the spans of text matching the pattern with placeholder tags.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
text (str): input string.
|
380 |
+
pattern (list): list of patterns to search for in the input string.
|
381 |
+
|
382 |
+
Returns:
|
383 |
+
Tuple[str, dict]: a tuple containing the modified text and a dictionary mapping
|
384 |
+
placeholders to their original values.
|
385 |
+
"""
|
386 |
+
text = normalize_indic_numerals(text.strip("\n"))
|
387 |
+
text, placeholder_entity_map = wrap_with_placeholders(text, patterns)
|
388 |
+
return text, placeholder_entity_map
|
389 |
+
|
390 |
+
|
391 |
+
def split_sentences(paragraph: str, lang: str) -> List[str]:
|
392 |
+
"""
|
393 |
+
Splits the input text paragraph into sentences. It uses `moses` for English and
|
394 |
+
`indic-nlp` for Indic languages.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
paragraph (str): input text paragraph.
|
398 |
+
lang (str): flores language code.
|
399 |
+
|
400 |
+
Returns:
|
401 |
+
List[str] -> list of sentences.
|
402 |
+
"""
|
403 |
+
# fails to handle sentence splitting in case of
|
404 |
+
# with MosesSentenceSplitter(lang) as splitter:
|
405 |
+
# return splitter([paragraph])
|
406 |
+
return (
|
407 |
+
sent_tokenize(paragraph)
|
408 |
+
if lang == "eng_Latn"
|
409 |
+
else sentence_split(
|
410 |
+
paragraph, lang=flores_codes[lang], delim_pat=DELIM_PAT_NO_DANDA
|
411 |
+
)
|
412 |
+
)
|
413 |
+
|
414 |
+
|
415 |
+
def apply_lang_tags(sents: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
416 |
+
"""
|
417 |
+
Add special tokens indicating source and target language to the start of the each input sentence.
|
418 |
+
Each resulting input sentence will have the format: "`{src_lang} {tgt_lang} {input_sentence}`".
|
419 |
+
|
420 |
+
Args:
|
421 |
+
sent (str): input sentence to be translated.
|
422 |
+
src_lang (str): flores lang code of the input sentence.
|
423 |
+
tgt_lang (str): flores lang code in which the input sentence will be translated.
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
List[str]: list of input sentences with the special tokens added to the start.
|
427 |
+
"""
|
428 |
+
return Parallel(n_jobs=-1)(
|
429 |
+
delayed(lambda x: f"{src_lang} {tgt_lang} {x.strip()}")(sent) for sent in sents
|
430 |
+
)
|
431 |
+
|
432 |
+
|
433 |
+
def preprocess_sent(
|
434 |
+
sent: str,
|
435 |
+
normalizer: Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory],
|
436 |
+
lang: str,
|
437 |
+
) -> str:
|
438 |
+
"""
|
439 |
+
Preprocess an input text sentence by normalizing, tokenization, and possibly transliterating it.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
sent (str): input text sentence to preprocess.
|
443 |
+
normalizer (Union[MosesPunctNormalizer, indic_normalize.IndicNormalizerFactory]): an object that performs normalization on the text.
|
444 |
+
lang (str): flores language code of the input text sentence.
|
445 |
+
|
446 |
+
Returns:
|
447 |
+
Tuple[str, dict]: a tuple of preprocessed input text sentence and also a corresponding dictionary
|
448 |
+
mapping placeholders to their original values.
|
449 |
+
"""
|
450 |
+
iso_lang = flores_codes[lang]
|
451 |
+
sent = punc_norm(sent, iso_lang)
|
452 |
+
sent, placeholder_entity_map = normalize(sent)
|
453 |
+
|
454 |
+
transliterate = True
|
455 |
+
if lang.split("_")[1] in ["Arab", "Aran", "Olck", "Mtei", "Latn"]:
|
456 |
+
transliterate = False
|
457 |
+
|
458 |
+
if iso_lang == "en":
|
459 |
+
processed_sent = " ".join(
|
460 |
+
en_tok.tokenize(en_normalizer.normalize(sent.strip()), escape=False)
|
461 |
+
)
|
462 |
+
elif transliterate:
|
463 |
+
# transliterates from the any specific language to devanagari
|
464 |
+
# which is why we specify lang2_code as "hi".
|
465 |
+
processed_sent = xliterator.transliterate(
|
466 |
+
" ".join(
|
467 |
+
indic_tokenize.trivial_tokenize(
|
468 |
+
normalizer.normalize(sent.strip()), iso_lang
|
469 |
+
)
|
470 |
+
),
|
471 |
+
iso_lang,
|
472 |
+
"hi",
|
473 |
+
).replace(" ् ", "्")
|
474 |
+
else:
|
475 |
+
# we only need to transliterate for joint training
|
476 |
+
processed_sent = " ".join(
|
477 |
+
indic_tokenize.trivial_tokenize(
|
478 |
+
normalizer.normalize(sent.strip()), iso_lang
|
479 |
+
)
|
480 |
+
)
|
481 |
+
|
482 |
+
return processed_sent, placeholder_entity_map
|
483 |
+
|
484 |
+
|
485 |
+
def preprocess(sents: List[str], lang: str):
|
486 |
+
"""
|
487 |
+
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it.
|
488 |
+
|
489 |
+
Args:
|
490 |
+
batch (List[str]): input list of sentences to preprocess.
|
491 |
+
lang (str): flores language code of the input text sentences.
|
492 |
+
|
493 |
+
Returns:
|
494 |
+
Tuple[List[str], List[dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
|
495 |
+
mapping placeholders to their original values.
|
496 |
+
"""
|
497 |
+
|
498 |
+
normalizer = (
|
499 |
+
indic_normalize.IndicNormalizerFactory().get_normalizer(flores_codes[lang])
|
500 |
+
if lang != "eng_Latn"
|
501 |
+
else None
|
502 |
+
)
|
503 |
+
|
504 |
+
processed_sents, placeholder_entity_map_sents = zip(
|
505 |
+
*[preprocess_sent(sent, normalizer, lang) for sent in sents]
|
506 |
+
)
|
507 |
+
|
508 |
+
return processed_sents, placeholder_entity_map_sents
|
509 |
+
|
510 |
+
|
511 |
+
def preprocess_batch(batch: List[str], src_lang: str, tgt_lang: str) -> List[str]:
|
512 |
+
"""
|
513 |
+
Preprocess an array of sentences by normalizing, tokenization, and possibly transliterating it. It also tokenizes the
|
514 |
+
normalized text sequences using sentence piece tokenizer and also adds language tags.
|
515 |
+
|
516 |
+
Args:
|
517 |
+
batch (List[str]): input list of sentences to preprocess.
|
518 |
+
src_lang (str): flores language code of the input text sentences.
|
519 |
+
tgt_lang (str): flores language code of the output text sentences.
|
520 |
+
|
521 |
+
Returns:
|
522 |
+
Tuple[List[str], List[dict]]: a tuple of list of preprocessed input text sentences and also a corresponding list of dictionary
|
523 |
+
mapping placeholders to their original values.
|
524 |
+
"""
|
525 |
+
preprocessed_sents, placeholder_entity_map_sents = preprocess(batch, lang=src_lang)
|
526 |
+
tagged_sents = apply_lang_tags(preprocessed_sents, src_lang, tgt_lang)
|
527 |
+
return tagged_sents, placeholder_entity_map_sents
|
528 |
+
|
529 |
+
|
530 |
+
def postprocess_batch(
|
531 |
+
sents: List[str],
|
532 |
+
placeholder_entity_map: List[dict],
|
533 |
+
lang: str,
|
534 |
+
common_lang: str = "hin_Deva",
|
535 |
+
) -> List[str]:
|
536 |
+
"""
|
537 |
+
Postprocesses a batch of input sentences after the translation generations.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
sents (List[str]): batch of translated sentences to postprocess.
|
541 |
+
placeholder_entity_map (List[dict]): dictionary mapping placeholders to the original entity values.
|
542 |
+
lang (str): flores language code of the input sentences.
|
543 |
+
common_lang (str, optional): flores language code of the transliterated language (defaults: hin_Deva).
|
544 |
+
|
545 |
+
Returns:
|
546 |
+
List[str]: postprocessed batch of input sentences.
|
547 |
+
"""
|
548 |
+
|
549 |
+
lang_code, script_code = lang.split("_")
|
550 |
+
|
551 |
+
for i in range(len(sents)):
|
552 |
+
sents[i] = sents[i].replace(" ", "").replace("▁", " ").strip()
|
553 |
+
|
554 |
+
# Fixes for Perso-Arabic scripts
|
555 |
+
# TODO: Move these normalizations inside indic-nlp-library
|
556 |
+
if script_code in {"Arab", "Aran"}:
|
557 |
+
# UrduHack adds space before punctuations. Since the model was trained without fixing this issue, let's fix it now
|
558 |
+
sents[i] = sents[i].replace(" ؟", "؟").replace(" ۔", "۔").replace(" ،", "،")
|
559 |
+
# Kashmiri bugfix for palatalization: https://github.com/AI4Bharat/IndicTrans2/issues/11
|
560 |
+
sents[i] = sents[i].replace("ٮ۪", "ؠ")
|
561 |
+
|
562 |
+
# Oriya bug: indic-nlp-library produces ଯ଼ instead of ୟ when converting from Devanagari to Odia
|
563 |
+
# TODO: Find out what's the issue with unicode transliterator for Oriya and fix it
|
564 |
+
if lang_code == "or":
|
565 |
+
sents[i] = sents[i].replace("ଯ଼", "ୟ")
|
566 |
+
|
567 |
+
assert len(sents) == len(placeholder_entity_map)
|
568 |
+
|
569 |
+
# Replace the placeholders entity
|
570 |
+
for i in range(0, len(sents)):
|
571 |
+
for key in placeholder_entity_map[i].keys():
|
572 |
+
sents[i] = sents[i].replace(key, placeholder_entity_map[i][key])
|
573 |
+
|
574 |
+
# Detokenize and transliterate to native scripts if applicable
|
575 |
+
|
576 |
+
if lang == "eng_Latn":
|
577 |
+
postprocessed_sents = [en_detok.detokenize(sent.split(" ")) for sent in sents]
|
578 |
+
else:
|
579 |
+
postprocessed_sents = [
|
580 |
+
indic_detokenize.trivial_detokenize(
|
581 |
+
xliterator.transliterate(
|
582 |
+
s, flores_codes[common_lang], flores_codes[lang]
|
583 |
+
),
|
584 |
+
flores_codes[lang],
|
585 |
+
)
|
586 |
+
for s in sents
|
587 |
+
]
|
588 |
+
|
589 |
+
assert len(postprocessed_sents) == len(placeholder_entity_map)
|
590 |
+
|
591 |
+
return postprocessed_sents
|
README.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# IndicTrans2 HF Compatible Models
|
2 |
+
|
3 |
+
In this section, we provide details on how to use our [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2) models which were originally trained with the [fairseq](https://github.com/facebookresearch/fairseq) to [HuggingFace transformers](https://huggingface.co/docs/transformers/index) for inference purpose. Our scripts for HuggingFace compatible models are adapted from [M2M100 repository](https://github.com/huggingface/transformers/tree/main/src/transformers/models/m2m_100).
|
4 |
+
|
5 |
+
|
6 |
+
### Setup
|
7 |
+
|
8 |
+
To get started, follow these steps to set up the environment:
|
9 |
+
|
10 |
+
```
|
11 |
+
# Clone the github repository and navigate to the project directory.
|
12 |
+
git clone https://github.com/AI4Bharat/IndicTrans2
|
13 |
+
cd IndicTrans2
|
14 |
+
|
15 |
+
# Install all the dependencies and requirements associated with the project for running HF compatible models.
|
16 |
+
source install.sh
|
17 |
+
```
|
18 |
+
|
19 |
+
> Note: The `install.sh` script in this directory is specifically for running HF compatible models for inference.
|
20 |
+
|
21 |
+
|
22 |
+
### Converting
|
23 |
+
|
24 |
+
In order to convert the fairseq checkpoint to a PyTorch checkpoint that is compatible with HuggingFace Transformers, use the following command:
|
25 |
+
|
26 |
+
```bash
|
27 |
+
python3 convert_indictrans_checkpoint_to_pytorch.py --fairseq_path <fairseq_checkpoint_best.pt> --pytorch_dump_folder_path <hf_output_dir>
|
28 |
+
```
|
29 |
+
- `<fairseq_checkpoint_best.pt>`: path to the fairseq `checkpoint_best.pt` that needs to be converted to HF compatible models
|
30 |
+
- `<hf_output_dir>`: path to the output directory where the HF compatible models will be saved
|
31 |
+
|
32 |
+
|
33 |
+
### Models
|
34 |
+
|
35 |
+
| Model | 🤗 HuggingFace Checkpoints |
|
36 |
+
|----------|-----------------------------------|
|
37 |
+
| Preprint En-Indic | [ai4bharat/indictrans2-en-indic-1B](https://huggingface.co/ai4bharat/indictrans2-en-indic-1B) |
|
38 |
+
| Preprint Indic-En | [ai4bharat/indictrans2-indic-en-1B](https://huggingface.co/ai4bharat/indictrans2-indic-en-1B) |
|
39 |
+
|
40 |
+
|
41 |
+
### Inference
|
42 |
+
|
43 |
+
With the conversion complete, you can now perform inference using the HuggingFace Transformers.
|
44 |
+
|
45 |
+
You can start with the provided `example.py` script and customize it for your specific translation use case:
|
46 |
+
|
47 |
+
```bash
|
48 |
+
python3 example.py
|
49 |
+
```
|
50 |
+
|
51 |
+
Feel free to modify the `example.py` script to suit your translation needs.
|
52 |
+
|
53 |
+
### Citation
|
54 |
+
|
55 |
+
```
|
56 |
+
@article{ai4bharat2023indictrans2,
|
57 |
+
title = {IndicTrans2: Towards High-Quality and Accessible Machine Translation Models for all 22 Scheduled Indian Languages},
|
58 |
+
author = {AI4Bharat and Jay Gala and Pranjal A. Chitale and Raghavan AK and Sumanth Doddapaneni and Varun Gumma and Aswanth Kumar and Janki Nawale and Anupama Sujatha and Ratish Puduppully and Vivek Raghavan and Pratyush Kumar and Mitesh M. Khapra and Raj Dabre and Anoop Kunchukuttan},
|
59 |
+
year = {2023},
|
60 |
+
journal = {arXiv preprint arXiv: 2305.16307}
|
61 |
+
}
|
62 |
+
```
|
configuration_indictrans.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch IndicTrans config."""
|
16 |
+
|
17 |
+
|
18 |
+
from collections import OrderedDict
|
19 |
+
from typing import Any, Mapping, Optional
|
20 |
+
|
21 |
+
from transformers import PreTrainedTokenizer
|
22 |
+
from transformers.configuration_utils import PretrainedConfig
|
23 |
+
from transformers.onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
|
24 |
+
from transformers.onnx.utils import compute_effective_axis_dimension
|
25 |
+
from transformers.utils import TensorType, is_torch_available
|
26 |
+
|
27 |
+
|
28 |
+
# Copied from transformers.models.m2m_100.configuration_m2m_100.M2M100Config->IndicTrans
|
29 |
+
class IndicTransConfig(PretrainedConfig):
|
30 |
+
r"""
|
31 |
+
This is the configuration class to store the configuration of a [`IT2Model`]. It is used to instantiate an
|
32 |
+
IT2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
33 |
+
with the defaults will yield a similar configuration to that of the IT2
|
34 |
+
|
35 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
36 |
+
documentation from [`PretrainedConfig`] for more information.
|
37 |
+
|
38 |
+
|
39 |
+
Args:
|
40 |
+
vocab_size (`int`, *optional*, defaults to 50265):
|
41 |
+
Vocabulary size of the IT2 model. Defines the number of different tokens that can be represented by the
|
42 |
+
`inputs_ids` passed when calling [`IT2Model`] or
|
43 |
+
d_model (`int`, *optional*, defaults to 1024):
|
44 |
+
Dimensionality of the layers and the pooler layer.
|
45 |
+
encoder_layers (`int`, *optional*, defaults to 12):
|
46 |
+
Number of encoder layers.
|
47 |
+
decoder_layers (`int`, *optional*, defaults to 12):
|
48 |
+
Number of decoder layers.
|
49 |
+
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
50 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
51 |
+
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
52 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
53 |
+
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
54 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
55 |
+
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
56 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
57 |
+
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
58 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
59 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
60 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
61 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
62 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
63 |
+
The dropout ratio for the attention probabilities.
|
64 |
+
activation_dropout (`float`, *optional*, defaults to 0.0):
|
65 |
+
The dropout ratio for activations inside the fully connected layer.
|
66 |
+
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
67 |
+
The dropout ratio for classifier.
|
68 |
+
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
69 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
70 |
+
just in case (e.g., 512 or 1024 or 2048).
|
71 |
+
init_std (`float`, *optional*, defaults to 0.02):
|
72 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
73 |
+
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
74 |
+
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
75 |
+
for more details.
|
76 |
+
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
77 |
+
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
78 |
+
for more details.
|
79 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
80 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
81 |
+
```"""
|
82 |
+
model_type = "IndicTrans"
|
83 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
84 |
+
attribute_map = {
|
85 |
+
"num_attention_heads": "encoder_attention_heads",
|
86 |
+
"hidden_size": "d_model",
|
87 |
+
}
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
encoder_vocab_size=None,
|
92 |
+
decoder_vocab_size=None,
|
93 |
+
encoder_embed_dim=512,
|
94 |
+
decoder_embed_dim=512,
|
95 |
+
max_source_positions=210,
|
96 |
+
max_target_positions=210,
|
97 |
+
encoder_layers=6,
|
98 |
+
encoder_ffn_dim=2048,
|
99 |
+
encoder_attention_heads=8,
|
100 |
+
decoder_layers=6,
|
101 |
+
decoder_ffn_dim=2048,
|
102 |
+
decoder_attention_heads=8,
|
103 |
+
encoder_layerdrop=0.00,
|
104 |
+
decoder_layerdrop=0.00,
|
105 |
+
use_cache=True,
|
106 |
+
is_encoder_decoder=True,
|
107 |
+
activation_function="relu",
|
108 |
+
encoder_normalize_before=False,
|
109 |
+
decoder_normalize_before=False,
|
110 |
+
layernorm_embedding=False,
|
111 |
+
share_decoder_input_output_embed=False,
|
112 |
+
dropout=0.1,
|
113 |
+
attention_dropout=0.0,
|
114 |
+
activation_dropout=0.0,
|
115 |
+
init_std=0.02,
|
116 |
+
scale_embedding=True,
|
117 |
+
decoder_start_token_id=2,
|
118 |
+
pad_token_id=1,
|
119 |
+
bos_token_id=0,
|
120 |
+
eos_token_id=2,
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
self.encoder_vocab_size = encoder_vocab_size
|
124 |
+
self.decoder_vocab_size = decoder_vocab_size
|
125 |
+
self.encoder_normalize_before = encoder_normalize_before
|
126 |
+
self.decoder_normalize_before = decoder_normalize_before
|
127 |
+
self.layernorm_embedding = layernorm_embedding
|
128 |
+
self.max_source_positions = max_source_positions
|
129 |
+
self.max_target_positions = max_target_positions
|
130 |
+
self.encoder_embed_dim = encoder_embed_dim
|
131 |
+
self.decoder_embed_dim = decoder_embed_dim
|
132 |
+
self.encoder_ffn_dim = encoder_ffn_dim
|
133 |
+
self.encoder_layers = encoder_layers
|
134 |
+
self.encoder_attention_heads = encoder_attention_heads
|
135 |
+
self.decoder_ffn_dim = decoder_ffn_dim
|
136 |
+
self.decoder_layers = decoder_layers
|
137 |
+
self.decoder_attention_heads = decoder_attention_heads
|
138 |
+
self.dropout = dropout
|
139 |
+
self.attention_dropout = attention_dropout
|
140 |
+
self.activation_dropout = activation_dropout
|
141 |
+
self.activation_function = activation_function
|
142 |
+
self.init_std = init_std
|
143 |
+
self.encoder_layerdrop = encoder_layerdrop
|
144 |
+
self.decoder_layerdrop = decoder_layerdrop
|
145 |
+
self.use_cache = use_cache
|
146 |
+
self.num_hidden_layers = encoder_layers
|
147 |
+
self.scale_embedding = scale_embedding
|
148 |
+
self.share_decoder_input_output_embed = share_decoder_input_output_embed
|
149 |
+
|
150 |
+
super().__init__(
|
151 |
+
pad_token_id=pad_token_id,
|
152 |
+
bos_token_id=bos_token_id,
|
153 |
+
eos_token_id=eos_token_id,
|
154 |
+
is_encoder_decoder=is_encoder_decoder,
|
155 |
+
decoder_start_token_id=decoder_start_token_id,
|
156 |
+
**kwargs,
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
class IndicTransOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
161 |
+
@property
|
162 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
163 |
+
common_inputs = OrderedDict(
|
164 |
+
[
|
165 |
+
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
166 |
+
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
167 |
+
]
|
168 |
+
)
|
169 |
+
|
170 |
+
if self.use_past:
|
171 |
+
common_inputs["decoder_input_ids"] = {0: "batch"}
|
172 |
+
common_inputs["decoder_attention_mask"] = {
|
173 |
+
0: "batch",
|
174 |
+
1: "past_decoder_sequence + sequence",
|
175 |
+
}
|
176 |
+
else:
|
177 |
+
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
178 |
+
common_inputs["decoder_attention_mask"] = {
|
179 |
+
0: "batch",
|
180 |
+
1: "decoder_sequence",
|
181 |
+
}
|
182 |
+
|
183 |
+
if self.use_past:
|
184 |
+
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
185 |
+
return common_inputs
|
186 |
+
|
187 |
+
# Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
|
188 |
+
# A better name would be _generate_dummy_inputs_for_encoder_and_decoder because sequence classification and question
|
189 |
+
# answering are not supported for IT2, but this name is preserved to be able to check that the copy matches what
|
190 |
+
# was done for BART so that it can be updated if need be.
|
191 |
+
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
192 |
+
self,
|
193 |
+
tokenizer: PreTrainedTokenizer,
|
194 |
+
batch_size: int = -1,
|
195 |
+
seq_length: int = -1,
|
196 |
+
is_pair: bool = False,
|
197 |
+
framework: Optional[TensorType] = None,
|
198 |
+
) -> Mapping[str, Any]:
|
199 |
+
# Copied from OnnxConfig.generate_dummy_inputs
|
200 |
+
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
201 |
+
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
202 |
+
batch_size = compute_effective_axis_dimension(
|
203 |
+
batch_size,
|
204 |
+
fixed_dimension=OnnxConfig.default_fixed_batch,
|
205 |
+
num_token_to_add=0,
|
206 |
+
)
|
207 |
+
|
208 |
+
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
209 |
+
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
210 |
+
seq_length = compute_effective_axis_dimension(
|
211 |
+
seq_length,
|
212 |
+
fixed_dimension=OnnxConfig.default_fixed_sequence,
|
213 |
+
num_token_to_add=token_to_add,
|
214 |
+
)
|
215 |
+
|
216 |
+
# Generate dummy inputs according to compute batch and sequence
|
217 |
+
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
218 |
+
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
|
219 |
+
return common_inputs
|
220 |
+
|
221 |
+
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._generate_dummy_inputs_for_default_and_seq2seq_lm
|
222 |
+
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
|
223 |
+
self,
|
224 |
+
tokenizer: PreTrainedTokenizer,
|
225 |
+
batch_size: int = -1,
|
226 |
+
seq_length: int = -1,
|
227 |
+
is_pair: bool = False,
|
228 |
+
framework: Optional[TensorType] = None,
|
229 |
+
) -> Mapping[str, Any]:
|
230 |
+
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
231 |
+
tokenizer, batch_size, seq_length, is_pair, framework
|
232 |
+
)
|
233 |
+
|
234 |
+
# Generate decoder inputs
|
235 |
+
decoder_seq_length = seq_length if not self.use_past else 1
|
236 |
+
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
237 |
+
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
238 |
+
)
|
239 |
+
decoder_inputs = {
|
240 |
+
f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()
|
241 |
+
}
|
242 |
+
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
243 |
+
|
244 |
+
if self.use_past:
|
245 |
+
if not is_torch_available():
|
246 |
+
raise ValueError(
|
247 |
+
"Cannot generate dummy past_keys inputs without PyTorch installed."
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
import torch
|
251 |
+
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
252 |
+
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
253 |
+
(
|
254 |
+
num_encoder_attention_heads,
|
255 |
+
num_decoder_attention_heads,
|
256 |
+
) = self.num_attention_heads
|
257 |
+
encoder_shape = (
|
258 |
+
batch,
|
259 |
+
num_encoder_attention_heads,
|
260 |
+
encoder_seq_length,
|
261 |
+
self._config.hidden_size // num_encoder_attention_heads,
|
262 |
+
)
|
263 |
+
decoder_past_length = decoder_seq_length + 3
|
264 |
+
decoder_shape = (
|
265 |
+
batch,
|
266 |
+
num_decoder_attention_heads,
|
267 |
+
decoder_past_length,
|
268 |
+
self._config.hidden_size // num_decoder_attention_heads,
|
269 |
+
)
|
270 |
+
|
271 |
+
common_inputs["decoder_attention_mask"] = torch.cat(
|
272 |
+
[
|
273 |
+
common_inputs["decoder_attention_mask"],
|
274 |
+
torch.ones(batch, decoder_past_length),
|
275 |
+
],
|
276 |
+
dim=1,
|
277 |
+
)
|
278 |
+
|
279 |
+
common_inputs["past_key_values"] = []
|
280 |
+
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
281 |
+
num_encoder_layers, num_decoder_layers = self.num_layers
|
282 |
+
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
283 |
+
max_num_layers = (
|
284 |
+
max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
285 |
+
)
|
286 |
+
remaining_side_name = (
|
287 |
+
"encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
288 |
+
)
|
289 |
+
|
290 |
+
for _ in range(min_num_layers):
|
291 |
+
common_inputs["past_key_values"].append(
|
292 |
+
(
|
293 |
+
torch.zeros(decoder_shape),
|
294 |
+
torch.zeros(decoder_shape),
|
295 |
+
torch.zeros(encoder_shape),
|
296 |
+
torch.zeros(encoder_shape),
|
297 |
+
)
|
298 |
+
)
|
299 |
+
# TODO: test this.
|
300 |
+
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
301 |
+
for _ in range(min_num_layers, max_num_layers):
|
302 |
+
common_inputs["past_key_values"].append(
|
303 |
+
(torch.zeros(shape), torch.zeros(shape))
|
304 |
+
)
|
305 |
+
return common_inputs
|
306 |
+
|
307 |
+
generate_dummy_inputs = _generate_dummy_inputs_for_default_and_seq2seq_lm
|
convert_indictrans_checkpoint_to_pytorch.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import argparse
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from configuration_indictrans import IndicTransConfig
|
21 |
+
from modeling_indictrans import IndicTransForConditionalGeneration
|
22 |
+
|
23 |
+
|
24 |
+
def remove_ignore_keys_(state_dict):
|
25 |
+
ignore_keys = [
|
26 |
+
"encoder.version",
|
27 |
+
"decoder.version",
|
28 |
+
"model.encoder.version",
|
29 |
+
"model.decoder.version",
|
30 |
+
"_float_tensor",
|
31 |
+
"encoder.embed_positions._float_tensor",
|
32 |
+
"decoder.embed_positions._float_tensor",
|
33 |
+
]
|
34 |
+
for k in ignore_keys:
|
35 |
+
state_dict.pop(k, None)
|
36 |
+
|
37 |
+
|
38 |
+
def make_linear_from_emb(emb):
|
39 |
+
vocab_size, emb_size = emb.shape
|
40 |
+
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
41 |
+
lin_layer.weight.data = emb.data
|
42 |
+
return lin_layer
|
43 |
+
|
44 |
+
|
45 |
+
def convert_fairseq_IT2_checkpoint_from_disk(checkpoint_path):
|
46 |
+
model = torch.load(checkpoint_path, map_location="cpu")
|
47 |
+
args = model["args"] or model["cfg"]["model"]
|
48 |
+
state_dict = model["model"]
|
49 |
+
remove_ignore_keys_(state_dict)
|
50 |
+
encoder_vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
|
51 |
+
decoder_vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0]
|
52 |
+
|
53 |
+
config = IndicTransConfig(
|
54 |
+
encoder_vocab_size=encoder_vocab_size,
|
55 |
+
decoder_vocab_size=decoder_vocab_size,
|
56 |
+
max_source_positions=args.max_source_positions,
|
57 |
+
max_target_positions=args.max_target_positions,
|
58 |
+
encoder_layers=args.encoder_layers,
|
59 |
+
decoder_layers=args.decoder_layers,
|
60 |
+
layernorm_embedding=args.layernorm_embedding,
|
61 |
+
encoder_normalize_before=args.encoder_normalize_before,
|
62 |
+
decoder_normalize_before=args.decoder_normalize_before,
|
63 |
+
encoder_attention_heads=args.encoder_attention_heads,
|
64 |
+
decoder_attention_heads=args.decoder_attention_heads,
|
65 |
+
encoder_ffn_dim=args.encoder_ffn_embed_dim,
|
66 |
+
decoder_ffn_dim=args.decoder_ffn_embed_dim,
|
67 |
+
encoder_embed_dim=args.encoder_embed_dim,
|
68 |
+
decoder_embed_dim=args.decoder_embed_dim,
|
69 |
+
encoder_layerdrop=args.encoder_layerdrop,
|
70 |
+
decoder_layerdrop=args.decoder_layerdrop,
|
71 |
+
dropout=args.dropout,
|
72 |
+
attention_dropout=args.attention_dropout,
|
73 |
+
activation_dropout=args.activation_dropout,
|
74 |
+
activation_function=args.activation_fn,
|
75 |
+
share_decoder_input_output_embed=args.share_decoder_input_output_embed,
|
76 |
+
scale_embedding=not args.no_scale_embedding,
|
77 |
+
)
|
78 |
+
|
79 |
+
model = IndicTransForConditionalGeneration(config)
|
80 |
+
model.model.load_state_dict(state_dict, strict=False)
|
81 |
+
if not args.share_decoder_input_output_embed:
|
82 |
+
model.lm_head = make_linear_from_emb(
|
83 |
+
state_dict["decoder.output_projection.weight"]
|
84 |
+
)
|
85 |
+
print(model)
|
86 |
+
return model
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
parser = argparse.ArgumentParser()
|
91 |
+
# Required parameters
|
92 |
+
parser.add_argument(
|
93 |
+
"--fairseq_path",
|
94 |
+
default="indic-en/model/checkpoint_best.pt",
|
95 |
+
type=str,
|
96 |
+
help="path to a model.pt on local filesystem.",
|
97 |
+
)
|
98 |
+
parser.add_argument(
|
99 |
+
"--pytorch_dump_folder_path",
|
100 |
+
default="indic-en/hf_model",
|
101 |
+
type=str,
|
102 |
+
help="Path to the output PyTorch model.",
|
103 |
+
)
|
104 |
+
|
105 |
+
args = parser.parse_args()
|
106 |
+
model = convert_fairseq_IT2_checkpoint_from_disk(args.fairseq_path)
|
107 |
+
model.save_pretrained(args.pytorch_dump_folder_path)
|
example.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
|
4 |
+
from IndicTransTokenizer.utils import preprocess_batch, postprocess_batch
|
5 |
+
from IndicTransTokenizer.tokenizer import IndicTransTokenizer
|
6 |
+
|
7 |
+
en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"
|
8 |
+
|
9 |
+
BATCH_SIZE = 16
|
10 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
+
if len(sys.argv)>1:
|
13 |
+
quantization = sys.argv[1]
|
14 |
+
else:
|
15 |
+
quantization = ""
|
16 |
+
|
17 |
+
|
18 |
+
def initialize_model_and_tokenizer(ckpt_dir, direction, quantization):
|
19 |
+
if quantization == "4-bit":
|
20 |
+
qconfig = BitsAndBytesConfig(
|
21 |
+
load_in_4bit=True,
|
22 |
+
bnb_4bit_use_double_quant=True,
|
23 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
24 |
+
)
|
25 |
+
elif quantization == "8-bit":
|
26 |
+
qconfig = BitsAndBytesConfig(
|
27 |
+
load_in_8bit=True,
|
28 |
+
bnb_8bit_use_double_quant=True,
|
29 |
+
bnb_8bit_compute_dtype=torch.bfloat16,
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
qconfig = None
|
33 |
+
|
34 |
+
tokenizer = IndicTransTokenizer(direction=direction)
|
35 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
36 |
+
ckpt_dir,
|
37 |
+
trust_remote_code=True,
|
38 |
+
low_cpu_mem_usage=True,
|
39 |
+
quantization_config=qconfig
|
40 |
+
)
|
41 |
+
|
42 |
+
if qconfig==None:
|
43 |
+
model = model.to(DEVICE)
|
44 |
+
model.half()
|
45 |
+
|
46 |
+
model.eval()
|
47 |
+
|
48 |
+
return tokenizer, model
|
49 |
+
|
50 |
+
|
51 |
+
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer):
|
52 |
+
translations = []
|
53 |
+
for i in range(0, len(input_sentences), BATCH_SIZE):
|
54 |
+
batch = input_sentences[i : i + BATCH_SIZE]
|
55 |
+
|
56 |
+
# Preprocess the batch and extract entity mappings
|
57 |
+
batch, entity_map = preprocess_batch(
|
58 |
+
batch, src_lang=src_lang, tgt_lang=tgt_lang
|
59 |
+
)
|
60 |
+
|
61 |
+
# Tokenize the batch and generate input encodings
|
62 |
+
inputs = tokenizer(
|
63 |
+
batch,
|
64 |
+
src=True,
|
65 |
+
truncation=True,
|
66 |
+
padding="longest",
|
67 |
+
return_tensors="pt",
|
68 |
+
return_attention_mask=True,
|
69 |
+
).to(DEVICE)
|
70 |
+
|
71 |
+
# Generate translations using the model
|
72 |
+
with torch.no_grad():
|
73 |
+
generated_tokens = model.generate(
|
74 |
+
**inputs,
|
75 |
+
use_cache=True,
|
76 |
+
min_length=0,
|
77 |
+
max_length=256,
|
78 |
+
num_beams=5,
|
79 |
+
num_return_sequences=1,
|
80 |
+
)
|
81 |
+
|
82 |
+
# Decode the generated tokens into text
|
83 |
+
generated_tokens = tokenizer.batch_decode(
|
84 |
+
generated_tokens.detach().cpu().tolist(), src=False
|
85 |
+
)
|
86 |
+
|
87 |
+
# Postprocess the translations, including entity replacement
|
88 |
+
translations += postprocess_batch(
|
89 |
+
generated_tokens, lang=tgt_lang, placeholder_entity_map=entity_map
|
90 |
+
)
|
91 |
+
|
92 |
+
del inputs
|
93 |
+
torch.cuda.empty_cache()
|
94 |
+
|
95 |
+
return translations
|
96 |
+
|
97 |
+
|
98 |
+
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(
|
99 |
+
en_indic_ckpt_dir, "en-indic", quantization
|
100 |
+
)
|
101 |
+
|
102 |
+
# ---------------------------------------------------------------------------
|
103 |
+
# English to Hindi
|
104 |
+
# ---------------------------------------------------------------------------
|
105 |
+
en_sents = [
|
106 |
+
"When I was young, I used to go to the park every day.",
|
107 |
+
"He has many old books, which he inherited from his ancestors.",
|
108 |
+
"I can't figure out how to solve my problem.",
|
109 |
+
"She is very hardworking and intelligent, which is why she got all the good marks.",
|
110 |
+
"We watched a new movie last week, which was very inspiring.",
|
111 |
+
"If you had met me at that time, we would have gone out to eat.",
|
112 |
+
"She went to the market with her sister to buy a new sari.",
|
113 |
+
"Raj told me that he is going to his grandmother's house next month.",
|
114 |
+
"All the kids were having fun at the party and were eating lots of sweets.",
|
115 |
+
"My friend has invited me to his birthday party, and I will give him a gift.",
|
116 |
+
]
|
117 |
+
src_lang, tgt_lang = "eng_Latn", "hin_Deva"
|
118 |
+
hi_translations = batch_translate(
|
119 |
+
en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer
|
120 |
+
)
|
121 |
+
|
122 |
+
print(f"\n{src_lang} - {tgt_lang}")
|
123 |
+
for input_sentence, translation in zip(en_sents, hi_translations):
|
124 |
+
print(f"{src_lang}: {input_sentence}")
|
125 |
+
print(f"{tgt_lang}: {translation}")
|
handler.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Any
|
2 |
+
import sys, os, re
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig
|
7 |
+
from IndicTransTokenizer.utils import preprocess_batch, postprocess_batch
|
8 |
+
from IndicTransTokenizer.tokenizer import IndicTransTokenizer
|
9 |
+
|
10 |
+
|
11 |
+
class EndpointHandler():
|
12 |
+
def __init__(self, direction = "en-indic", quantization = ""):
|
13 |
+
self.model_name = "ai4bharat/indictrans2-en-indic-1B"
|
14 |
+
|
15 |
+
self.utterance_pattern = re.compile(r"^\d+$")
|
16 |
+
self.timestamp_pattern = re.compile(r"(\d+:\d+:\d+,\d+)\s*-->\s*(\d+:\d+:\d+,\d+)")
|
17 |
+
|
18 |
+
self.BATCH_SIZE = 16
|
19 |
+
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
|
21 |
+
self.model = None
|
22 |
+
self.tokenizer = None
|
23 |
+
|
24 |
+
if quantization == "4-bit":
|
25 |
+
qconfig = BitsAndBytesConfig(
|
26 |
+
load_in_4bit=True,
|
27 |
+
bnb_4bit_use_double_quant=True,
|
28 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
29 |
+
)
|
30 |
+
elif quantization == "8-bit":
|
31 |
+
qconfig = BitsAndBytesConfig(
|
32 |
+
load_in_8bit=True,
|
33 |
+
bnb_8bit_use_double_quant=True,
|
34 |
+
bnb_8bit_compute_dtype=torch.bfloat16,
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
qconfig = None
|
38 |
+
|
39 |
+
self.tokenizer = IndicTransTokenizer(direction=direction)
|
40 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
41 |
+
self.model_name,
|
42 |
+
trust_remote_code=True,
|
43 |
+
low_cpu_mem_usage=True,
|
44 |
+
quantization_config=qconfig
|
45 |
+
)
|
46 |
+
|
47 |
+
if qconfig==None:
|
48 |
+
self.model = self.model.to(self.DEVICE)
|
49 |
+
self.model.half()
|
50 |
+
|
51 |
+
self.model.eval()
|
52 |
+
|
53 |
+
|
54 |
+
def batch_translate(self, input_sentences, src_lang, tgt_lang):
|
55 |
+
translations = []
|
56 |
+
for i in range(0, len(input_sentences), self.BATCH_SIZE):
|
57 |
+
batch = input_sentences[i : i + self.BATCH_SIZE]
|
58 |
+
|
59 |
+
# Preprocess the batch and extract entity mappings
|
60 |
+
batch, entity_map = preprocess_batch(
|
61 |
+
batch, src_lang=src_lang, tgt_lang=tgt_lang
|
62 |
+
)
|
63 |
+
|
64 |
+
# Tokenize the batch and generate input encodings
|
65 |
+
inputs = self.tokenizer(
|
66 |
+
batch,
|
67 |
+
src=True,
|
68 |
+
truncation=True,
|
69 |
+
padding="longest",
|
70 |
+
return_tensors="pt",
|
71 |
+
return_attention_mask=True,
|
72 |
+
).to(self.DEVICE)
|
73 |
+
|
74 |
+
# Generate translations using the model
|
75 |
+
with torch.no_grad():
|
76 |
+
generated_tokens = self.model.generate(
|
77 |
+
**inputs,
|
78 |
+
use_cache=True,
|
79 |
+
min_length=0,
|
80 |
+
max_length=256,
|
81 |
+
num_beams=5,
|
82 |
+
num_return_sequences=1,
|
83 |
+
)
|
84 |
+
|
85 |
+
# Decode the generated tokens into text
|
86 |
+
generated_tokens = self.tokenizer.batch_decode(
|
87 |
+
generated_tokens.detach().cpu().tolist(), src=False
|
88 |
+
)
|
89 |
+
|
90 |
+
# Postprocess the translations, including entity replacement
|
91 |
+
translations += postprocess_batch(
|
92 |
+
generated_tokens, lang=tgt_lang, placeholder_entity_map=entity_map
|
93 |
+
)
|
94 |
+
|
95 |
+
del inputs
|
96 |
+
if torch.cuda.is_available():
|
97 |
+
torch.cuda.empty_cache()
|
98 |
+
|
99 |
+
return translations
|
100 |
+
|
101 |
+
|
102 |
+
def read_srt(self, srt_path):
|
103 |
+
data = []
|
104 |
+
with open(srt_path, 'r', encoding='utf-8') as fp:
|
105 |
+
utterance_ind = ""
|
106 |
+
start_end = ""
|
107 |
+
text = ""
|
108 |
+
for ind, line in enumerate(fp.readlines()):
|
109 |
+
line = line.strip()
|
110 |
+
if re.search(self.utterance_pattern, line) is not None:
|
111 |
+
utterance_ind = line
|
112 |
+
elif re.search(self.timestamp_pattern, line) is not None:
|
113 |
+
start_end = line
|
114 |
+
else:
|
115 |
+
text += line
|
116 |
+
|
117 |
+
if utterance_ind!='' and start_end!='' and text!='':
|
118 |
+
data.append({'utterance_ind': utterance_ind, 'start_end': start_end, 'text': text})
|
119 |
+
utterance_ind = ''
|
120 |
+
start_end = ''
|
121 |
+
text = ''
|
122 |
+
|
123 |
+
return data
|
124 |
+
|
125 |
+
def test(self, inputs) -> List[Dict[str, Any]]:
|
126 |
+
"""
|
127 |
+
data args:
|
128 |
+
inputs (:obj: (transcript_path : 'str', src_lang : 'str', tgt_lang : 'str')
|
129 |
+
kwargs
|
130 |
+
Return:
|
131 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
132 |
+
"""
|
133 |
+
|
134 |
+
src_lang = inputs["src_lang"]
|
135 |
+
tgt_lang = inputs["tgt_lang"]
|
136 |
+
transcript_path = inputs["transcript_path"]
|
137 |
+
|
138 |
+
output_translations = []
|
139 |
+
if self.model is not None:
|
140 |
+
transcriptions = self.read_srt(transcript_path)
|
141 |
+
trans_sents = [entry['text'] for entry in transcriptions]
|
142 |
+
indic_translations = self.batch_translate(trans_sents, src_lang, tgt_lang)
|
143 |
+
|
144 |
+
for i in tqdm(range(len(transcriptions))):
|
145 |
+
entry = transcriptions[i]
|
146 |
+
entry['text'] = indic_translations[i]
|
147 |
+
output_translations.append(entry)
|
148 |
+
|
149 |
+
return output_translations
|
150 |
+
else:
|
151 |
+
return []
|
152 |
+
|
153 |
+
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
154 |
+
"""
|
155 |
+
data args:
|
156 |
+
inputs (:obj: (transcript_path : 'str', src_lang : 'str', tgt_lang : 'str')
|
157 |
+
kwargs
|
158 |
+
Return:
|
159 |
+
A :obj:`list` | `dict`: will be serialized and returned
|
160 |
+
"""
|
161 |
+
|
162 |
+
inputs = data.pop("inputs",data)
|
163 |
+
|
164 |
+
src_lang = inputs["src_lang"]
|
165 |
+
tgt_lang = inputs["tgt_lang"]
|
166 |
+
transcript_path = inputs["transcript_path"]
|
167 |
+
|
168 |
+
output_translations = []
|
169 |
+
if self.model is not None:
|
170 |
+
transcriptions = self.read_srt(transcript_path)
|
171 |
+
indic_translations = self.batch_translate(transcriptions, src_lang, tgt_lang)
|
172 |
+
|
173 |
+
for i in tqdm(range(len(transcriptions))):
|
174 |
+
entry = transcriptions[i]
|
175 |
+
entry['text'] = indic_translations[i]
|
176 |
+
output_translations.append(entry)
|
177 |
+
|
178 |
+
return output_translations
|
179 |
+
else:
|
180 |
+
return []
|
181 |
+
|
182 |
+
|
183 |
+
if __name__ == "__main__":
|
184 |
+
endpoint = EndpointHandler(quantization = "8-bit")
|
185 |
+
inputs = {}
|
186 |
+
inputs['src_lang'] = 'eng_Latn'
|
187 |
+
inputs['tgt_lang'] = 'tel_Telu'
|
188 |
+
inputs['transcript_path'] = './sample.srt'
|
189 |
+
|
190 |
+
outputs = endpoint.test(inputs)
|
191 |
+
|
192 |
+
print("Outputs: ")
|
193 |
+
for entry in outputs:
|
194 |
+
print(entry)
|
install.sh
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#/bin/bash
|
2 |
+
|
3 |
+
root_dir=$(pwd)
|
4 |
+
echo "Setting up the environment in the $root_dir"
|
5 |
+
|
6 |
+
# --------------------------------------------------------------
|
7 |
+
# create and activate the virtual environment
|
8 |
+
# --------------------------------------------------------------
|
9 |
+
echo "Creating a virtual environment with python3"
|
10 |
+
conda create -n itv2_hf python=3.9 -y
|
11 |
+
conda activate itv2_hf
|
12 |
+
|
13 |
+
echo "Installing all the dependencies"
|
14 |
+
conda install pip
|
15 |
+
python3 -m pip install --upgrade pip
|
16 |
+
|
17 |
+
|
18 |
+
# --------------------------------------------------------------
|
19 |
+
# PyTorch Installation
|
20 |
+
# --------------------------------------------------------------
|
21 |
+
python3 -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
|
22 |
+
|
23 |
+
|
24 |
+
# --------------------------------------------------------------
|
25 |
+
# Install IndicNLP library and necessary resources
|
26 |
+
# --------------------------------------------------------------
|
27 |
+
git clone https://github.com/anoopkunchukuttan/indic_nlp_resources.git
|
28 |
+
export INDIC_RESOURCES_PATH=$root_dir/indic_nlp_resources
|
29 |
+
|
30 |
+
# we use version 0.92 which is the latest in the github repo
|
31 |
+
git clone https://github.com/anoopkunchukuttan/indic_nlp_library.git
|
32 |
+
cd indic_nlp_library
|
33 |
+
python3 -m pip install ./
|
34 |
+
cd $root_dir
|
35 |
+
|
36 |
+
|
37 |
+
# --------------------------------------------------------------
|
38 |
+
# Install additional utility packages
|
39 |
+
# --------------------------------------------------------------
|
40 |
+
python3 -m pip install sacremoses pandas regex mock transformers==4.33.2 urduhack[tf] mosestokenizer
|
41 |
+
python3 -c "import urduhack; urduhack.download()"
|
42 |
+
python3 -m pip install bitsandbytes scipy accelerate datasets
|
43 |
+
|
44 |
+
|
45 |
+
# --------------------------------------------------------------
|
46 |
+
# Sentencepiece for tokenization
|
47 |
+
# --------------------------------------------------------------
|
48 |
+
# build the cpp binaries from the source repo in order to use the command line utility
|
49 |
+
# source repo: https://github.com/google/sentencepiece
|
50 |
+
python3 -m pip install sentencepiece
|
51 |
+
|
52 |
+
echo "Setup completed!"
|
modeling_indictrans.py
ADDED
@@ -0,0 +1,1449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The IndicTrans2 Authors and AI4Bharat team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch IndicTrans model."""
|
16 |
+
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import List, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.nn import functional as F
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutput,
|
29 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
30 |
+
Seq2SeqLMOutput,
|
31 |
+
Seq2SeqModelOutput,
|
32 |
+
)
|
33 |
+
|
34 |
+
from transformers.utils import logging
|
35 |
+
from transformers.modeling_utils import PreTrainedModel
|
36 |
+
|
37 |
+
from configuration_indictrans import IndicTransConfig
|
38 |
+
|
39 |
+
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
_CONFIG_FOR_DOC = "IndicTransConfig"
|
43 |
+
|
44 |
+
INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
|
45 |
+
|
46 |
+
|
47 |
+
# Copied from transformers.models.bart.modeling_bart.shift_tokens_right
|
48 |
+
def shift_tokens_right(
|
49 |
+
input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int
|
50 |
+
):
|
51 |
+
"""
|
52 |
+
Shift input ids one token to the right.
|
53 |
+
"""
|
54 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
55 |
+
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
56 |
+
shifted_input_ids[:, 0] = decoder_start_token_id
|
57 |
+
|
58 |
+
if pad_token_id is None:
|
59 |
+
raise ValueError("self.model.config.pad_token_id has to be defined.")
|
60 |
+
# replace possible -100 values in labels by `pad_token_id`
|
61 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
|
62 |
+
|
63 |
+
return shifted_input_ids
|
64 |
+
|
65 |
+
|
66 |
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
67 |
+
def _make_causal_mask(
|
68 |
+
input_ids_shape: torch.Size,
|
69 |
+
dtype: torch.dtype,
|
70 |
+
device: torch.device,
|
71 |
+
past_key_values_length: int = 0,
|
72 |
+
):
|
73 |
+
"""
|
74 |
+
Make causal mask used for bi-directional self-attention.
|
75 |
+
"""
|
76 |
+
bsz, tgt_len = input_ids_shape
|
77 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
78 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
79 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
80 |
+
mask = mask.to(dtype)
|
81 |
+
|
82 |
+
if past_key_values_length > 0:
|
83 |
+
mask = torch.cat(
|
84 |
+
[
|
85 |
+
torch.zeros(
|
86 |
+
tgt_len, past_key_values_length, dtype=dtype, device=device
|
87 |
+
),
|
88 |
+
mask,
|
89 |
+
],
|
90 |
+
dim=-1,
|
91 |
+
)
|
92 |
+
return mask[None, None, :, :].expand(
|
93 |
+
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
98 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
99 |
+
"""
|
100 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
101 |
+
"""
|
102 |
+
bsz, src_len = mask.size()
|
103 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
104 |
+
|
105 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
106 |
+
|
107 |
+
inverted_mask = 1.0 - expanded_mask
|
108 |
+
|
109 |
+
return inverted_mask.masked_fill(
|
110 |
+
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
def create_position_ids_from_input_ids(
|
115 |
+
input_ids, padding_idx, past_key_values_length=0
|
116 |
+
):
|
117 |
+
"""
|
118 |
+
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
|
119 |
+
are ignored. This is modified from fairseq's `utils.make_positions`.
|
120 |
+
"""
|
121 |
+
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
122 |
+
mask = input_ids.ne(padding_idx).int()
|
123 |
+
incremental_indices = (
|
124 |
+
torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
|
125 |
+
) * mask
|
126 |
+
return incremental_indices.long() + padding_idx
|
127 |
+
|
128 |
+
|
129 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding->IndicTrans
|
130 |
+
class IndicTransSinusoidalPositionalEmbedding(nn.Module):
|
131 |
+
"""This module produces sinusoidal positional embeddings of any length."""
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
self.offset = 2
|
138 |
+
self.embedding_dim = embedding_dim
|
139 |
+
self.padding_idx = padding_idx
|
140 |
+
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
|
141 |
+
|
142 |
+
def make_weights(
|
143 |
+
self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
|
144 |
+
):
|
145 |
+
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
|
146 |
+
if hasattr(self, "weights"):
|
147 |
+
# in forward put the weights on the correct dtype and device of the param
|
148 |
+
emb_weights = emb_weights.to(
|
149 |
+
dtype=self.weights.dtype, device=self.weights.device
|
150 |
+
)
|
151 |
+
|
152 |
+
self.register_buffer("weights", emb_weights, persistent=False)
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def get_embedding(
|
156 |
+
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
|
157 |
+
):
|
158 |
+
"""
|
159 |
+
Build sinusoidal embeddings.
|
160 |
+
|
161 |
+
This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
|
162 |
+
"Attention Is All You Need".
|
163 |
+
"""
|
164 |
+
half_dim = embedding_dim // 2
|
165 |
+
emb = math.log(10000) / (half_dim - 1)
|
166 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
167 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
|
168 |
+
1
|
169 |
+
) * emb.unsqueeze(0)
|
170 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
|
171 |
+
num_embeddings, -1
|
172 |
+
)
|
173 |
+
if embedding_dim % 2 == 1:
|
174 |
+
# zero pad
|
175 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
176 |
+
if padding_idx is not None:
|
177 |
+
emb[padding_idx, :] = 0
|
178 |
+
|
179 |
+
return emb.to(torch.get_default_dtype())
|
180 |
+
|
181 |
+
@torch.no_grad()
|
182 |
+
def forward(
|
183 |
+
self,
|
184 |
+
input_ids: torch.Tensor = None,
|
185 |
+
inputs_embeds: torch.Tensor = None,
|
186 |
+
past_key_values_length: int = 0,
|
187 |
+
):
|
188 |
+
if input_ids is not None:
|
189 |
+
bsz, seq_len = input_ids.size()
|
190 |
+
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
191 |
+
position_ids = create_position_ids_from_input_ids(
|
192 |
+
input_ids, self.padding_idx, past_key_values_length
|
193 |
+
).to(input_ids.device)
|
194 |
+
else:
|
195 |
+
bsz, seq_len = inputs_embeds.size()[:-1]
|
196 |
+
position_ids = self.create_position_ids_from_inputs_embeds(
|
197 |
+
inputs_embeds, past_key_values_length
|
198 |
+
)
|
199 |
+
|
200 |
+
# expand embeddings if needed
|
201 |
+
max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
|
202 |
+
if max_pos > self.weights.size(0):
|
203 |
+
self.make_weights(
|
204 |
+
max_pos + self.offset, self.embedding_dim, self.padding_idx
|
205 |
+
)
|
206 |
+
|
207 |
+
return (
|
208 |
+
self.weights.index_select(0, position_ids.view(-1))
|
209 |
+
.view(bsz, seq_len, self.weights.shape[-1])
|
210 |
+
.detach()
|
211 |
+
)
|
212 |
+
|
213 |
+
def create_position_ids_from_inputs_embeds(
|
214 |
+
self, inputs_embeds, past_key_values_length
|
215 |
+
):
|
216 |
+
"""
|
217 |
+
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
inputs_embeds: torch.Tensor
|
221 |
+
|
222 |
+
Returns: torch.Tensor
|
223 |
+
"""
|
224 |
+
input_shape = inputs_embeds.size()[:-1]
|
225 |
+
sequence_length = input_shape[1]
|
226 |
+
|
227 |
+
position_ids = torch.arange(
|
228 |
+
self.padding_idx + 1,
|
229 |
+
sequence_length + self.padding_idx + 1,
|
230 |
+
dtype=torch.long,
|
231 |
+
device=inputs_embeds.device,
|
232 |
+
)
|
233 |
+
return (
|
234 |
+
position_ids.unsqueeze(0).expand(input_shape).contiguous()
|
235 |
+
+ past_key_values_length
|
236 |
+
)
|
237 |
+
|
238 |
+
|
239 |
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->IndicTrans
|
240 |
+
class IndicTransAttention(nn.Module):
|
241 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
242 |
+
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
embed_dim: int,
|
246 |
+
num_heads: int,
|
247 |
+
dropout: float = 0.0,
|
248 |
+
is_decoder: bool = False,
|
249 |
+
bias: bool = True,
|
250 |
+
):
|
251 |
+
super().__init__()
|
252 |
+
self.embed_dim = embed_dim
|
253 |
+
self.num_heads = num_heads
|
254 |
+
self.dropout = dropout
|
255 |
+
self.head_dim = embed_dim // num_heads
|
256 |
+
|
257 |
+
if (self.head_dim * num_heads) != self.embed_dim:
|
258 |
+
raise ValueError(
|
259 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
260 |
+
f" and `num_heads`: {num_heads})."
|
261 |
+
)
|
262 |
+
self.scaling = self.head_dim**-0.5
|
263 |
+
self.is_decoder = is_decoder
|
264 |
+
|
265 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
266 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
267 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
268 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
269 |
+
|
270 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
271 |
+
return (
|
272 |
+
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
273 |
+
.transpose(1, 2)
|
274 |
+
.contiguous()
|
275 |
+
)
|
276 |
+
|
277 |
+
def forward(
|
278 |
+
self,
|
279 |
+
hidden_states: torch.Tensor,
|
280 |
+
key_value_states: Optional[torch.Tensor] = None,
|
281 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
282 |
+
attention_mask: Optional[torch.Tensor] = None,
|
283 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
284 |
+
output_attentions: bool = False,
|
285 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
286 |
+
"""Input shape: Batch x Time x Channel"""
|
287 |
+
|
288 |
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
289 |
+
# for the decoder
|
290 |
+
is_cross_attention = key_value_states is not None
|
291 |
+
|
292 |
+
bsz, tgt_len, _ = hidden_states.size()
|
293 |
+
|
294 |
+
# get query proj
|
295 |
+
query_states = self.q_proj(hidden_states) * self.scaling
|
296 |
+
# get key, value proj
|
297 |
+
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
298 |
+
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
299 |
+
# the provided `key_value_states` to support prefix tuning
|
300 |
+
if (
|
301 |
+
is_cross_attention
|
302 |
+
and past_key_value is not None
|
303 |
+
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
304 |
+
):
|
305 |
+
# reuse k,v, cross_attentions
|
306 |
+
key_states = past_key_value[0]
|
307 |
+
value_states = past_key_value[1]
|
308 |
+
elif is_cross_attention:
|
309 |
+
# cross_attentions
|
310 |
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
311 |
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
312 |
+
elif past_key_value is not None:
|
313 |
+
# reuse k, v, self_attention
|
314 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
315 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
316 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
317 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
318 |
+
else:
|
319 |
+
# self_attention
|
320 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
321 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
322 |
+
|
323 |
+
if self.is_decoder:
|
324 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
325 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
326 |
+
# key/value_states (first "if" case)
|
327 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
328 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
329 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
330 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
331 |
+
past_key_value = (key_states, value_states)
|
332 |
+
|
333 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
334 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
335 |
+
key_states = key_states.reshape(*proj_shape)
|
336 |
+
value_states = value_states.reshape(*proj_shape)
|
337 |
+
|
338 |
+
src_len = key_states.size(1)
|
339 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
340 |
+
|
341 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
342 |
+
raise ValueError(
|
343 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
344 |
+
f" {attn_weights.size()}"
|
345 |
+
)
|
346 |
+
|
347 |
+
if attention_mask is not None:
|
348 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
349 |
+
raise ValueError(
|
350 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
351 |
+
)
|
352 |
+
attn_weights = (
|
353 |
+
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
354 |
+
+ attention_mask
|
355 |
+
)
|
356 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
357 |
+
|
358 |
+
attn_weights = F.softmax(attn_weights, dim=-1)
|
359 |
+
|
360 |
+
if layer_head_mask is not None:
|
361 |
+
if layer_head_mask.size() != (self.num_heads,):
|
362 |
+
raise ValueError(
|
363 |
+
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
364 |
+
f" {layer_head_mask.size()}"
|
365 |
+
)
|
366 |
+
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
|
367 |
+
bsz, self.num_heads, tgt_len, src_len
|
368 |
+
)
|
369 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
370 |
+
|
371 |
+
if output_attentions:
|
372 |
+
# this operation is a bit awkward, but it's required to
|
373 |
+
# make sure that attn_weights keeps its gradient.
|
374 |
+
# In order to do so, attn_weights have to be reshaped
|
375 |
+
# twice and have to be reused in the following
|
376 |
+
attn_weights_reshaped = attn_weights.view(
|
377 |
+
bsz, self.num_heads, tgt_len, src_len
|
378 |
+
)
|
379 |
+
attn_weights = attn_weights_reshaped.view(
|
380 |
+
bsz * self.num_heads, tgt_len, src_len
|
381 |
+
)
|
382 |
+
else:
|
383 |
+
attn_weights_reshaped = None
|
384 |
+
|
385 |
+
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
|
386 |
+
|
387 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
388 |
+
|
389 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
390 |
+
raise ValueError(
|
391 |
+
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
392 |
+
f" {attn_output.size()}"
|
393 |
+
)
|
394 |
+
|
395 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
396 |
+
attn_output = attn_output.transpose(1, 2)
|
397 |
+
|
398 |
+
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
399 |
+
# partitioned across GPUs when using tensor-parallelism.
|
400 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
401 |
+
|
402 |
+
attn_output = self.out_proj(attn_output)
|
403 |
+
|
404 |
+
return attn_output, attn_weights_reshaped, past_key_value
|
405 |
+
|
406 |
+
|
407 |
+
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
|
408 |
+
class IndicTransEncoderLayer(nn.Module):
|
409 |
+
def __init__(self, config: IndicTransConfig):
|
410 |
+
super().__init__()
|
411 |
+
self.embed_dim = config.encoder_embed_dim
|
412 |
+
self.self_attn = IndicTransAttention(
|
413 |
+
embed_dim=self.embed_dim,
|
414 |
+
num_heads=config.encoder_attention_heads,
|
415 |
+
dropout=config.attention_dropout,
|
416 |
+
)
|
417 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
418 |
+
self.dropout = config.dropout
|
419 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
420 |
+
self.activation_dropout = config.activation_dropout
|
421 |
+
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
|
422 |
+
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
|
423 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
424 |
+
self.normalize_before = config.encoder_normalize_before
|
425 |
+
|
426 |
+
def forward(
|
427 |
+
self,
|
428 |
+
hidden_states: torch.Tensor,
|
429 |
+
attention_mask: torch.Tensor,
|
430 |
+
layer_head_mask: torch.Tensor,
|
431 |
+
output_attentions: bool = False,
|
432 |
+
) -> torch.Tensor:
|
433 |
+
"""
|
434 |
+
Args:
|
435 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
436 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
437 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
438 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
439 |
+
`(encoder_attention_heads,)`.
|
440 |
+
output_attentions (`bool`, *optional*):
|
441 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
442 |
+
returned tensors for more detail.
|
443 |
+
"""
|
444 |
+
residual = hidden_states
|
445 |
+
if self.normalize_before:
|
446 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
447 |
+
hidden_states, attn_weights, _ = self.self_attn(
|
448 |
+
hidden_states=hidden_states,
|
449 |
+
attention_mask=attention_mask,
|
450 |
+
layer_head_mask=layer_head_mask,
|
451 |
+
output_attentions=output_attentions,
|
452 |
+
)
|
453 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
454 |
+
hidden_states = residual + hidden_states
|
455 |
+
if not self.normalize_before:
|
456 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
457 |
+
|
458 |
+
residual = hidden_states
|
459 |
+
if self.normalize_before:
|
460 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
461 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
462 |
+
hidden_states = F.dropout(
|
463 |
+
hidden_states, p=self.activation_dropout, training=self.training
|
464 |
+
)
|
465 |
+
hidden_states = self.fc2(hidden_states)
|
466 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
467 |
+
hidden_states = residual + hidden_states
|
468 |
+
if not self.normalize_before:
|
469 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
470 |
+
|
471 |
+
if hidden_states.dtype == torch.float16 and (
|
472 |
+
torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
|
473 |
+
):
|
474 |
+
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
475 |
+
hidden_states = torch.clamp(
|
476 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
477 |
+
)
|
478 |
+
|
479 |
+
outputs = (hidden_states,)
|
480 |
+
|
481 |
+
if output_attentions:
|
482 |
+
outputs += (attn_weights,)
|
483 |
+
|
484 |
+
return outputs
|
485 |
+
|
486 |
+
|
487 |
+
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->IndicTrans
|
488 |
+
class IndicTransDecoderLayer(nn.Module):
|
489 |
+
def __init__(self, config: IndicTransConfig):
|
490 |
+
super().__init__()
|
491 |
+
self.embed_dim = config.decoder_embed_dim
|
492 |
+
|
493 |
+
self.self_attn = IndicTransAttention(
|
494 |
+
embed_dim=self.embed_dim,
|
495 |
+
num_heads=config.decoder_attention_heads,
|
496 |
+
dropout=config.attention_dropout,
|
497 |
+
is_decoder=True,
|
498 |
+
)
|
499 |
+
self.dropout = config.dropout
|
500 |
+
self.activation_fn = ACT2FN[config.activation_function]
|
501 |
+
self.activation_dropout = config.activation_dropout
|
502 |
+
|
503 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
504 |
+
self.encoder_attn = IndicTransAttention(
|
505 |
+
self.embed_dim,
|
506 |
+
config.decoder_attention_heads,
|
507 |
+
dropout=config.attention_dropout,
|
508 |
+
is_decoder=True,
|
509 |
+
)
|
510 |
+
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
511 |
+
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
|
512 |
+
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
|
513 |
+
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
514 |
+
self.normalize_before = config.decoder_normalize_before
|
515 |
+
|
516 |
+
def forward(
|
517 |
+
self,
|
518 |
+
hidden_states: torch.Tensor,
|
519 |
+
attention_mask: Optional[torch.Tensor] = None,
|
520 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
521 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
522 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
523 |
+
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
524 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
525 |
+
output_attentions: Optional[bool] = False,
|
526 |
+
use_cache: Optional[bool] = True,
|
527 |
+
) -> torch.Tensor:
|
528 |
+
"""
|
529 |
+
Args:
|
530 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
531 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
532 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
533 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
534 |
+
cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
|
535 |
+
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
|
536 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
537 |
+
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
|
538 |
+
`(encoder_attention_heads,)`.
|
539 |
+
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
|
540 |
+
size `(decoder_attention_heads,)`.
|
541 |
+
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
|
542 |
+
output_attentions (`bool`, *optional*):
|
543 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
544 |
+
returned tensors for more detail.
|
545 |
+
"""
|
546 |
+
residual = hidden_states
|
547 |
+
if self.normalize_before:
|
548 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
549 |
+
|
550 |
+
# Self Attention
|
551 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
552 |
+
self_attn_past_key_value = (
|
553 |
+
past_key_value[:2] if past_key_value is not None else None
|
554 |
+
)
|
555 |
+
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
556 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
557 |
+
hidden_states=hidden_states,
|
558 |
+
past_key_value=self_attn_past_key_value,
|
559 |
+
attention_mask=attention_mask,
|
560 |
+
layer_head_mask=layer_head_mask,
|
561 |
+
output_attentions=output_attentions,
|
562 |
+
)
|
563 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
564 |
+
hidden_states = residual + hidden_states
|
565 |
+
if not self.normalize_before:
|
566 |
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
567 |
+
|
568 |
+
# Cross-Attention Block
|
569 |
+
cross_attn_present_key_value = None
|
570 |
+
cross_attn_weights = None
|
571 |
+
if encoder_hidden_states is not None:
|
572 |
+
residual = hidden_states
|
573 |
+
if self.normalize_before:
|
574 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
575 |
+
|
576 |
+
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
577 |
+
cross_attn_past_key_value = (
|
578 |
+
past_key_value[-2:] if past_key_value is not None else None
|
579 |
+
)
|
580 |
+
(
|
581 |
+
hidden_states,
|
582 |
+
cross_attn_weights,
|
583 |
+
cross_attn_present_key_value,
|
584 |
+
) = self.encoder_attn(
|
585 |
+
hidden_states=hidden_states,
|
586 |
+
key_value_states=encoder_hidden_states,
|
587 |
+
attention_mask=encoder_attention_mask,
|
588 |
+
layer_head_mask=cross_attn_layer_head_mask,
|
589 |
+
past_key_value=cross_attn_past_key_value,
|
590 |
+
output_attentions=output_attentions,
|
591 |
+
)
|
592 |
+
hidden_states = F.dropout(
|
593 |
+
hidden_states, p=self.dropout, training=self.training
|
594 |
+
)
|
595 |
+
hidden_states = residual + hidden_states
|
596 |
+
if not self.normalize_before:
|
597 |
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
598 |
+
|
599 |
+
# add cross-attn to positions 3,4 of present_key_value tuple
|
600 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
601 |
+
|
602 |
+
# Fully Connected
|
603 |
+
residual = hidden_states
|
604 |
+
if self.normalize_before:
|
605 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
606 |
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
607 |
+
hidden_states = F.dropout(
|
608 |
+
hidden_states, p=self.activation_dropout, training=self.training
|
609 |
+
)
|
610 |
+
hidden_states = self.fc2(hidden_states)
|
611 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
612 |
+
hidden_states = residual + hidden_states
|
613 |
+
if not self.normalize_before:
|
614 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
615 |
+
|
616 |
+
outputs = (hidden_states,)
|
617 |
+
|
618 |
+
if output_attentions:
|
619 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
620 |
+
|
621 |
+
if use_cache:
|
622 |
+
outputs += (present_key_value,)
|
623 |
+
|
624 |
+
return outputs
|
625 |
+
|
626 |
+
|
627 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->IndicTrans
|
628 |
+
class IndicTransPreTrainedModel(PreTrainedModel):
|
629 |
+
config_class = IndicTransConfig
|
630 |
+
base_model_prefix = "model"
|
631 |
+
supports_gradient_checkpointing = True
|
632 |
+
_no_split_modules = ["IndicTransAttention"]
|
633 |
+
|
634 |
+
def _init_weights(self, module):
|
635 |
+
std = self.config.init_std
|
636 |
+
if isinstance(module, nn.Linear):
|
637 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
638 |
+
if module.bias is not None:
|
639 |
+
module.bias.data.zero_()
|
640 |
+
elif isinstance(module, nn.Embedding):
|
641 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
642 |
+
if module.padding_idx is not None:
|
643 |
+
module.weight.data[module.padding_idx].zero_()
|
644 |
+
|
645 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
646 |
+
if isinstance(module, (IndicTransDecoder, IndicTransEncoder)):
|
647 |
+
module.gradient_checkpointing = value
|
648 |
+
|
649 |
+
|
650 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->IndicTrans
|
651 |
+
class IndicTransEncoder(IndicTransPreTrainedModel):
|
652 |
+
"""
|
653 |
+
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
|
654 |
+
[`IndicTransEncoderLayer`].
|
655 |
+
|
656 |
+
Args:
|
657 |
+
config: IndicTransConfig
|
658 |
+
embed_tokens (nn.Embedding): output embedding
|
659 |
+
"""
|
660 |
+
|
661 |
+
def __init__(
|
662 |
+
self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
|
663 |
+
):
|
664 |
+
super().__init__(config)
|
665 |
+
|
666 |
+
self.dropout = config.dropout
|
667 |
+
self.layerdrop = config.encoder_layerdrop
|
668 |
+
|
669 |
+
embed_dim = config.encoder_embed_dim
|
670 |
+
self.padding_idx = config.pad_token_id
|
671 |
+
self.max_source_positions = config.max_source_positions
|
672 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
673 |
+
|
674 |
+
self.embed_tokens = nn.Embedding(
|
675 |
+
config.encoder_vocab_size, embed_dim, self.padding_idx
|
676 |
+
)
|
677 |
+
|
678 |
+
if embed_tokens is not None:
|
679 |
+
self.embed_tokens.weight = embed_tokens.weight
|
680 |
+
|
681 |
+
self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
|
682 |
+
config.max_source_positions,
|
683 |
+
embed_dim,
|
684 |
+
self.padding_idx,
|
685 |
+
)
|
686 |
+
self.layers = nn.ModuleList(
|
687 |
+
[IndicTransEncoderLayer(config) for _ in range(config.encoder_layers)]
|
688 |
+
)
|
689 |
+
self.layer_norm = (
|
690 |
+
nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None
|
691 |
+
)
|
692 |
+
self.layernorm_embedding = (
|
693 |
+
nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
|
694 |
+
)
|
695 |
+
|
696 |
+
self.gradient_checkpointing = False
|
697 |
+
# Initialize weights and apply final processing
|
698 |
+
self.post_init()
|
699 |
+
|
700 |
+
def forward(
|
701 |
+
self,
|
702 |
+
input_ids: Optional[torch.Tensor] = None,
|
703 |
+
attention_mask: Optional[torch.Tensor] = None,
|
704 |
+
head_mask: Optional[torch.Tensor] = None,
|
705 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
706 |
+
output_attentions: Optional[bool] = None,
|
707 |
+
output_hidden_states: Optional[bool] = None,
|
708 |
+
return_dict: Optional[bool] = None,
|
709 |
+
):
|
710 |
+
r"""
|
711 |
+
Args:
|
712 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
713 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
714 |
+
provide it.
|
715 |
+
|
716 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
717 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
718 |
+
|
719 |
+
[What are input IDs?](../glossary#input-ids)
|
720 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
721 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
722 |
+
|
723 |
+
- 1 for tokens that are **not masked**,
|
724 |
+
- 0 for tokens that are **masked**.
|
725 |
+
|
726 |
+
[What are attention masks?](../glossary#attention-mask)
|
727 |
+
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
|
728 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
729 |
+
|
730 |
+
- 1 indicates the head is **not masked**,
|
731 |
+
- 0 indicates the head is **masked**.
|
732 |
+
|
733 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
734 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
735 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
736 |
+
than the model's internal embedding lookup matrix.
|
737 |
+
output_attentions (`bool`, *optional*):
|
738 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
739 |
+
returned tensors for more detail.
|
740 |
+
output_hidden_states (`bool`, *optional*):
|
741 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
742 |
+
for more detail.
|
743 |
+
return_dict (`bool`, *optional*):
|
744 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
745 |
+
"""
|
746 |
+
output_attentions = (
|
747 |
+
output_attentions
|
748 |
+
if output_attentions is not None
|
749 |
+
else self.config.output_attentions
|
750 |
+
)
|
751 |
+
output_hidden_states = (
|
752 |
+
output_hidden_states
|
753 |
+
if output_hidden_states is not None
|
754 |
+
else self.config.output_hidden_states
|
755 |
+
)
|
756 |
+
return_dict = (
|
757 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
758 |
+
)
|
759 |
+
|
760 |
+
# retrieve input_ids and inputs_embeds
|
761 |
+
if input_ids is not None and inputs_embeds is not None:
|
762 |
+
raise ValueError(
|
763 |
+
"You cannot specify both input_ids and inputs_embeds at the same time"
|
764 |
+
)
|
765 |
+
elif input_ids is not None:
|
766 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
767 |
+
input_shape = input_ids.size()
|
768 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
769 |
+
elif inputs_embeds is not None:
|
770 |
+
input_shape = inputs_embeds.size()[:-1]
|
771 |
+
else:
|
772 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
773 |
+
|
774 |
+
if inputs_embeds is None:
|
775 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
776 |
+
|
777 |
+
embed_pos = self.embed_positions(input_ids, inputs_embeds)
|
778 |
+
embed_pos = embed_pos.to(inputs_embeds.device)
|
779 |
+
|
780 |
+
hidden_states = inputs_embeds + embed_pos
|
781 |
+
if self.layernorm_embedding is not None:
|
782 |
+
x = self.layernorm_embedding(hidden_states)
|
783 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
784 |
+
|
785 |
+
# expand attention_mask
|
786 |
+
if attention_mask is not None:
|
787 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
788 |
+
attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
|
789 |
+
|
790 |
+
encoder_states = () if output_hidden_states else None
|
791 |
+
all_attentions = () if output_attentions else None
|
792 |
+
|
793 |
+
# check if head_mask has a correct number of layers specified if desired
|
794 |
+
if head_mask is not None:
|
795 |
+
if head_mask.size()[0] != len(self.layers):
|
796 |
+
raise ValueError(
|
797 |
+
f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
|
798 |
+
f" {head_mask.size()[0]}."
|
799 |
+
)
|
800 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
801 |
+
|
802 |
+
for idx, encoder_layer in enumerate(self.layers):
|
803 |
+
if output_hidden_states:
|
804 |
+
encoder_states = encoder_states + (hidden_states,)
|
805 |
+
|
806 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
807 |
+
dropout_probability = torch.rand([])
|
808 |
+
|
809 |
+
skip_the_layer = (
|
810 |
+
True
|
811 |
+
if self.training and (dropout_probability < self.layerdrop)
|
812 |
+
else False
|
813 |
+
)
|
814 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
815 |
+
# under deepspeed zero3 all gpus must run in sync
|
816 |
+
|
817 |
+
if self.gradient_checkpointing and self.training:
|
818 |
+
# create gradient checkpointing function
|
819 |
+
def create_custom_forward(module):
|
820 |
+
def custom_forward(*inputs):
|
821 |
+
return module(*inputs, output_attentions)
|
822 |
+
|
823 |
+
return custom_forward
|
824 |
+
|
825 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
826 |
+
create_custom_forward(encoder_layer),
|
827 |
+
hidden_states,
|
828 |
+
attention_mask,
|
829 |
+
(head_mask[idx] if head_mask is not None else None),
|
830 |
+
)
|
831 |
+
else:
|
832 |
+
layer_outputs = encoder_layer(
|
833 |
+
hidden_states,
|
834 |
+
attention_mask,
|
835 |
+
layer_head_mask=(
|
836 |
+
head_mask[idx] if head_mask is not None else None
|
837 |
+
),
|
838 |
+
output_attentions=output_attentions,
|
839 |
+
)
|
840 |
+
|
841 |
+
hidden_states = layer_outputs[0]
|
842 |
+
|
843 |
+
if skip_the_layer:
|
844 |
+
layer_outputs = (None, None)
|
845 |
+
|
846 |
+
if output_attentions:
|
847 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
848 |
+
|
849 |
+
if self.layer_norm is not None:
|
850 |
+
hidden_states = self.layer_norm(hidden_states)
|
851 |
+
|
852 |
+
if output_hidden_states:
|
853 |
+
encoder_states = encoder_states + (hidden_states,)
|
854 |
+
|
855 |
+
if not return_dict:
|
856 |
+
return tuple(
|
857 |
+
v
|
858 |
+
for v in [hidden_states, encoder_states, all_attentions]
|
859 |
+
if v is not None
|
860 |
+
)
|
861 |
+
return BaseModelOutput(
|
862 |
+
last_hidden_state=hidden_states,
|
863 |
+
hidden_states=encoder_states,
|
864 |
+
attentions=all_attentions,
|
865 |
+
)
|
866 |
+
|
867 |
+
|
868 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->IndicTrans
|
869 |
+
class IndicTransDecoder(IndicTransPreTrainedModel):
|
870 |
+
"""
|
871 |
+
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`IndicTransDecoderLayer`]
|
872 |
+
|
873 |
+
Args:
|
874 |
+
config: IndicTransConfig
|
875 |
+
embed_tokens (nn.Embedding): output embedding
|
876 |
+
"""
|
877 |
+
|
878 |
+
def __init__(
|
879 |
+
self, config: IndicTransConfig, embed_tokens: Optional[nn.Embedding] = None
|
880 |
+
):
|
881 |
+
super().__init__(config)
|
882 |
+
self.dropout = config.dropout
|
883 |
+
self.layerdrop = config.decoder_layerdrop
|
884 |
+
|
885 |
+
embed_dim = config.encoder_embed_dim
|
886 |
+
self.padding_idx = config.pad_token_id
|
887 |
+
self.max_target_positions = config.max_target_positions
|
888 |
+
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
889 |
+
|
890 |
+
self.embed_tokens = nn.Embedding(
|
891 |
+
config.decoder_vocab_size, embed_dim, self.padding_idx
|
892 |
+
)
|
893 |
+
|
894 |
+
if embed_tokens is not None:
|
895 |
+
self.embed_tokens.weight = embed_tokens.weight
|
896 |
+
|
897 |
+
self.embed_positions = IndicTransSinusoidalPositionalEmbedding(
|
898 |
+
config.max_target_positions,
|
899 |
+
embed_dim,
|
900 |
+
self.padding_idx,
|
901 |
+
)
|
902 |
+
self.layers = nn.ModuleList(
|
903 |
+
[IndicTransDecoderLayer(config) for _ in range(config.decoder_layers)]
|
904 |
+
)
|
905 |
+
self.layer_norm = (
|
906 |
+
nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None
|
907 |
+
)
|
908 |
+
self.layernorm_embedding = (
|
909 |
+
nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
|
910 |
+
)
|
911 |
+
|
912 |
+
self.gradient_checkpointing = False
|
913 |
+
# Initialize weights and apply final processing
|
914 |
+
self.post_init()
|
915 |
+
|
916 |
+
def forward(
|
917 |
+
self,
|
918 |
+
input_ids: Optional[torch.Tensor] = None,
|
919 |
+
attention_mask: Optional[torch.Tensor] = None,
|
920 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
921 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
922 |
+
head_mask: Optional[torch.Tensor] = None,
|
923 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
924 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
925 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
926 |
+
use_cache: Optional[bool] = None,
|
927 |
+
output_attentions: Optional[bool] = None,
|
928 |
+
output_hidden_states: Optional[bool] = None,
|
929 |
+
return_dict: Optional[bool] = None,
|
930 |
+
):
|
931 |
+
r"""
|
932 |
+
Args:
|
933 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
934 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
|
935 |
+
provide it.
|
936 |
+
|
937 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
938 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
939 |
+
|
940 |
+
[What are input IDs?](../glossary#input-ids)
|
941 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
942 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
943 |
+
|
944 |
+
- 1 for tokens that are **not masked**,
|
945 |
+
- 0 for tokens that are **masked**.
|
946 |
+
|
947 |
+
[What are attention masks?](../glossary#attention-mask)
|
948 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
|
949 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
950 |
+
of the decoder.
|
951 |
+
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
|
952 |
+
Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
|
953 |
+
selected in `[0, 1]`:
|
954 |
+
|
955 |
+
- 1 for tokens that are **not masked**,
|
956 |
+
- 0 for tokens that are **masked**.
|
957 |
+
|
958 |
+
[What are attention masks?](../glossary#attention-mask)
|
959 |
+
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
960 |
+
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
|
961 |
+
|
962 |
+
- 1 indicates the head is **not masked**,
|
963 |
+
- 0 indicates the head is **masked**.
|
964 |
+
|
965 |
+
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
|
966 |
+
Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
|
967 |
+
cross-attention on hidden heads. Mask values selected in `[0, 1]`:
|
968 |
+
|
969 |
+
- 1 indicates the head is **not masked**,
|
970 |
+
- 0 indicates the head is **masked**.
|
971 |
+
|
972 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
973 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
974 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
975 |
+
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
976 |
+
|
977 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
978 |
+
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
979 |
+
|
980 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
981 |
+
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
982 |
+
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
|
983 |
+
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
|
984 |
+
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
|
985 |
+
control over how to convert `input_ids` indices into associated vectors than the model's internal
|
986 |
+
embedding lookup matrix.
|
987 |
+
output_attentions (`bool`, *optional*):
|
988 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
989 |
+
returned tensors for more detail.
|
990 |
+
output_hidden_states (`bool`, *optional*):
|
991 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
992 |
+
for more detail.
|
993 |
+
return_dict (`bool`, *optional*):
|
994 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
995 |
+
"""
|
996 |
+
output_attentions = (
|
997 |
+
output_attentions
|
998 |
+
if output_attentions is not None
|
999 |
+
else self.config.output_attentions
|
1000 |
+
)
|
1001 |
+
output_hidden_states = (
|
1002 |
+
output_hidden_states
|
1003 |
+
if output_hidden_states is not None
|
1004 |
+
else self.config.output_hidden_states
|
1005 |
+
)
|
1006 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1007 |
+
return_dict = (
|
1008 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1009 |
+
)
|
1010 |
+
|
1011 |
+
# retrieve input_ids and inputs_embeds
|
1012 |
+
if input_ids is not None and inputs_embeds is not None:
|
1013 |
+
raise ValueError(
|
1014 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
1015 |
+
)
|
1016 |
+
elif input_ids is not None:
|
1017 |
+
input_shape = input_ids.size()
|
1018 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
1019 |
+
elif inputs_embeds is not None:
|
1020 |
+
input_shape = inputs_embeds.size()[:-1]
|
1021 |
+
else:
|
1022 |
+
raise ValueError(
|
1023 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
# past_key_values_length
|
1027 |
+
past_key_values_length = (
|
1028 |
+
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
if inputs_embeds is None:
|
1032 |
+
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
1033 |
+
|
1034 |
+
# create causal mask
|
1035 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1036 |
+
combined_attention_mask = None
|
1037 |
+
if input_shape[-1] > 1:
|
1038 |
+
combined_attention_mask = _make_causal_mask(
|
1039 |
+
input_shape,
|
1040 |
+
inputs_embeds.dtype,
|
1041 |
+
device=inputs_embeds.device,
|
1042 |
+
past_key_values_length=past_key_values_length,
|
1043 |
+
)
|
1044 |
+
|
1045 |
+
if attention_mask is not None and combined_attention_mask is not None:
|
1046 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1047 |
+
combined_attention_mask = combined_attention_mask + _expand_mask(
|
1048 |
+
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
1049 |
+
)
|
1050 |
+
|
1051 |
+
# expand encoder attention mask
|
1052 |
+
if encoder_hidden_states is not None and encoder_attention_mask is not None:
|
1053 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
1054 |
+
encoder_attention_mask = _expand_mask(
|
1055 |
+
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
1056 |
+
)
|
1057 |
+
|
1058 |
+
# embed positions
|
1059 |
+
positions = self.embed_positions(
|
1060 |
+
input_ids, inputs_embeds, past_key_values_length
|
1061 |
+
)
|
1062 |
+
positions = positions.to(inputs_embeds.device)
|
1063 |
+
|
1064 |
+
hidden_states = inputs_embeds + positions
|
1065 |
+
if self.layernorm_embedding is not None:
|
1066 |
+
hidden_states = self.layernorm_embedding(hidden_states)
|
1067 |
+
|
1068 |
+
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
|
1069 |
+
|
1070 |
+
if self.gradient_checkpointing and self.training:
|
1071 |
+
if use_cache:
|
1072 |
+
logger.warning_once(
|
1073 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting"
|
1074 |
+
" `use_cache=False`..."
|
1075 |
+
)
|
1076 |
+
use_cache = False
|
1077 |
+
|
1078 |
+
# decoder layers
|
1079 |
+
all_hidden_states = () if output_hidden_states else None
|
1080 |
+
all_self_attns = () if output_attentions else None
|
1081 |
+
all_cross_attentions = () if output_attentions else None
|
1082 |
+
next_decoder_cache = () if use_cache else None
|
1083 |
+
|
1084 |
+
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
1085 |
+
for attn_mask, mask_name in zip(
|
1086 |
+
[head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
|
1087 |
+
):
|
1088 |
+
if attn_mask is not None:
|
1089 |
+
if attn_mask.size()[0] != len(self.layers):
|
1090 |
+
raise ValueError(
|
1091 |
+
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
|
1092 |
+
f" {head_mask.size()[0]}."
|
1093 |
+
)
|
1094 |
+
deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
|
1095 |
+
|
1096 |
+
for idx, decoder_layer in enumerate(self.layers):
|
1097 |
+
if output_hidden_states:
|
1098 |
+
all_hidden_states += (hidden_states,)
|
1099 |
+
|
1100 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
1101 |
+
dropout_probability = torch.rand([])
|
1102 |
+
|
1103 |
+
skip_the_layer = (
|
1104 |
+
True
|
1105 |
+
if self.training and (dropout_probability < self.layerdrop)
|
1106 |
+
else False
|
1107 |
+
)
|
1108 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
1109 |
+
# under deepspeed zero3 all gpus must run in sync
|
1110 |
+
|
1111 |
+
past_key_value = (
|
1112 |
+
past_key_values[idx] if past_key_values is not None else None
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
if self.gradient_checkpointing and self.training:
|
1116 |
+
|
1117 |
+
def create_custom_forward(module):
|
1118 |
+
def custom_forward(*inputs):
|
1119 |
+
# None for past_key_value
|
1120 |
+
return module(*inputs, output_attentions, use_cache)
|
1121 |
+
|
1122 |
+
return custom_forward
|
1123 |
+
|
1124 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
1125 |
+
create_custom_forward(decoder_layer),
|
1126 |
+
hidden_states,
|
1127 |
+
combined_attention_mask,
|
1128 |
+
encoder_hidden_states,
|
1129 |
+
encoder_attention_mask,
|
1130 |
+
head_mask[idx] if head_mask is not None else None,
|
1131 |
+
cross_attn_head_mask[idx]
|
1132 |
+
if cross_attn_head_mask is not None
|
1133 |
+
else None,
|
1134 |
+
None,
|
1135 |
+
)
|
1136 |
+
else:
|
1137 |
+
layer_outputs = decoder_layer(
|
1138 |
+
hidden_states,
|
1139 |
+
attention_mask=combined_attention_mask,
|
1140 |
+
encoder_hidden_states=encoder_hidden_states,
|
1141 |
+
encoder_attention_mask=encoder_attention_mask,
|
1142 |
+
layer_head_mask=(
|
1143 |
+
head_mask[idx] if head_mask is not None else None
|
1144 |
+
),
|
1145 |
+
cross_attn_layer_head_mask=(
|
1146 |
+
cross_attn_head_mask[idx]
|
1147 |
+
if cross_attn_head_mask is not None
|
1148 |
+
else None
|
1149 |
+
),
|
1150 |
+
past_key_value=past_key_value,
|
1151 |
+
output_attentions=output_attentions,
|
1152 |
+
use_cache=use_cache,
|
1153 |
+
)
|
1154 |
+
|
1155 |
+
hidden_states = layer_outputs[0]
|
1156 |
+
|
1157 |
+
if skip_the_layer:
|
1158 |
+
continue
|
1159 |
+
|
1160 |
+
if use_cache:
|
1161 |
+
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
1162 |
+
|
1163 |
+
if output_attentions:
|
1164 |
+
all_self_attns += (layer_outputs[1],)
|
1165 |
+
all_cross_attentions += (layer_outputs[2],)
|
1166 |
+
|
1167 |
+
if self.layer_norm is not None:
|
1168 |
+
hidden_states = self.layer_norm(hidden_states)
|
1169 |
+
|
1170 |
+
# add hidden states from the last decoder layer
|
1171 |
+
if output_hidden_states:
|
1172 |
+
all_hidden_states += (hidden_states,)
|
1173 |
+
|
1174 |
+
next_cache = next_decoder_cache if use_cache else None
|
1175 |
+
if not return_dict:
|
1176 |
+
return tuple(
|
1177 |
+
v
|
1178 |
+
for v in [
|
1179 |
+
hidden_states,
|
1180 |
+
next_cache,
|
1181 |
+
all_hidden_states,
|
1182 |
+
all_self_attns,
|
1183 |
+
all_cross_attentions,
|
1184 |
+
]
|
1185 |
+
if v is not None
|
1186 |
+
)
|
1187 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
1188 |
+
last_hidden_state=hidden_states,
|
1189 |
+
past_key_values=next_cache,
|
1190 |
+
hidden_states=all_hidden_states,
|
1191 |
+
attentions=all_self_attns,
|
1192 |
+
cross_attentions=all_cross_attentions,
|
1193 |
+
)
|
1194 |
+
|
1195 |
+
|
1196 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->IndicTrans
|
1197 |
+
class IndicTransModel(IndicTransPreTrainedModel):
|
1198 |
+
_tied_weights_keys = None
|
1199 |
+
|
1200 |
+
def __init__(self, config: IndicTransConfig):
|
1201 |
+
super().__init__(config)
|
1202 |
+
|
1203 |
+
self.encoder = IndicTransEncoder(config)
|
1204 |
+
self.decoder = IndicTransDecoder(config)
|
1205 |
+
|
1206 |
+
# Initialize weights and apply final processing
|
1207 |
+
self.post_init()
|
1208 |
+
|
1209 |
+
def get_encoder(self):
|
1210 |
+
return self.encoder
|
1211 |
+
|
1212 |
+
def get_decoder(self):
|
1213 |
+
return self.decoder
|
1214 |
+
|
1215 |
+
def forward(
|
1216 |
+
self,
|
1217 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1218 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1219 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1220 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1221 |
+
head_mask: Optional[torch.Tensor] = None,
|
1222 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1223 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1224 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1225 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1226 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1227 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1228 |
+
use_cache: Optional[bool] = None,
|
1229 |
+
output_attentions: Optional[bool] = None,
|
1230 |
+
output_hidden_states: Optional[bool] = None,
|
1231 |
+
return_dict: Optional[bool] = None,
|
1232 |
+
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
1233 |
+
output_attentions = (
|
1234 |
+
output_attentions
|
1235 |
+
if output_attentions is not None
|
1236 |
+
else self.config.output_attentions
|
1237 |
+
)
|
1238 |
+
output_hidden_states = (
|
1239 |
+
output_hidden_states
|
1240 |
+
if output_hidden_states is not None
|
1241 |
+
else self.config.output_hidden_states
|
1242 |
+
)
|
1243 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1244 |
+
return_dict = (
|
1245 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1246 |
+
)
|
1247 |
+
|
1248 |
+
if encoder_outputs is None:
|
1249 |
+
encoder_outputs = self.encoder(
|
1250 |
+
input_ids=input_ids,
|
1251 |
+
attention_mask=attention_mask,
|
1252 |
+
head_mask=head_mask,
|
1253 |
+
inputs_embeds=inputs_embeds,
|
1254 |
+
output_attentions=output_attentions,
|
1255 |
+
output_hidden_states=output_hidden_states,
|
1256 |
+
return_dict=return_dict,
|
1257 |
+
)
|
1258 |
+
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
|
1259 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
1260 |
+
encoder_outputs = BaseModelOutput(
|
1261 |
+
last_hidden_state=encoder_outputs[0],
|
1262 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
1263 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
1264 |
+
)
|
1265 |
+
|
1266 |
+
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
1267 |
+
decoder_outputs = self.decoder(
|
1268 |
+
input_ids=decoder_input_ids,
|
1269 |
+
attention_mask=decoder_attention_mask,
|
1270 |
+
encoder_hidden_states=encoder_outputs[0],
|
1271 |
+
encoder_attention_mask=attention_mask,
|
1272 |
+
head_mask=decoder_head_mask,
|
1273 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1274 |
+
past_key_values=past_key_values,
|
1275 |
+
inputs_embeds=decoder_inputs_embeds,
|
1276 |
+
use_cache=use_cache,
|
1277 |
+
output_attentions=output_attentions,
|
1278 |
+
output_hidden_states=output_hidden_states,
|
1279 |
+
return_dict=return_dict,
|
1280 |
+
)
|
1281 |
+
|
1282 |
+
if not return_dict:
|
1283 |
+
return decoder_outputs + encoder_outputs
|
1284 |
+
|
1285 |
+
return Seq2SeqModelOutput(
|
1286 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
1287 |
+
past_key_values=decoder_outputs.past_key_values,
|
1288 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
1289 |
+
decoder_attentions=decoder_outputs.attentions,
|
1290 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1291 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
1292 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1293 |
+
encoder_attentions=encoder_outputs.attentions,
|
1294 |
+
)
|
1295 |
+
|
1296 |
+
|
1297 |
+
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->IndicTrans
|
1298 |
+
class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
|
1299 |
+
base_model_prefix = "model"
|
1300 |
+
_tied_weights_keys = None
|
1301 |
+
|
1302 |
+
def __init__(self, config: IndicTransConfig):
|
1303 |
+
super().__init__(config)
|
1304 |
+
self.model = IndicTransModel(config)
|
1305 |
+
self.lm_head = nn.Linear(
|
1306 |
+
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
1307 |
+
)
|
1308 |
+
|
1309 |
+
if config.share_decoder_input_output_embed:
|
1310 |
+
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
1311 |
+
|
1312 |
+
self.post_init()
|
1313 |
+
|
1314 |
+
def tie_weights(self):
|
1315 |
+
pass
|
1316 |
+
|
1317 |
+
def get_encoder(self):
|
1318 |
+
return self.model.get_encoder()
|
1319 |
+
|
1320 |
+
def get_decoder(self):
|
1321 |
+
return self.model.get_decoder()
|
1322 |
+
|
1323 |
+
def get_output_embeddings(self):
|
1324 |
+
return self.lm_head
|
1325 |
+
|
1326 |
+
def set_output_embeddings(self, new_embeddings):
|
1327 |
+
self.lm_head = new_embeddings
|
1328 |
+
|
1329 |
+
def forward(
|
1330 |
+
self,
|
1331 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1332 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1333 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
1334 |
+
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
1335 |
+
head_mask: Optional[torch.Tensor] = None,
|
1336 |
+
decoder_head_mask: Optional[torch.Tensor] = None,
|
1337 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
1338 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1339 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
1340 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1341 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
1342 |
+
labels: Optional[torch.LongTensor] = None,
|
1343 |
+
use_cache: Optional[bool] = None,
|
1344 |
+
output_attentions: Optional[bool] = None,
|
1345 |
+
output_hidden_states: Optional[bool] = None,
|
1346 |
+
return_dict: Optional[bool] = None,
|
1347 |
+
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
1348 |
+
r"""
|
1349 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1350 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
1351 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
1352 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
1353 |
+
|
1354 |
+
Returns:
|
1355 |
+
"""
|
1356 |
+
return_dict = (
|
1357 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1358 |
+
)
|
1359 |
+
|
1360 |
+
if labels is not None:
|
1361 |
+
if decoder_input_ids is None:
|
1362 |
+
decoder_input_ids = shift_tokens_right(
|
1363 |
+
labels, self.config.pad_token_id, self.config.decoder_start_token_id
|
1364 |
+
)
|
1365 |
+
|
1366 |
+
outputs = self.model(
|
1367 |
+
input_ids,
|
1368 |
+
attention_mask=attention_mask,
|
1369 |
+
decoder_input_ids=decoder_input_ids,
|
1370 |
+
encoder_outputs=encoder_outputs,
|
1371 |
+
decoder_attention_mask=decoder_attention_mask,
|
1372 |
+
head_mask=head_mask,
|
1373 |
+
decoder_head_mask=decoder_head_mask,
|
1374 |
+
cross_attn_head_mask=cross_attn_head_mask,
|
1375 |
+
past_key_values=past_key_values,
|
1376 |
+
inputs_embeds=inputs_embeds,
|
1377 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
1378 |
+
use_cache=use_cache,
|
1379 |
+
output_attentions=output_attentions,
|
1380 |
+
output_hidden_states=output_hidden_states,
|
1381 |
+
return_dict=return_dict,
|
1382 |
+
)
|
1383 |
+
lm_logits = self.lm_head(outputs[0])
|
1384 |
+
|
1385 |
+
masked_lm_loss = None
|
1386 |
+
if labels is not None:
|
1387 |
+
# move labels to the correct device to enable PP
|
1388 |
+
labels = labels.to(lm_logits.device)
|
1389 |
+
loss_fct = nn.CrossEntropyLoss()
|
1390 |
+
masked_lm_loss = loss_fct(
|
1391 |
+
lm_logits.view(-1, self.config.vocab_size), labels.view(-1)
|
1392 |
+
)
|
1393 |
+
|
1394 |
+
if not return_dict:
|
1395 |
+
output = (lm_logits,) + outputs[1:]
|
1396 |
+
return (
|
1397 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1398 |
+
)
|
1399 |
+
|
1400 |
+
return Seq2SeqLMOutput(
|
1401 |
+
loss=masked_lm_loss,
|
1402 |
+
logits=lm_logits,
|
1403 |
+
past_key_values=outputs.past_key_values,
|
1404 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1405 |
+
decoder_attentions=outputs.decoder_attentions,
|
1406 |
+
cross_attentions=outputs.cross_attentions,
|
1407 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1408 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1409 |
+
encoder_attentions=outputs.encoder_attentions,
|
1410 |
+
)
|
1411 |
+
|
1412 |
+
def prepare_inputs_for_generation(
|
1413 |
+
self,
|
1414 |
+
decoder_input_ids,
|
1415 |
+
past_key_values=None,
|
1416 |
+
attention_mask=None,
|
1417 |
+
head_mask=None,
|
1418 |
+
decoder_head_mask=None,
|
1419 |
+
cross_attn_head_mask=None,
|
1420 |
+
use_cache=None,
|
1421 |
+
encoder_outputs=None,
|
1422 |
+
**kwargs,
|
1423 |
+
):
|
1424 |
+
# cut decoder_input_ids if past is used
|
1425 |
+
if past_key_values is not None:
|
1426 |
+
decoder_input_ids = decoder_input_ids[:, -1:]
|
1427 |
+
|
1428 |
+
return {
|
1429 |
+
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
1430 |
+
"encoder_outputs": encoder_outputs,
|
1431 |
+
"past_key_values": past_key_values,
|
1432 |
+
"decoder_input_ids": decoder_input_ids,
|
1433 |
+
"attention_mask": attention_mask,
|
1434 |
+
"head_mask": head_mask,
|
1435 |
+
"decoder_head_mask": decoder_head_mask,
|
1436 |
+
"cross_attn_head_mask": cross_attn_head_mask,
|
1437 |
+
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
|
1438 |
+
}
|
1439 |
+
|
1440 |
+
@staticmethod
|
1441 |
+
def _reorder_cache(past_key_values, beam_idx):
|
1442 |
+
reordered_past = ()
|
1443 |
+
for layer_past in past_key_values:
|
1444 |
+
reordered_past += (
|
1445 |
+
tuple(
|
1446 |
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1447 |
+
),
|
1448 |
+
)
|
1449 |
+
return reordered_past
|
sample.srt
ADDED
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1
|
2 |
+
00:00:00,000 --> 00:00:01,845
|
3 |
+
Sadhguru: If you activate this dimension of energy,
|
4 |
+
2
|
5 |
+
00:00:01,845 --> 00:00:05,013
|
6 |
+
other dimensions of life will open up.
|
7 |
+
3
|
8 |
+
00:00:05,013 --> 00:00:07,545
|
9 |
+
One thing is,
|
10 |
+
4
|
11 |
+
00:00:07,545 --> 00:00:10,377
|
12 |
+
are you ready for those dimensions?
|
13 |
+
5
|
14 |
+
00:00:10,377 --> 00:00:11,280
|
15 |
+
Kundalini Yoga,
|
16 |
+
6
|
17 |
+
00:00:11,280 --> 00:00:12,318
|
18 |
+
in its essence,
|
19 |
+
7
|
20 |
+
00:00:12,318 --> 00:00:16,637
|
21 |
+
is the most dangerous form of yoga.
|
22 |
+
8
|
23 |
+
00:00:16,637 --> 00:00:18,008
|
24 |
+
I'm saying dangerous,
|
25 |
+
9
|
26 |
+
00:00:18,008 --> 00:00:20,538
|
27 |
+
because it's the most potent also.
|
28 |
+
10
|
29 |
+
00:00:20,538 --> 00:00:22,890
|
30 |
+
If you have to jump into an abyss,
|
31 |
+
11
|
32 |
+
00:00:22,890 --> 00:00:28,613
|
33 |
+
you should be insane or you should have enormous trust in somebody.
|
34 |
+
12
|
35 |
+
00:00:28:613 --> 00:00:37,465
|
36 |
+
So what is Kundalini?
|
37 |
+
13
|
38 |
+
00:00:37,465 --> 00:00:41,788
|
39 |
+
Right now,
|
40 |
+
14
|
41 |
+
00:00:41,788 --> 00:00:43,100
|
42 |
+
I'm speaking,
|
43 |
+
15
|
44 |
+
00:00:43,100 --> 00:00:46,495
|
45 |
+
this is Kundalini.
|
46 |
+
16
|
47 |
+
00:00:46,495 --> 00:00:47,843
|
48 |
+
You are alert and listening.
|
49 |
+
17
|
50 |
+
00:00:47,843 --> 00:00:49,483
|
51 |
+
If you are alert and listening,
|
52 |
+
18
|
53 |
+
00:00:49,483 --> 00:00:51,278
|
54 |
+
that is Kundalini
|
55 |
+
19
|
56 |
+
00:00:51,278 --> 00:00:54,528
|
57 |
+
A flower is blossoming,
|
58 |
+
20
|
59 |
+
00:00:54,528 --> 00:00:55,680
|
60 |
+
that is Kundalini.
|
61 |
+
21
|
62 |
+
00:00:55,680 --> 00:00:57,645
|
63 |
+
A dog is barking,
|
64 |
+
22
|
65 |
+
00:00:57,645 --> 00:00:59,480
|
66 |
+
that is also Kundalini,
|
67 |
+
23
|
68 |
+
00:00:59,480 --> 00:01:00,553
|
69 |
+
or in other words,
|
70 |
+
24
|
71 |
+
00:01:00,553 --> 00:01:03,545
|
72 |
+
the fundamental life force in the existence,
|
73 |
+
25
|
74 |
+
00:01:03,545 --> 00:01:04,880
|
75 |
+
we call it Kundalini.
|
76 |
+
26
|
77 |
+
00:01:04,880 --> 00:01:07,513
|
78 |
+
Now,
|
79 |
+
27
|
80 |
+
00:01:07,513 --> 00:01:08,555
|
81 |
+
within the system,
|
82 |
+
28
|
83 |
+
00:01:08,555 --> 00:01:10,377
|
84 |
+
within the human system,
|
85 |
+
29
|
86 |
+
00:01:10,377 --> 00:01:14,065
|
87 |
+
if you look at this as a kind of a life package,
|
88 |
+
30
|
89 |
+
00:01:14,065 --> 00:01:15,287
|
90 |
+
it's a piece of life.
|
91 |
+
31
|
92 |
+
00:01:15,287 --> 00:01:18,540
|
93 |
+
This piece of life is packed in a certain way,
|
94 |
+
32
|
95 |
+
00:01:18,540 --> 00:01:20,725
|
96 |
+
with layers of this energy.
|
97 |
+
33
|
98 |
+
00:01:20,725 --> 00:01:24,887
|
99 |
+
One dimension of energy comes alive immediately,
|
100 |
+
34
|
101 |
+
00:01:24,887 --> 00:01:28,667
|
102 |
+
because that is necessary for your survival process.
|
103 |
+
35
|
104 |
+
00:01:28,667 --> 00:01:37,369
|
105 |
+
The other dimensions of energy will not come alive unless you do something about it.
|
106 |
+
36
|
107 |
+
00:01:37,369 --> 00:01:40,203
|
108 |
+
Unless you're aware of it and activate it in a certain way,
|
109 |
+
37
|
110 |
+
00:01:40,203 --> 00:01:42,093
|
111 |
+
they do not come into existence.
|
112 |
+
38
|
113 |
+
00:01:42,093 --> 00:01:43,980
|
114 |
+
They remain dormant.
|
115 |
+
39
|
116 |
+
00:01:43,980 --> 00:01:49,932
|
117 |
+
The dormant energy is way bigger than the energy that is in use right now.
|
118 |
+
40
|
119 |
+
00:01:49,932 --> 00:01:53,582
|
120 |
+
To take care of your survival process,
|
121 |
+
41
|
122 |
+
00:01:53,582 --> 00:01:57,169
|
123 |
+
to live a physical life completely,
|
124 |
+
42
|
125 |
+
00:01:57,169 --> 00:02:00,902
|
126 |
+
to live a complete physical and intellectual life,
|
127 |
+
43
|
128 |
+
00:02:00,902 --> 00:02:08,451
|
129 |
+
you need to activate only about twenty-one of your chakras.
|
130 |
+
44
|
131 |
+
00:02:08,451 --> 00:02:12,246
|
132 |
+
Out of this one hundred and fourteen,
|
133 |
+
45
|
134 |
+
00:02:12,246 --> 00:02:14,444
|
135 |
+
if about twenty-one of them are on,
|
136 |
+
46
|
137 |
+
00:02:14,444 --> 00:02:16,248
|
138 |
+
you will live a complete life,
|
139 |
+
47
|
140 |
+
00:02:16,248 --> 00:02:18,322
|
141 |
+
you will not feel any inadequacy.
|
142 |
+
48
|
143 |
+
00:02:18,322 --> 00:02:21,267
|
144 |
+
You will live a complete physical life.
|
145 |
+
49
|
146 |
+
00:02:21,267 --> 00:02:23,112
|
147 |
+
There'll be no problem with your life,
|
148 |
+
50
|
149 |
+
00:02:23,112 --> 00:02:25,249
|
150 |
+
you will think you're a great success,
|
151 |
+
51
|
152 |
+
00:02:25,249 --> 00:02:29,458
|
153 |
+
but you're only twenty-one,
|
154 |
+
52
|
155 |
+
00:02:29,458 --> 00:02:34,000
|
156 |
+
that is less than twenty-one percent out of one hundred and fourteen,
|
157 |
+
53
|
158 |
+
00:02:34,000 --> 00:02:35,949
|
159 |
+
less than twenty percent.
|
160 |
+
54
|
161 |
+
00:02:35,949 --> 00:02:37,733
|
162 |
+
At let… less than twenty percent,
|
163 |
+
55
|
164 |
+
00:02:37,733 --> 00:02:41,984
|
165 |
+
you will feel like a complete life without any inadequacies.
|
166 |
+
56
|
167 |
+
00:02:41,984 --> 00:02:44,618
|
168 |
+
The remaining percentage of life,
|
169 |
+
57
|
170 |
+
00:02:44,618 --> 00:02:46,339
|
171 |
+
what is it about?
|
172 |
+
58
|
173 |
+
00:02:46,339 --> 00:02:50,632
|
174 |
+
It is not even needed if your intention is just to live well.
|
175 |
+
59
|
176 |
+
00:02:50,632 --> 00:02:53,846
|
177 |
+
If you activate this dimension of energy,
|
178 |
+
60
|
179 |
+
00:02:53,846 --> 00:02:57,143
|
180 |
+
other dimensions of life will open up.
|
181 |
+
61
|
182 |
+
00:02:57,143 --> 00:02:58,865
|
183 |
+
One thing is,
|
184 |
+
62
|
185 |
+
00:02:58,865 --> 00:03:00,980
|
186 |
+
“Are you ready for those dimensions?”
|
187 |
+
63
|
188 |
+
00:03:00,980 --> 00:03:05,003
|
189 |
+
The question is not about whether it's good or bad.
|
190 |
+
64
|
191 |
+
00:03:05,003 --> 00:03:06,890
|
192 |
+
The question is just about,
|
193 |
+
65
|
194 |
+
00:03:06,890 --> 00:03:08,902
|
195 |
+
“Are you ready for it?”
|
196 |
+
66
|
197 |
+
00:03:08,902 --> 00:03:13,920
|
198 |
+
Because even if the best things in life come to your life,
|
199 |
+
67
|
200 |
+
00:03:13,920 --> 00:03:16,803
|
201 |
+
when you are not ready for it,
|
202 |
+
68
|
203 |
+
00:03:16,803 --> 00:03:19,084
|
204 |
+
it will not be a good thing for you.
|
205 |
+
69
|
206 |
+
00:03:19,084 --> 00:03:22,713
|
207 |
+
In your experience it will not be a good thing
|
208 |
+
70
|
209 |
+
00:03:22,713 --> 00:03:25,223
|
210 |
+
if something came to you when you're not ready for it,
|
211 |
+
71
|
212 |
+
00:03:25,223 --> 00:03:26,467
|
213 |
+
isn't it so?
|
214 |
+
72
|
215 |
+
00:03:26,467 --> 00:03:28,250
|
216 |
+
Even if it's a greatest thing,
|
217 |
+
73
|
218 |
+
00:03:28,250 --> 00:03:29,702
|
219 |
+
it may be the greatest thing,
|
220 |
+
74
|
221 |
+
00:03:29,702 --> 00:03:33,539
|
222 |
+
but it came to you when you are not ready for it.
|
223 |
+
75
|
224 |
+
00:03:33,539 --> 00:03:34,969
|
225 |
+
Then it is not a good thing,
|
226 |
+
76
|
227 |
+
00:03:34,969 --> 00:03:36,525
|
228 |
+
isn't it?
|
229 |
+
77
|
230 |
+
00:03:36,525 --> 00:03:38,018
|
231 |
+
So are you ready for it,
|
232 |
+
78
|
233 |
+
00:03:38,018 --> 00:03:39,034
|
234 |
+
is the first question.
|
235 |
+
79
|
236 |
+
00:03:39,034 --> 00:03:40,071
|
237 |
+
If you're ready for it,
|
238 |
+
80
|
239 |
+
00:03:40,071 --> 00:03:41,709
|
240 |
+
what can we do for it?
|
241 |
+
81
|
242 |
+
00:03:41,709 --> 00:03:44,073
|
243 |
+
What can we do to activate it?
|
244 |
+
82
|
245 |
+
00:03:44,073 --> 00:03:46,541
|
246 |
+
The various ways of doing this,
|
247 |
+
83
|
248 |
+
00:03:46,541 --> 00:03:47,143
|
249 |
+
many,
|
250 |
+
84
|
251 |
+
00:03:47,143--> 00:03:48,200
|
252 |
+
many ways,
|
253 |
+
85
|
254 |
+
00:03:48,200 --> 00:03:50,938
|
255 |
+
but the Kundalini Yoga…
|
256 |
+
86
|
257 |
+
00:03:50,938 --> 00:03:56,558
|
258 |
+
are people familiar with Kundalini Yoga practicing or…?
|
259 |
+
87
|
260 |
+
00:03:56,558 --> 00:04:00,083
|
261 |
+
Okay.
|
262 |
+
88
|
263 |
+
00:04:00,083 --> 00:04:01,535
|
264 |
+
Kundalini Yoga.
|
265 |
+
89
|
266 |
+
00:04:01,535 --> 00:04:04,708
|
267 |
+
I…I'm not making a comment about anybody,
|
268 |
+
90
|
269 |
+
00:04:04,708 --> 00:04:07,279
|
270 |
+
okay?
|
271 |
+
91
|
272 |
+
00:04:07,279 --> 00:04:08,378
|
273 |
+
Kundalini Yoga,
|
274 |
+
92
|
275 |
+
00:04:08,378 --> 00:04:13,687
|
276 |
+
in its essence is the most dangerous form of yoga.
|
277 |
+
93
|
278 |
+
00:04:13,687 --> 00:04:15,471
|
279 |
+
I'm saying dangerous,
|
280 |
+
94
|
281 |
+
00:04:15,471 --> 00:04:19,203
|
282 |
+
because it's the most potent also.
|
283 |
+
95
|
284 |
+
00:04:19,203 --> 00:04:27,561
|
285 |
+
What is most potent is always the most dangerous if improperly handled.
|
286 |
+
96
|
287 |
+
00:04:27,561 --> 00:04:31,169
|
288 |
+
There are various kinds of energy in the world right now,
|
289 |
+
97
|
290 |
+
00:04:31,169 --> 00:04:33,886
|
291 |
+
even the electricity is being manufactured in… I mean,
|
292 |
+
98
|
293 |
+
00:04:33,886 --> 00:04:36,126
|
294 |
+
produced in so many different ways.
|
295 |
+
99
|
296 |
+
00:04:36,126 --> 00:04:40,170
|
297 |
+
One of the ways that we do it is through nuclear reactions,
|
298 |
+
100
|
299 |
+
00:04:40,170 --> 00:04:42,907
|
300 |
+
nuclear reactors rather..
|
301 |
+
101
|
302 |
+
00:04:42,907 --> 00:04:46,889
|
303 |
+
It is the most efficient way of producing energy that we know right now,
|
304 |
+
102
|
305 |
+
00:04:46,889 --> 00:04:49,211
|
306 |
+
but it is also the most dangerous way,
|
307 |
+
103
|
308 |
+
00:04:49,211 --> 00:04:51,223
|
309 |
+
isn't it?
|
310 |
+
104
|
311 |
+
00:04:51,223 --> 00:04:52,965
|
312 |
+
When things go wrong,
|
313 |
+
105
|
314 |
+
00:04:52,965 --> 00:04:55,744
|
315 |
+
they go seriously wrong.
|
316 |
+
106
|
317 |
+
00:04:55,744 --> 00:04:57,299
|
318 |
+
When they're going right,
|
319 |
+
107
|
320 |
+
00:04:57,299 --> 00:05:04,288
|
321 |
+
it is the easiest and the best way to produce energy on the planet is nuclear energy actually.
|
322 |
+
108
|
323 |
+
00:05:04,288 --> 00:05:06,258
|
324 |
+
But when it goes bad,
|
325 |
+
109
|
326 |
+
00:05:06,258 --> 00:05:07,502
|
327 |
+
it goes bad,
|
328 |
+
110
|
329 |
+
00:05:07,502 --> 00:05:08,539
|
330 |
+
really bad,
|
331 |
+
111
|
332 |
+
00:05:08,539 --> 00:05:11,256
|
333 |
+
like in ways that you can't fix it.
|
334 |
+
112
|
335 |
+
00:05:11,256 --> 00:05:11,691
|
336 |
+
So,
|
337 |
+
113
|
338 |
+
00:05:11,691 --> 00:05:13,454
|
339 |
+
similarly with Kundalini Yoga,
|
340 |
+
114
|
341 |
+
00:05:13,454 --> 00:05:17,353
|
342 |
+
it is the most potent and it is the most dangerous.
|
343 |
+
115
|
344 |
+
00:05:17,353 --> 00:05:21,106
|
345 |
+
Without the necessary preparation and guidance,
|
346 |
+
116
|
347 |
+
00:05:21,106 --> 00:05:23,387
|
348 |
+
without expert guidance,
|
349 |
+
117
|
350 |
+
00:05:23,387 --> 00:05:25,793
|
351 |
+
constant guidance and observation,
|
352 |
+
118
|
353 |
+
00:05:25,793 --> 00:05:28,219
|
354 |
+
nobody should ever attempt it.
|
355 |
+
119
|
356 |
+
00:05:28,219 --> 00:05:29,588
|
357 |
+
But the problem is,
|
358 |
+
120
|
359 |
+
00:05:29,588 --> 00:05:33,611
|
360 |
+
books have been written about it and everybody wants to do the highest yoga.
|
361 |
+
121
|
362 |
+
00:05:33,611 --> 00:05:35,830
|
363 |
+
Nobody wants to start with A,
|
364 |
+
122
|
365 |
+
00:05:35,830 --> 00:05:38,650
|
366 |
+
everybody wants to start the alphabet with Z.
|
367 |
+
123
|
368 |
+
00:05:38,650 --> 00:05:42,383
|
369 |
+
This attitude itself is dangerous.
|
370 |
+
124
|
371 |
+
00:05:42,383 --> 00:05:50,326
|
372 |
+
What can be a life-transforming force can become a life-destructive force
|
373 |
+
125
|
374 |
+
00:05:50,326 --> 00:05:55,780
|
375 |
+
simply because without the necessary commitment and dedication and focus and understanding
|
376 |
+
126
|
377 |
+
00:05:55,780 --> 00:05:57,294
|
378 |
+
it is being handled.
|
379 |
+
127
|
380 |
+
00:05:57,294 --> 00:05:58,849
|
381 |
+
Anyway,
|
382 |
+
128
|
383 |
+
00:05:58,849 --> 00:06:00,550
|
384 |
+
about raising the Kundalini,
|
385 |
+
129
|
386 |
+
00:06:00,550 --> 00:06:02,292
|
387 |
+
if the Kundalini rises,
|
388 |
+
130
|
389 |
+
00:06:02,292 --> 00:06:08,534
|
390 |
+
the dimensions of your life will change so rapidly that you must be willing
|
391 |
+
131
|
392 |
+
00:06:08,534 --> 00:06:11,562
|
393 |
+
to make the outside admin… adjustments equally,
|
394 |
+
132
|
395 |
+
00:06:11,562 --> 00:06:12,806
|
396 |
+
quick.
|
397 |
+
133
|
398 |
+
00:06:12,806 --> 00:06:14,320
|
399 |
+
Otherwise,
|
400 |
+
134
|
401 |
+
00:06:14,320 --> 00:06:18,426
|
402 |
+
things will fall apart in a big way.
|
403 |
+
135
|
404 |
+
00:06:18,426 --> 00:06:20,624
|
405 |
+
In the classical yogic traditions,
|
406 |
+
136
|
407 |
+
00:06:20,624 --> 00:06:26,161
|
408 |
+
there is a certain type of yoga we teach for people who live in family situations.
|
409 |
+
137
|
410 |
+
00:06:26,161 --> 00:06:30,205
|
411 |
+
There is a certain other type of yoga we teach for ascetics.
|
412 |
+
138
|
413 |
+
00:06:30,205 --> 00:06:35,159
|
414 |
+
In Isha,
|
415 |
+
139
|
416 |
+
00:06:35,159 --> 00:06:32,486
|
417 |
+
we have both the forms,
|
418 |
+
140
|
419 |
+
00:06:32,486 --> 00:06:36,219
|
420 |
+
we have ascetic yoga and we have the general yoga.
|
421 |
+
141
|
422 |
+
00:06:36,219 --> 00:06:38,915
|
423 |
+
We will never teach you the ascetic form.
|
424 |
+
142
|
425 |
+
00:06:38,915 --> 00:06:41,590
|
426 |
+
That is the most pos… potent way to do it.
|
427 |
+
143
|
428 |
+
00:06:41,590 --> 00:06:46,733
|
429 |
+
But it will demand a certain dimension of discipline and focus,
|
430 |
+
144
|
431 |
+
00:06:46,733 --> 00:06:50,113
|
432 |
+
which your regular lives will not allow.
|
433 |
+
145
|
434 |
+
00:06:50,113 --> 00:06:52,498
|
435 |
+
If you do that kind of yoga,
|
436 |
+
146
|
437 |
+
00:06:52,498 --> 00:06:56,273
|
438 |
+
it will dismantle your outside life instantly.
|
439 |
+
147
|
440 |
+
00:06:56,273 --> 00:06:59,632
|
441 |
+
Now this Yoga is not designed to dismantle your life,
|
442 |
+
148
|
443 |
+
00:06:59,632 --> 00:07:05,874
|
444 |
+
this Yoga is designed to make your life happen better.
|
445 |
+
149
|
446 |
+
00:07:05,874 --> 00:07:07,907
|
447 |
+
When life happens better,
|
448 |
+
150
|
449 |
+
00:07:07,907 --> 00:07:09,835
|
450 |
+
when things happen better,
|
451 |
+
151
|
452 |
+
00:07:09,835 --> 00:07:11,328
|
453 |
+
you make more money,
|
454 |
+
152
|
455 |
+
00:07:11,328 --> 00:07:13,112
|
456 |
+
your business is going better,
|
457 |
+
153
|
458 |
+
00:07:13,112 --> 00:07:14,854
|
459 |
+
your profession is happening better.
|
460 |
+
154
|
461 |
+
00:07:14,854 --> 00:07:19,644
|
462 |
+
You're generally unfortunately,
|
463 |
+
155
|
464 |
+
00:07:19,644 --> 00:07:23,833
|
465 |
+
you are longing to seek the higher becomes slower.
|
466 |
+
156
|
467 |
+
00:07:23,833 --> 00:07:25,512
|
468 |
+
Yes.
|
469 |
+
157
|
470 |
+
00:07:25,512 --> 00:07:27,006
|
471 |
+
So in the real sense,
|
472 |
+
158
|
473 |
+
00:07:27,006 --> 00:07:28,873
|
474 |
+
it is not the good way to do it.
|
475 |
+
159
|
476 |
+
00:07:28,873 --> 00:07:31,714
|
477 |
+
But it's the only way it works in today's world.
|
478 |
+
160
|
479 |
+
00:07:31,714 --> 00:07:34,596
|
480 |
+
And it's the only way it works for majority of the people.
|
481 |
+
161
|
482 |
+
00:07:34,596 --> 00:07:36,255
|
483 |
+
For a small number of people,
|
484 |
+
162
|
485 |
+
00:07:36,255 --> 00:07:38,018
|
486 |
+
we can do it other ways.
|
487 |
+
163
|
488 |
+
00:07:38,018 --> 00:07:43,680
|
489 |
+
We can bypass all these things and just do very powerful ways of doing things.
|
490 |
+
164
|
491 |
+
00:07:43,680 --> 00:07:47,329
|
492 |
+
But it will dismantle all social structures around them,
|
493 |
+
165
|
494 |
+
00:07:47,329 --> 00:07:49,631
|
495 |
+
which is not good for everybody to do.
|
496 |
+
166
|
497 |
+
00:07:49,631 --> 00:07:51,809
|
498 |
+
So these are different dimensions.
|
499 |
+
167
|
500 |
+
00:07:51,809 --> 00:07:52,887
|
501 |
+
Kundalini Yoga,
|
502 |
+
168
|
503 |
+
00:07:52,887 --> 00:07:54,443
|
504 |
+
if it has to be practiced,
|
505 |
+
169
|
506 |
+
00:07:54,443 --> 00:07:57,263
|
507 |
+
you must be in a certain kind of atmosphere.
|
508 |
+
170
|
509 |
+
00:07:57,263 --> 00:08:01,618
|
510 |
+
You cannot live in social situations and do Kundalini Yoga
|
511 |
+
171
|
512 |
+
00:08:01,618 --> 00:08:02,468
|
513 |
+
Otherwise,
|
514 |
+
172
|
515 |
+
00:08:02,468 --> 00:08:03,878
|
516 |
+
in the name of Kundalini Yoga,
|
517 |
+
173
|
518 |
+
00:08:03,878 --> 00:08:06,263
|
519 |
+
you're doing something simplistic.
|
520 |
+
174
|
521 |
+
00:08:06,263 --> 00:08:07,176
|
522 |
+
Otherwise,
|
523 |
+
175
|
524 |
+
00:08:07,176 --> 00:08:10,971
|
525 |
+
Kundalini yoga can transform the way you are within days.
|
526 |
+
176
|
527 |
+
00:08:10,971 --> 00:08:16,674
|
528 |
+
Suddenly you find you're a stranger in your own home within two days of practice,
|
529 |
+
177
|
530 |
+
00:08:16,674 --> 00:08:20,469
|
531 |
+
because it will change everything about you.
|
532 |
+
178
|
533 |
+
00:08:20,469 --> 00:08:20,759
|
534 |
+
So,
|
535 |
+
179
|
536 |
+
00:08:20,759 --> 00:08:22,169
|
537 |
+
can we raise the Kundalini?
|
538 |
+
180
|
539 |
+
00:08:22,169 --> 00:08:22,480
|
540 |
+
Yes,
|
541 |
+
181
|
542 |
+
00:08:22,480 --> 00:08:25,135
|
543 |
+
we can.
|
544 |
+
182
|
545 |
+
00:08:25,135 --> 00:08:30,651
|
546 |
+
One way is to create a conducive atmosphere so that slowly it rises.
|
547 |
+
183
|
548 |
+
00:08:30,651 --> 00:08:35,110
|
549 |
+
The other way is to provoke it in such a way that it rises quickly.
|
550 |
+
184
|
551 |
+
00:08:35,110 --> 00:08:37,225
|
552 |
+
If it rises quickly,
|
553 |
+
185
|
554 |
+
00:08:37,225 --> 00:08:39,423
|
555 |
+
then everything changes dramatically.
|
556 |
+
186
|
557 |
+
00:08:39,423 --> 00:08:42,223
|
558 |
+
If it rises slowly over a period of time,
|
559 |
+
187
|
560 |
+
00:08:42,223 --> 00:08:44,234
|
561 |
+
changes will happen slowly,
|
562 |
+
188
|
563 |
+
00:08:44,234 --> 00:08:48,216
|
564 |
+
you will be capable of handling these changes over a period of time.
|
565 |
+
189
|
566 |
+
00:08:48,216 --> 00:08:49,958
|
567 |
+
But if it happens very quick,
|
568 |
+
190
|
569 |
+
00:08:49,958 --> 00:08:53,193
|
570 |
+
then you will not be able to handle the changes,
|
571 |
+
191
|
572 |
+
00:08:53,193 --> 00:08:56,677
|
573 |
+
things will look like things are falling apart.
|
574 |
+
192
|
575 |
+
00:08:56,677 --> 00:08:58,958
|
576 |
+
So there are different ways of doing this.
|
577 |
+
193
|
578 |
+
00:08:58,958 --> 00:08:59,705
|
579 |
+
How many ways?
|
580 |
+
194
|
581 |
+
00:08:59,705 --> 00:09:00,887
|
582 |
+
There are too many ways,
|
583 |
+
195
|
584 |
+
00:09:00,887 --> 00:09:03,023
|
585 |
+
I will not go into how many ways.
|
586 |
+
196
|
587 |
+
00:09:03,023 --> 00:09:04,972
|
588 |
+
There are so many ways of doing it.
|
589 |
+
197
|
590 |
+
00:09:04,972 --> 00:09:05,885
|
591 |
+
Essentially,
|
592 |
+
198
|
593 |
+
00:09:05,885 --> 00:09:10,903
|
594 |
+
there are one hundred and twelve ways of doing it.
|
595 |
+
199
|
596 |
+
00:09:10,903 --> 00:09:14,595
|
597 |
+
There are hundred and twelve ways in which you can take this up,
|
598 |
+
200
|
599 |
+
00:09:14,595 --> 00:09:17,602
|
600 |
+
from the base to…
|
601 |
+
201
|
602 |
+
00:09:17,602 --> 00:09:18,099
|
603 |
+
Oh!
|
604 |
+
202
|
605 |
+
00:09:18,099 --> 00:09:19,675
|
606 |
+
for this you have to know the structure,
|
607 |
+
203
|
608 |
+
00:09:19,675 --> 00:09:22,122
|
609 |
+
otherwise it will become very elaborate.
|
610 |
+
204
|
611 |
+
00:09:22,122 --> 00:09:25,606
|
612 |
+
Hmm,
|
613 |
+
205
|
614 |
+
00:09:25,606 --> 00:09:28,385
|
615 |
+
out of this one hundred and fourteen chakras,
|
616 |
+
206
|
617 |
+
00:09:28,385 --> 00:09:33,093
|
618 |
+
there are seven which we recognize as seven dimensions.
|
619 |
+
207
|
620 |
+
00:09:33,093 --> 00:09:34,462
|
621 |
+
Out of this,
|
622 |
+
208
|
623 |
+
00:09:34,462 --> 00:09:36,287
|
624 |
+
six are within the body,
|
625 |
+
209
|
626 |
+
00:09:36,287 --> 00:09:39,750
|
627 |
+
one just outside the physical body.
|
628 |
+
210
|
629 |
+
00:09:39,750 --> 00:09:40,144
|
630 |
+
So,
|
631 |
+
211
|
632 |
+
00:09:40,144 --> 00:09:43,192
|
633 |
+
if you employ this one hundred and twelve methods,
|
634 |
+
212
|
635 |
+
00:09:43,192 --> 00:09:45,639
|
636 |
+
you will handle the six chakras,
|
637 |
+
213
|
638 |
+
00:09:45,639 --> 00:09:48,190
|
639 |
+
the seventh one you cannot handle.
|
640 |
+
214
|
641 |
+
00:09:48,190 --> 00:09:52,151
|
642 |
+
There are hundred and twelve ways in which you can at… attain to a chakra
|
643 |
+
215
|
644 |
+
00:09:52,151 --> 00:09:53,976
|
645 |
+
which we refer to as Agna,
|
646 |
+
216
|
647 |
+
00:09:53,976 --> 00:09:55,448
|
648 |
+
but from Agna to Sahasrara,
|
649 |
+
217
|
650 |
+
00:09:55,448 --> 00:09:57,398
|
651 |
+
there is no way.
|
652 |
+
218
|
653 |
+
00:09:57,398 --> 00:09:59,264
|
654 |
+
There is no way to do it,
|
655 |
+
219
|
656 |
+
00:09:59,264 --> 00:10:02,520
|
657 |
+
you just have to jump into an abyss.
|
658 |
+
220
|
659 |
+
00:10:02,520 --> 00:10:04,843
|
660 |
+
If you have to jump into an abyss,
|
661 |
+
221
|
662 |
+
00:10:04,843 --> 00:10:11,417
|
663 |
+
you should be insane or you should have enormous trust in somebody.
|
664 |
+
222
|
665 |
+
00:10:11,417 --> 00:10:15,543
|
666 |
+
Somebody says jump and you are jumping because
|
667 |
+
223
|
668 |
+
00:10:15,543 --> 00:10:19,214
|
669 |
+
you have such a deep trust in somebody that when he says jump,
|
670 |
+
224
|
671 |
+
00:10:19,214 --> 00:10:21,226
|
672 |
+
it has to be good for you.
|
673 |
+
225
|
674 |
+
00:10:21,226 --> 00:10:26,078
|
675 |
+
You simply jump into a bottomless pit.
|
676 |
+
226
|
677 |
+
00:10:26,078 --> 00:10:27,011
|
678 |
+
So,
|
679 |
+
227
|
680 |
+
00:10:27,011 --> 00:10:33,088
|
681 |
+
the journey from the Mooladhara to Agna there are one hundred and twelve ways to get there.
|
682 |
+
228
|
683 |
+
00:10:33,088 --> 00:10:34,311
|
684 |
+
But from there to there,
|
685 |
+
229
|
686 |
+
00:10:34,311 --> 00:10:35,307
|
687 |
+
there is no way.
|
688 |
+
230
|
689 |
+
00:10:35,307 --> 00:10:36,841
|
690 |
+
It is just one jump,
|
691 |
+
231
|
692 |
+
00:10:36,841 --> 00:10:38,998
|
693 |
+
that can happen in trust,
|
694 |
+
232
|
695 |
+
00:10:38,998 --> 00:10:42,316
|
696 |
+
in devotion or in madness.
|
697 |
+
233
|
698 |
+
00:10:42,316 --> 00:10:57,780
|
699 |
+
Choice is yours.
|