jaygala24 commited on
Commit
7e1f3dc
1 Parent(s): ddf188b

Create tokenization_indictrans.py

Browse files
Files changed (1) hide show
  1. tokenization_indictrans.py +238 -0
tokenization_indictrans.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from typing import Dict, List, Optional, Union, Tuple
5
+
6
+ from transformers.utils import logging
7
+ from sentencepiece import SentencePieceProcessor
8
+ from transformers.tokenization_utils import PreTrainedTokenizer
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ SPIECE_UNDERLINE = "▁"
14
+ SUPPORTED_LANGUAGES = [
15
+ "asm_Beng",
16
+ "awa_Deva",
17
+ "ben_Beng",
18
+ "bho_Deva",
19
+ "brx_Deva",
20
+ "doi_Deva",
21
+ "eng_Latn",
22
+ "gom_Deva",
23
+ "gon_Deva",
24
+ "guj_Gujr",
25
+ "hin_Deva",
26
+ "hne_Deva",
27
+ "kan_Knda",
28
+ "kas_Arab",
29
+ "kas_Deva",
30
+ "kha_Latn",
31
+ "lus_Latn",
32
+ "mag_Deva",
33
+ "mai_Deva",
34
+ "mal_Mlym",
35
+ "mar_Deva",
36
+ "mni_Beng",
37
+ "mni_Mtei",
38
+ "npi_Deva",
39
+ "ory_Orya",
40
+ "pan_Guru",
41
+ "san_Deva",
42
+ "sat_Olck",
43
+ "snd_Arab",
44
+ "snd_Deva",
45
+ "tam_Taml",
46
+ "tel_Telu",
47
+ "urd_Arab",
48
+ "unr_Deva",
49
+ ]
50
+
51
+ VOCAB_FILES_NAMES = {
52
+ "src_vocab_fp": "dict.SRC.json",
53
+ "tgt_vocab_fp": "dict.TGT.json",
54
+ "src_spm_fp": "model.SRC",
55
+ "tgt_spm_fp": "model.TGT",
56
+ }
57
+
58
+
59
+ class IndicTransTokenizer(PreTrainedTokenizer):
60
+ _added_tokens_encoder = {}
61
+ _added_tokens_decoder = {}
62
+
63
+ model_input_names = ["input_ids", "attention_mask"]
64
+
65
+ def __init__(
66
+ self,
67
+ src_vocab_fp=None,
68
+ tgt_vocab_fp=None,
69
+ src_spm_fp=None,
70
+ tgt_spm_fp=None,
71
+ unk_token="<unk>",
72
+ bos_token="<s>",
73
+ eos_token="</s>",
74
+ pad_token="<pad>",
75
+ do_lower_case=False,
76
+ **kwargs
77
+ ):
78
+
79
+ self.src = True
80
+
81
+ self.src_vocab_fp = src_vocab_fp
82
+ self.tgt_vocab_fp = tgt_vocab_fp
83
+ self.src_spm_fp = src_spm_fp
84
+ self.tgt_spm_fp = tgt_spm_fp
85
+
86
+ self.unk_token = unk_token
87
+ self.pad_token = pad_token
88
+ self.eos_token = eos_token
89
+ self.bos_token = bos_token
90
+
91
+ self.encoder = self._load_json(self.src_vocab_fp)
92
+ if self.unk_token not in self.encoder:
93
+ raise KeyError("<unk> token must be in vocab")
94
+ assert self.pad_token in self.encoder
95
+ self.encoder_rev = {v: k for k, v in self.encoder.items()}
96
+
97
+ self.decoder = self._load_json(self.tgt_vocab_fp)
98
+ if self.unk_token not in self.encoder:
99
+ raise KeyError("<unk> token must be in vocab")
100
+ assert self.pad_token in self.encoder
101
+ self.decoder_rev = {v: k for k, v in self.decoder.items()}
102
+
103
+ # load SentencePiece model for pre-processing
104
+ self.src_spm = self._load_spm(self.src_spm_fp)
105
+ self.tgt_spm = self._load_spm(self.tgt_spm_fp)
106
+
107
+ self.current_spm = self.src_spm
108
+ self.current_encoder = self.encoder
109
+ self.current_encoder_rev = self.encoder_rev
110
+
111
+ self.unk_token_id = self.encoder[self.unk_token]
112
+ self.pad_token_id = self.encoder[self.pad_token]
113
+ self.eos_token_id = self.encoder[self.eos_token]
114
+ self.bos_token_id = self.encoder[self.bos_token]
115
+
116
+ super().__init__(
117
+ src_vocab_file=self.src_vocab_fp,
118
+ tgt_vocab_file=self.src_vocab_fp,
119
+ do_lower_case=do_lower_case,
120
+ unk_token=unk_token,
121
+ bos_token=bos_token,
122
+ eos_token=eos_token,
123
+ pad_token=pad_token,
124
+ **kwargs,
125
+ )
126
+
127
+ def _switch_to_input_mode(self):
128
+ self.src = True
129
+ self.padding_side = "left"
130
+ self.current_spm = self.src_spm
131
+ self.current_encoder = self.encoder
132
+ self.current_encoder_rev = self.encoder_rev
133
+
134
+ def _switch_to_target_mode(self):
135
+ self.src = False
136
+ self.padding_side = "right"
137
+ self.current_spm = self.tgt_spm
138
+ self.current_encoder = self.decoder
139
+ self.current_encoder_rev = self.decoder_rev
140
+
141
+ def _load_spm(self, path: str) -> SentencePieceProcessor:
142
+ return SentencePieceProcessor(model_file=path)
143
+
144
+ def _save_json(self, data, path: str) -> None:
145
+ with open(path, "w", encoding="utf-8") as f:
146
+ json.dump(data, f, indent=2)
147
+
148
+ def _load_json(self, path: str) -> Union[Dict, List]:
149
+ with open(path, "r", encoding="utf-8") as f:
150
+ return json.load(f)
151
+
152
+ @property
153
+ def src_vocab_size(self) -> int:
154
+ return len(self.encoder)
155
+
156
+ @property
157
+ def tgt_vocab_size(self) -> int:
158
+ return len(self.decoder)
159
+
160
+ def get_src_vocab(self) -> Dict[str, int]:
161
+ return dict(self.encoder, **self.added_tokens_encoder)
162
+
163
+ def get_tgt_vocab(self) -> Dict[str, int]:
164
+ return dict(self.decoder, **self.added_tokens_decoder)
165
+
166
+ # hack override
167
+ def get_vocab(self) -> Dict[str, int]:
168
+ return self.get_src_vocab()
169
+
170
+ # hack override
171
+ @property
172
+ def vocab_size(self) -> int:
173
+ return self.src_vocab_size
174
+
175
+ def _convert_token_to_id(self, token: str) -> int:
176
+ """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
177
+ return self.current_encoder.get(token, self.current_encoder[self.unk_token])
178
+
179
+ def _convert_id_to_token(self, index: int) -> str:
180
+ """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
181
+ return self.current_encoder_rev.get(index, self.unk_token)
182
+
183
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
184
+ """Uses sentencepiece model for detokenization"""
185
+ pad_tokens = [token for token in tokens if token == self.pad_token]
186
+ tokens = [token for token in tokens if token != self.pad_token]
187
+ if self.src:
188
+ return (
189
+ " ".join(pad_tokens)
190
+ + " "
191
+ + " ".join(tokens[:2])
192
+ + " "
193
+ + "".join(tokens[2:]).replace(SPIECE_UNDERLINE, " ").strip()
194
+ )
195
+ return (
196
+ "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
197
+ + " "
198
+ + " ".join(pad_tokens)
199
+ )
200
+
201
+ def _tokenize(self, text) -> List[str]:
202
+ if self.src:
203
+ tokens = text.split(" ")
204
+ tags = tokens[:2]
205
+ text = " ".join(tokens[2:])
206
+ tokens = self.current_spm.EncodeAsPieces(text)
207
+ return tags + tokens
208
+ else:
209
+ return self.current_spm.EncodeAsPieces(text)
210
+
211
+ def build_inputs_with_special_tokens(
212
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
213
+ ) -> List[int]:
214
+ if token_ids_1 is None:
215
+ return token_ids_0 + [self.eos_token_id]
216
+ # We don't expect to process pairs, but leave the pair logic for API consistency
217
+ return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
218
+
219
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
220
+ if not os.path.isdir(save_directory):
221
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
222
+ return
223
+
224
+ src_spm_fp = os.path.join(save_directory, "model.SRC")
225
+ tgt_spm_fp = os.path.join(save_directory, "model.TGT")
226
+ src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
227
+ tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
228
+
229
+ self._save_json(self.encoder, src_vocab_fp)
230
+ self._save_json(self.decoder, tgt_vocab_fp)
231
+
232
+ with open(src_spm_fp, 'wb') as f:
233
+ f.write(self.src_spm.serialized_model_proto())
234
+
235
+ with open(tgt_spm_fp, 'wb') as f:
236
+ f.write(self.tgt_spm.serialized_model_proto())
237
+
238
+ return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp