fclong commited on
Commit
bda0379
·
1 Parent(s): 4b743ae

Create tokenizers_pegasus.py

Browse files
Files changed (1) hide show
  1. tokenizers_pegasus.py +597 -0
tokenizers_pegasus.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fengshen.examples.pegasus.data_utils import (
2
+ _is_control,
3
+ _is_punctuation,
4
+ _is_whitespace,
5
+ _is_chinese_char)
6
+ from transformers import PreTrainedTokenizer
7
+ from transformers import logging
8
+ from typing import List, Optional, Tuple, Union
9
+ import collections
10
+ import os
11
+ import unicodedata
12
+ import re
13
+ import jieba
14
+ import sys
15
+
16
+ sys.path.append("../../../../")
17
+
18
+ jieba.dt.tmp_dir = os.path.expanduser("~/.cache/")
19
+ # jieba.enable_parallel(8)
20
+ jieba.initialize()
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
25
+
26
+
27
+ def load_vocab(vocab_file):
28
+ """Loads a vocabulary file into a dictionary."""
29
+ vocab = collections.OrderedDict()
30
+ with open(vocab_file, "r", encoding="utf-8") as reader:
31
+ tokens = reader.readlines()
32
+ for index, token in enumerate(tokens):
33
+ token = token.rstrip("\n")
34
+ vocab[token] = index
35
+ return vocab
36
+
37
+
38
+ def whitespace_tokenize(text):
39
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
40
+ text = text.strip()
41
+ if not text:
42
+ return []
43
+ tokens = text.split()
44
+ return tokens
45
+
46
+
47
+ class PegasusTokenizer(PreTrainedTokenizer):
48
+ # copy from BertTokenizer
49
+ r"""
50
+ Construct a Pegasus tokenizer. Based on WordPiece.
51
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
52
+ this superclass for more information regarding those methods.
53
+ Args:
54
+ vocab_file (`str`):
55
+ File containing the vocabulary.
56
+ do_lower_case (`bool`, *optional*, defaults to `True`):
57
+ Whether or not to lowercase the input when tokenizing.
58
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
59
+ Whether or not to do basic tokenization before WordPiece.
60
+ never_split (`Iterable`, *optional*):
61
+ Collection of tokens which will never be split during tokenization. Only has an effect when
62
+ `do_basic_tokenize=True`
63
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
64
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
65
+ token instead.
66
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
67
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
68
+ sequence classification or for a text and a question for question answering. It is also used as the last
69
+ token of a sequence built with special tokens.
70
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
71
+ The token used for padding, for example when batching sequences of different lengths.
72
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
73
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
74
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
75
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
76
+ The token used for masking values. This is the token used when training this model with masked language
77
+ modeling. This is the token which the model will try to predict.
78
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
79
+ Whether or not to tokenize Chinese characters.
80
+ This should likely be deactivated for Japanese (see this
81
+ [issue](https://github.com/huggingface/transformers/issues/328)).
82
+ strip_accents (`bool`, *optional*):
83
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
84
+ value for `lowercase` (as in the original BERT).
85
+ """
86
+
87
+ vocab_files_names = VOCAB_FILES_NAMES
88
+ model_input_names = ["input_ids", "attention_mask"]
89
+
90
+ # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
91
+ # pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
92
+ # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
93
+
94
+ def __init__(self,
95
+ vocab_file,
96
+ do_lower_case=True,
97
+ do_basic_tokenize=True,
98
+ never_split=None,
99
+ pad_token="<pad>",
100
+ eos_token="</s>",
101
+ unk_token="<unk>",
102
+ mask_token="<mask_2>",
103
+ mask_token_sent="<mask_1>",
104
+ additional_special_tokens=None,
105
+ sep_token="[SEP]",
106
+ cls_token="[CLS]",
107
+ tokenize_chinese_chars=True,
108
+ strip_accents=None,
109
+ offset=100,
110
+ pre_tokenizer=lambda x: jieba.cut(x, HMM=False),
111
+ **kwargs):
112
+ self.offset = offset
113
+
114
+ if additional_special_tokens is not None:
115
+ if not isinstance(additional_special_tokens, list):
116
+ raise TypeError(
117
+ f"additional_special_tokens should be of type {type(list)}, \
118
+ but is {type(additional_special_tokens)}"
119
+ )
120
+
121
+ additional_special_tokens_extended = (
122
+ ([mask_token_sent] + additional_special_tokens)
123
+ if mask_token_sent not in additional_special_tokens
124
+ and mask_token_sent is not None else additional_special_tokens)
125
+
126
+ # fill additional tokens with ..., <unk_token_102> in case not all additional tokens are already taken
127
+ additional_special_tokens_extended += [
128
+ f"<unk_{i}>" for i in range(
129
+ len(additional_special_tokens_extended), self.offset - 1)
130
+ ]
131
+
132
+ if len(set(additional_special_tokens_extended)) != len(
133
+ additional_special_tokens_extended):
134
+ raise ValueError(
135
+ f"Please make sure that the provided additional_special_tokens \
136
+ do not contain an incorrectly shifted list of <unk_x> tokens. \
137
+ Found {additional_special_tokens_extended}."
138
+ )
139
+ additional_special_tokens = additional_special_tokens_extended
140
+ else:
141
+ additional_special_tokens = [
142
+ mask_token_sent
143
+ ] if mask_token_sent is not None else []
144
+ # additional_special_tokens += [f"<unk_{i}>" for i in range(3, self.offset)]
145
+
146
+ # print("additional_special_tokens: ", additional_special_tokens)
147
+
148
+ if not os.path.isfile(vocab_file):
149
+ raise ValueError(
150
+ f"Can't find a vocabulary file at path '{vocab_file}'. \
151
+ To load the vocabulary from a Google pretrained "
152
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
153
+ )
154
+
155
+ super().__init__(
156
+ do_lower_case=do_lower_case,
157
+ do_basic_tokenize=do_basic_tokenize,
158
+ never_split=never_split,
159
+ unk_token=unk_token,
160
+ sep_token=sep_token,
161
+ pad_token=pad_token,
162
+ cls_token=cls_token,
163
+ mask_token=mask_token,
164
+ eos_token=eos_token,
165
+ tokenize_chinese_chars=tokenize_chinese_chars,
166
+ additional_special_tokens=additional_special_tokens,
167
+ strip_accents=strip_accents,
168
+ **kwargs,
169
+ )
170
+
171
+ self.pre_tokenizer = pre_tokenizer
172
+ self.mask_token_sent = mask_token_sent
173
+ self.vocab = load_vocab(vocab_file)
174
+
175
+ self.vocab[self.eos_token] = self.vocab.pop("[unused1]")
176
+ # self.vocab[self.eos_token] = self.vocab.pop("[unused2]")
177
+ self.vocab[self.pad_token] = self.vocab.pop("[PAD]")
178
+ self.vocab[self.unk_token] = self.vocab.pop("[UNK]")
179
+
180
+ if self.mask_token_sent is not None:
181
+ self.vocab[self.mask_token] = self.vocab.pop("[unused3]")
182
+ self.vocab[self.mask_token_sent] = self.vocab.pop("[unused2]")
183
+
184
+ self.ids_to_tokens = collections.OrderedDict([
185
+ (ids, tok) for tok, ids in self.vocab.items()
186
+ ])
187
+ self.do_basic_tokenize = do_basic_tokenize
188
+ if do_basic_tokenize:
189
+ self.basic_tokenizer = BasicTokenizer(
190
+ do_lower_case=do_lower_case,
191
+ never_split=never_split,
192
+ tokenize_chinese_chars=tokenize_chinese_chars,
193
+ strip_accents=strip_accents,
194
+ )
195
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
196
+ unk_token=self.unk_token)
197
+
198
+ @property
199
+ def do_lower_case(self):
200
+ return self.basic_tokenizer.do_lower_case
201
+
202
+ @property
203
+ def vocab_size(self):
204
+ return len(self.vocab)
205
+
206
+ def get_vocab(self):
207
+ return dict(self.vocab, **self.added_tokens_encoder)
208
+
209
+ def _tokenize(self, text):
210
+ split_tokens = []
211
+ # print("pegasus_tokenizer: ", text)
212
+ for text in self.pre_tokenizer(text):
213
+ if text in self.vocab:
214
+ split_tokens.append(text)
215
+ else:
216
+ if self.do_basic_tokenize:
217
+ for token in self.basic_tokenizer.tokenize(
218
+ text, never_split=self.all_special_tokens):
219
+
220
+ # If the token is part of the never_split set
221
+ if token in self.basic_tokenizer.never_split:
222
+ split_tokens.append(token)
223
+ else:
224
+ split_tokens += self.wordpiece_tokenizer.tokenize(
225
+ token)
226
+ else:
227
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
228
+ return split_tokens
229
+
230
+ def _convert_token_to_id(self, token):
231
+ """Converts a token (str) in an id using the vocab."""
232
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
233
+
234
+ def _convert_id_to_token(self, index):
235
+ """Converts an index (integer) in a token (str) using the vocab."""
236
+ return self.ids_to_tokens.get(index, self.unk_token)
237
+
238
+ @staticmethod
239
+ def _cjk_punctuation():
240
+ return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\
241
+ \uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\
242
+ \uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\
243
+ \u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\
244
+ \u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'
245
+
246
+ def convert_ids_to_tokens(
247
+ self,
248
+ ids: Union[int, List[int]],
249
+ skip_special_tokens: bool = False) -> Union[str, List[str]]:
250
+ """
251
+ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
252
+ added tokens.
253
+ Args:
254
+ ids (`int` or `List[int]`):
255
+ The token id (or token ids) to convert to tokens.
256
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
257
+ Whether or not to remove special tokens in the decoding.
258
+ Returns:
259
+ `str` or `List[str]`: The decoded token(s).
260
+ """
261
+ if isinstance(ids, int):
262
+ if ids in self.added_tokens_decoder:
263
+ return self.added_tokens_decoder[ids]
264
+ else:
265
+ return self._convert_id_to_token(ids)
266
+ tokens = []
267
+ for index in ids:
268
+ index = int(index)
269
+ if skip_special_tokens and index in self.all_special_ids and index != 2:
270
+ continue
271
+ if index in self.added_tokens_decoder:
272
+ tokens.append(self.added_tokens_decoder[index])
273
+ else:
274
+ tokens.append(self._convert_id_to_token(index))
275
+ return tokens
276
+
277
+ def convert_tokens_to_string(self, tokens):
278
+ """Converts a sequence of tokens (string) in a single string."""
279
+ # for token in
280
+ # tokens = tokens or self.ids_to_tokens(ids)
281
+ # tokens = [token for token in tokens if not self._is_special(token)]
282
+
283
+ text = ''
284
+ for i, token in enumerate(tokens):
285
+ if token[:2] == '##':
286
+ text += token[2:]
287
+ elif len(token) == 1 and _is_chinese_char(ord(token)):
288
+ text += token
289
+ elif len(token) == 1 and _is_punctuation(token):
290
+ text += token
291
+ text += ' '
292
+ elif i > 0 and _is_chinese_char(ord(text[-1])):
293
+ text += token
294
+ elif tokens == "</s>":
295
+ continue
296
+ else:
297
+ text += ' '
298
+ text += token
299
+
300
+ text = re.sub(' +', ' ', text)
301
+ text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
302
+ punctuation = re.sub(' +', '', self._cjk_punctuation()).strip() + '+-/={(<['
303
+ punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
304
+ punctuation_regex = '(%s) ' % punctuation_regex
305
+ text = re.sub(punctuation_regex, '\\1', text)
306
+ text = re.sub(r'(\d\.) (\d)', '\\1\\2', text)
307
+
308
+ return text.strip()
309
+ # out_string = " ".join(tokens).replace(" ##", "").strip()
310
+
311
+ def build_inputs_with_special_tokens(
312
+ self,
313
+ token_ids_0: List[int],
314
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
315
+ """
316
+ Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
317
+ and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:
318
+ - single sequence: `X </s>`
319
+ - pair of sequences: `A B </s>` (not intended use)
320
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
321
+ separator.
322
+ Args:
323
+ token_ids_0 (`List[int]`):
324
+ List of IDs to which the special tokens will be added.
325
+ token_ids_1 (`List[int]`, *optional*):
326
+ Optional second list of IDs for sequence pairs.
327
+ Returns:
328
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
329
+ """
330
+ if token_ids_1 is None:
331
+ return token_ids_0 + [self.eos_token_id]
332
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
333
+
334
+ def _special_token_mask(self, seq):
335
+ all_special_ids = set(
336
+ self.all_special_ids) # call it once instead of inside list comp
337
+ # all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
338
+
339
+ return [1 if x in all_special_ids else 0 for x in seq]
340
+
341
+ def get_special_tokens_mask(
342
+ self,
343
+ token_ids_0: List[int],
344
+ token_ids_1: Optional[List[int]] = None,
345
+ already_has_special_tokens: bool = False) -> List[int]:
346
+ """
347
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
348
+ special tokens using the tokenizer `prepare_for_model` method.
349
+ Args:
350
+ token_ids_0 (`List[int]`):
351
+ List of IDs.
352
+ token_ids_1 (`List[int]`, *optional*):
353
+ Optional second list of IDs for sequence pairs.
354
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
355
+ Whether or not the token list is already formatted with special tokens for the model.
356
+ Returns:
357
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
358
+ """
359
+
360
+ if already_has_special_tokens:
361
+ return self._special_token_mask(token_ids_0)
362
+ elif token_ids_1 is None:
363
+ return self._special_token_mask(token_ids_0) + [self.eos_token_id]
364
+ else:
365
+ return self._special_token_mask(token_ids_0 +
366
+ token_ids_1) + [self.eos_token_id]
367
+
368
+ def num_special_tokens_to_add(self, pair=False):
369
+ """Just EOS"""
370
+ return 1
371
+
372
+ def save_vocabulary(self,
373
+ save_directory: str,
374
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
375
+ index = 0
376
+ if os.path.isdir(save_directory):
377
+ vocab_file = os.path.join(
378
+ save_directory,
379
+ (filename_prefix + "-" if filename_prefix else "") +
380
+ VOCAB_FILES_NAMES["vocab_file"])
381
+ else:
382
+ vocab_file = (filename_prefix +
383
+ "-" if filename_prefix else "") + save_directory
384
+ with open(vocab_file, "w", encoding="utf-8") as writer:
385
+ for token, token_index in sorted(self.vocab.items(),
386
+ key=lambda kv: kv[1]):
387
+ if index != token_index:
388
+ logger.warning(
389
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
390
+ " Please check that the vocabulary is not corrupted!")
391
+ index = token_index
392
+ writer.write(token + "\n")
393
+ index += 1
394
+ return (vocab_file, )
395
+
396
+
397
+ class BasicTokenizer(object):
398
+ """
399
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
400
+ Args:
401
+ do_lower_case (`bool`, *optional*, defaults to `True`):
402
+ Whether or not to lowercase the input when tokenizing.
403
+ never_split (`Iterable`, *optional*):
404
+ Collection of tokens which will never be split during tokenization. Only has an effect when
405
+ `do_basic_tokenize=True`
406
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
407
+ Whether or not to tokenize Chinese characters.
408
+ This should likely be deactivated for Japanese (see this
409
+ [issue](https://github.com/huggingface/transformers/issues/328)).
410
+ strip_accents: (`bool`, *optional*):
411
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
412
+ value for `lowercase` (as in the original BERT).
413
+ """
414
+
415
+ def __init__(self,
416
+ do_lower_case=True,
417
+ never_split=None,
418
+ tokenize_chinese_chars=True,
419
+ strip_accents=None):
420
+ if never_split is None:
421
+ never_split = []
422
+ self.do_lower_case = do_lower_case
423
+ self.never_split = set(never_split)
424
+ self.tokenize_chinese_chars = tokenize_chinese_chars
425
+ self.strip_accents = strip_accents
426
+
427
+ def tokenize(self, text, never_split=None):
428
+ """
429
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
430
+ WordPieceTokenizer.
431
+ Args:
432
+ never_split (`List[str]`, *optional*)
433
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
434
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
435
+ """
436
+ # union() returns a new set by concatenating the two sets.
437
+ never_split = self.never_split.union(
438
+ set(never_split)) if never_split else self.never_split
439
+ text = self._clean_text(text)
440
+
441
+ # This was added on November 1st, 2018 for the multilingual and Chinese
442
+ # models. This is also applied to the English models now, but it doesn't
443
+ # matter since the English models were not trained on any Chinese data
444
+ # and generally don't have any Chinese data in them (there are Chinese
445
+ # characters in the vocabulary because Wikipedia does have some Chinese
446
+ # words in the English Wikipedia.).
447
+ if self.tokenize_chinese_chars:
448
+ text = self._tokenize_chinese_chars(text)
449
+ orig_tokens = whitespace_tokenize(text)
450
+ split_tokens = []
451
+ for token in orig_tokens:
452
+ if token not in never_split:
453
+ if self.do_lower_case:
454
+ token = token.lower()
455
+ if self.strip_accents is not False:
456
+ token = self._run_strip_accents(token)
457
+ elif self.strip_accents:
458
+ token = self._run_strip_accents(token)
459
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
460
+
461
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
462
+ return output_tokens
463
+
464
+ def _run_strip_accents(self, text):
465
+ """Strips accents from a piece of text."""
466
+ text = unicodedata.normalize("NFD", text)
467
+ output = []
468
+ for char in text:
469
+ cat = unicodedata.category(char)
470
+ if cat == "Mn":
471
+ continue
472
+ output.append(char)
473
+ return "".join(output)
474
+
475
+ def _run_split_on_punc(self, text, never_split=None):
476
+ """Splits punctuation on a piece of text."""
477
+ if never_split is not None and text in never_split:
478
+ return [text]
479
+ chars = list(text)
480
+ i = 0
481
+ start_new_word = True
482
+ output = []
483
+ while i < len(chars):
484
+ char = chars[i]
485
+ if _is_punctuation(char):
486
+ output.append([char])
487
+ start_new_word = True
488
+ else:
489
+ if start_new_word:
490
+ output.append([])
491
+ start_new_word = False
492
+ output[-1].append(char)
493
+ i += 1
494
+
495
+ return ["".join(x) for x in output]
496
+
497
+ def _tokenize_chinese_chars(self, text):
498
+ """Adds whitespace around any CJK character."""
499
+ output = []
500
+ for char in text:
501
+ cp = ord(char)
502
+ if self._is_chinese_char(cp):
503
+ output.append(" ")
504
+ output.append(char)
505
+ output.append(" ")
506
+ else:
507
+ output.append(char)
508
+ return "".join(output)
509
+
510
+ def _is_chinese_char(self, cp):
511
+ """Checks whether CP is the codepoint of a CJK character."""
512
+ # This defines a "chinese character" as anything in the CJK Unicode block:
513
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
514
+ #
515
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
516
+ # despite its name. The modern Korean Hangul alphabet is a different block,
517
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
518
+ # space-separated words, so they are not treated specially and handled
519
+ # like the all of the other languages.
520
+ if ((cp >= 0x4E00 and cp <= 0x9FFF)
521
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
522
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
523
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
524
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
525
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
526
+ or (cp >= 0xF900 and cp <= 0xFAFF)
527
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)): #
528
+ return True
529
+
530
+ return False
531
+
532
+ def _clean_text(self, text):
533
+ """Performs invalid character removal and whitespace cleanup on text."""
534
+ output = []
535
+ for char in text:
536
+ cp = ord(char)
537
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
538
+ continue
539
+ if _is_whitespace(char):
540
+ output.append(" ")
541
+ else:
542
+ output.append(char)
543
+ return "".join(output)
544
+
545
+
546
+ class WordpieceTokenizer(object):
547
+ """Runs WordPiece tokenization."""
548
+
549
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
550
+ self.vocab = vocab
551
+ self.unk_token = unk_token
552
+ self.max_input_chars_per_word = max_input_chars_per_word
553
+
554
+ def tokenize(self, text):
555
+ """
556
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
557
+ tokenization using the given vocabulary.
558
+ For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
559
+ Args:
560
+ text: A single token or whitespace separated tokens. This should have
561
+ already been passed through *BasicTokenizer*.
562
+ Returns:
563
+ A list of wordpiece tokens.
564
+ """
565
+
566
+ output_tokens = []
567
+ for token in whitespace_tokenize(text):
568
+ chars = list(token)
569
+ if len(chars) > self.max_input_chars_per_word:
570
+ output_tokens.append(self.unk_token)
571
+ continue
572
+
573
+ is_bad = False
574
+ start = 0
575
+ sub_tokens = []
576
+ while start < len(chars):
577
+ end = len(chars)
578
+ cur_substr = None
579
+ while start < end:
580
+ substr = "".join(chars[start:end])
581
+ if start > 0:
582
+ substr = "##" + substr
583
+ if substr in self.vocab:
584
+ cur_substr = substr
585
+ break
586
+ end -= 1
587
+ if cur_substr is None:
588
+ is_bad = True
589
+ break
590
+ sub_tokens.append(cur_substr)
591
+ start = end
592
+
593
+ if is_bad:
594
+ output_tokens.append(self.unk_token)
595
+ else:
596
+ output_tokens.extend(sub_tokens)
597
+ return output_tokens