jaygala24 commited on
Commit
bf95274
1 Parent(s): 5b70ce9

Upload 5 files

Browse files
Files changed (6) hide show
  1. .gitattributes +1 -0
  2. dict.SRC.json +0 -0
  3. dict.TGT.json +0 -0
  4. model.SRC +3 -0
  5. model.TGT +0 -0
  6. tokenization_indictrans.py +239 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model.SRC filter=lfs diff=lfs merge=lfs -text
dict.SRC.json ADDED
The diff for this file is too large to render. See raw diff
 
dict.TGT.json ADDED
The diff for this file is too large to render. See raw diff
 
model.SRC ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
+ size 3256903
model.TGT ADDED
Binary file (759 kB). View file
 
tokenization_indictrans.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from typing import Dict, List, Optional, Union, Tuple
5
+
6
+ from transformers.utils import logging
7
+ from sentencepiece import SentencePieceProcessor
8
+ from transformers.tokenization_utils import PreTrainedTokenizer
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ SPIECE_UNDERLINE = "▁"
14
+ SUPPORTED_LANGUAGES = [
15
+ "asm_Beng",
16
+ "awa_Deva",
17
+ "ben_Beng",
18
+ "bho_Deva",
19
+ "brx_Deva",
20
+ "doi_Deva",
21
+ "eng_Latn",
22
+ "gom_Deva",
23
+ "gon_Deva",
24
+ "guj_Gujr",
25
+ "hin_Deva",
26
+ "hne_Deva",
27
+ "kan_Knda",
28
+ "kas_Arab",
29
+ "kas_Deva",
30
+ "kha_Latn",
31
+ "lus_Latn",
32
+ "mag_Deva",
33
+ "mai_Deva",
34
+ "mal_Mlym",
35
+ "mar_Deva",
36
+ "mni_Beng",
37
+ "mni_Mtei",
38
+ "npi_Deva",
39
+ "ory_Orya",
40
+ "pan_Guru",
41
+ "san_Deva",
42
+ "sat_Olck",
43
+ "snd_Arab",
44
+ "snd_Deva",
45
+ "tam_Taml",
46
+ "tel_Telu",
47
+ "urd_Arab",
48
+ "unr_Deva",
49
+ ]
50
+
51
+ VOCAB_FILES_NAMES = {
52
+ "src_vocab_fp": "dict.SRC.json",
53
+ "tgt_vocab_fp": "dict.TGT.json",
54
+ "src_spm_fp": "model.SRC",
55
+ "tgt_spm_fp": "model.TGT",
56
+ }
57
+
58
+
59
+ class IndicTransTokenizer(PreTrainedTokenizer):
60
+ _added_tokens_encoder = {}
61
+ _added_tokens_decoder = {}
62
+
63
+ vocab_files_names = VOCAB_FILES_NAMES
64
+ model_input_names = ["input_ids", "attention_mask"]
65
+
66
+ def __init__(
67
+ self,
68
+ src_vocab_fp=None,
69
+ tgt_vocab_fp=None,
70
+ src_spm_fp=None,
71
+ tgt_spm_fp=None,
72
+ unk_token="<unk>",
73
+ bos_token="<s>",
74
+ eos_token="</s>",
75
+ pad_token="<pad>",
76
+ do_lower_case=False,
77
+ **kwargs
78
+ ):
79
+
80
+ self.src = True
81
+
82
+ self.src_vocab_fp = src_vocab_fp
83
+ self.tgt_vocab_fp = tgt_vocab_fp
84
+ self.src_spm_fp = src_spm_fp
85
+ self.tgt_spm_fp = tgt_spm_fp
86
+
87
+ self.unk_token = unk_token
88
+ self.pad_token = pad_token
89
+ self.eos_token = eos_token
90
+ self.bos_token = bos_token
91
+
92
+ self.encoder = self._load_json(self.src_vocab_fp)
93
+ if self.unk_token not in self.encoder:
94
+ raise KeyError("<unk> token must be in vocab")
95
+ assert self.pad_token in self.encoder
96
+ self.encoder_rev = {v: k for k, v in self.encoder.items()}
97
+
98
+ self.decoder = self._load_json(self.tgt_vocab_fp)
99
+ if self.unk_token not in self.encoder:
100
+ raise KeyError("<unk> token must be in vocab")
101
+ assert self.pad_token in self.encoder
102
+ self.decoder_rev = {v: k for k, v in self.decoder.items()}
103
+
104
+ # load SentencePiece model for pre-processing
105
+ self.src_spm = self._load_spm(self.src_spm_fp)
106
+ self.tgt_spm = self._load_spm(self.tgt_spm_fp)
107
+
108
+ self.current_spm = self.src_spm
109
+ self.current_encoder = self.encoder
110
+ self.current_encoder_rev = self.encoder_rev
111
+
112
+ self.unk_token_id = self.encoder[self.unk_token]
113
+ self.pad_token_id = self.encoder[self.pad_token]
114
+ self.eos_token_id = self.encoder[self.eos_token]
115
+ self.bos_token_id = self.encoder[self.bos_token]
116
+
117
+ super().__init__(
118
+ src_vocab_file=self.src_vocab_fp,
119
+ tgt_vocab_file=self.src_vocab_fp,
120
+ do_lower_case=do_lower_case,
121
+ unk_token=unk_token,
122
+ bos_token=bos_token,
123
+ eos_token=eos_token,
124
+ pad_token=pad_token,
125
+ **kwargs,
126
+ )
127
+
128
+ def _switch_to_input_mode(self):
129
+ self.src = True
130
+ self.padding_side = "left"
131
+ self.current_spm = self.src_spm
132
+ self.current_encoder = self.encoder
133
+ self.current_encoder_rev = self.encoder_rev
134
+
135
+ def _switch_to_target_mode(self):
136
+ self.src = False
137
+ self.padding_side = "right"
138
+ self.current_spm = self.tgt_spm
139
+ self.current_encoder = self.decoder
140
+ self.current_encoder_rev = self.decoder_rev
141
+
142
+ def _load_spm(self, path: str) -> SentencePieceProcessor:
143
+ return SentencePieceProcessor(model_file=path)
144
+
145
+ def _save_json(self, data, path: str) -> None:
146
+ with open(path, "w", encoding="utf-8") as f:
147
+ json.dump(data, f, indent=2)
148
+
149
+ def _load_json(self, path: str) -> Union[Dict, List]:
150
+ with open(path, "r", encoding="utf-8") as f:
151
+ return json.load(f)
152
+
153
+ @property
154
+ def src_vocab_size(self) -> int:
155
+ return len(self.encoder)
156
+
157
+ @property
158
+ def tgt_vocab_size(self) -> int:
159
+ return len(self.decoder)
160
+
161
+ def get_src_vocab(self) -> Dict[str, int]:
162
+ return dict(self.encoder, **self.added_tokens_encoder)
163
+
164
+ def get_tgt_vocab(self) -> Dict[str, int]:
165
+ return dict(self.decoder, **self.added_tokens_decoder)
166
+
167
+ # hack override
168
+ def get_vocab(self) -> Dict[str, int]:
169
+ return self.get_src_vocab()
170
+
171
+ # hack override
172
+ @property
173
+ def vocab_size(self) -> int:
174
+ return self.src_vocab_size
175
+
176
+ def _convert_token_to_id(self, token: str) -> int:
177
+ """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
178
+ return self.current_encoder.get(token, self.current_encoder[self.unk_token])
179
+
180
+ def _convert_id_to_token(self, index: int) -> str:
181
+ """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
182
+ return self.current_encoder_rev.get(index, self.unk_token)
183
+
184
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
185
+ """Uses sentencepiece model for detokenization"""
186
+ pad_tokens = [token for token in tokens if token == self.pad_token]
187
+ tokens = [token for token in tokens if token != self.pad_token]
188
+ if self.src:
189
+ return (
190
+ " ".join(pad_tokens)
191
+ + " "
192
+ + " ".join(tokens[:2])
193
+ + " "
194
+ + "".join(tokens[2:]).replace(SPIECE_UNDERLINE, " ").strip()
195
+ )
196
+ return (
197
+ "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
198
+ + " "
199
+ + " ".join(pad_tokens)
200
+ )
201
+
202
+ def _tokenize(self, text) -> List[str]:
203
+ if self.src:
204
+ tokens = text.split(" ")
205
+ tags = tokens[:2]
206
+ text = " ".join(tokens[2:])
207
+ tokens = self.current_spm.EncodeAsPieces(text)
208
+ return tags + tokens
209
+ else:
210
+ return self.current_spm.EncodeAsPieces(text)
211
+
212
+ def build_inputs_with_special_tokens(
213
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
214
+ ) -> List[int]:
215
+ if token_ids_1 is None:
216
+ return token_ids_0 + [self.eos_token_id]
217
+ # We don't expect to process pairs, but leave the pair logic for API consistency
218
+ return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
219
+
220
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
221
+ if not os.path.isdir(save_directory):
222
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
223
+ return
224
+
225
+ src_spm_fp = os.path.join(save_directory, "model.SRC")
226
+ tgt_spm_fp = os.path.join(save_directory, "model.TGT")
227
+ src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
228
+ tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
229
+
230
+ self._save_json(self.encoder, src_vocab_fp)
231
+ self._save_json(self.decoder, tgt_vocab_fp)
232
+
233
+ with open(src_spm_fp, 'wb') as f:
234
+ f.write(self.src_spm.serialized_model_proto())
235
+
236
+ with open(tgt_spm_fp, 'wb') as f:
237
+ f.write(self.tgt_spm.serialized_model_proto())
238
+
239
+ return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp