satyaalmasian commited on
Commit
0d47233
1 Parent(s): 89c6521

Upload NumBertTokenizer.py

Browse files
Files changed (1) hide show
  1. NumBertTokenizer.py +484 -0
NumBertTokenizer.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ import unicodedata
4
+ from typing import List, Optional, Tuple
5
+
6
+ from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES, PRETRAINED_VOCAB_FILES_MAP, \
7
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES, PRETRAINED_INIT_CONFIGURATION, load_vocab, BasicTokenizer, \
8
+ whitespace_tokenize,_is_whitespace,_is_control,_is_punctuation
9
+ from transformers.tokenization_utils import PreTrainedTokenizer
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class NumBertTokenizer(PreTrainedTokenizer):
16
+ r"""
17
+ Construct a BERT tokenizer. Based on WordPiece.
18
+
19
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
20
+ Users should refer to this superclass for more information regarding those methods.
21
+
22
+ Args:
23
+ vocab_file (:obj:`str`):
24
+ File containing the vocabulary.
25
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
26
+ Whether or not to lowercase the input when tokenizing.
27
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
28
+ Whether or not to do basic tokenization before WordPiece.
29
+ never_split (:obj:`Iterable`, `optional`):
30
+ Collection of tokens which will never be split during tokenization. Only has an effect when
31
+ :obj:`do_basic_tokenize=True`
32
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
33
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
34
+ token instead.
35
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
36
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
37
+ sequence classification or for a text and a question for question answering. It is also used as the last
38
+ token of a sequence built with special tokens.
39
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
40
+ The token used for padding, for example when batching sequences of different lengths.
41
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
42
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
43
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
44
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
45
+ The token used for masking values. This is the token used when training this model with masked language
46
+ modeling. This is the token which the model will try to predict.
47
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
48
+ Whether or not to tokenize Chinese characters.
49
+
50
+ This should likely be deactivated for Japanese (see this `issue
51
+ <https://github.com/huggingface/transformers/issues/328>`__).
52
+ strip_accents: (:obj:`bool`, `optional`):
53
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
54
+ value for :obj:`lowercase` (as in the original BERT).
55
+ """
56
+
57
+ vocab_files_names = VOCAB_FILES_NAMES
58
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
59
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
60
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
61
+
62
+ def __init__(
63
+ self,
64
+ vocab_file,
65
+ do_lower_case=True,
66
+ do_basic_tokenize=True,
67
+ never_split=None,
68
+ unk_token="[UNK]",
69
+ sep_token="[SEP]",
70
+ pad_token="[PAD]",
71
+ cls_token="[CLS]",
72
+ mask_token="[MASK]",
73
+ tokenize_chinese_chars=True,
74
+ strip_accents=None,
75
+ **kwargs
76
+ ):
77
+ super().__init__(
78
+ do_lower_case=do_lower_case,
79
+ do_basic_tokenize=do_basic_tokenize,
80
+ never_split=never_split,
81
+ unk_token=unk_token,
82
+ sep_token=sep_token,
83
+ pad_token=pad_token,
84
+ cls_token=cls_token,
85
+ mask_token=mask_token,
86
+ tokenize_chinese_chars=tokenize_chinese_chars,
87
+ strip_accents=strip_accents,
88
+ **kwargs,
89
+ )
90
+
91
+ if not os.path.isfile(vocab_file):
92
+ raise ValueError(
93
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
94
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
95
+ )
96
+ self.vocab = load_vocab(vocab_file)
97
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
98
+ self.do_basic_tokenize = do_basic_tokenize
99
+ if do_basic_tokenize:
100
+ self.basic_tokenizer = NumBERTBasicTokenizer(
101
+ do_lower_case=do_lower_case,
102
+ never_split=never_split,
103
+ tokenize_chinese_chars=tokenize_chinese_chars,
104
+ strip_accents=strip_accents,
105
+ )
106
+ self.wordpiece_tokenizer = NumberWordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
107
+
108
+ @property
109
+ def do_lower_case(self):
110
+ return self.basic_tokenizer.do_lower_case
111
+
112
+ @property
113
+ def vocab_size(self):
114
+ return len(self.vocab)
115
+
116
+ def get_vocab(self):
117
+ return dict(self.vocab, **self.added_tokens_encoder)
118
+
119
+ def _tokenize(self, text):
120
+ split_tokens = []
121
+ if self.do_basic_tokenize:
122
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
123
+
124
+ # If the token is part of the never_split set
125
+ if token in self.basic_tokenizer.never_split:
126
+ split_tokens.append(token)
127
+ else:
128
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
129
+ else:
130
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
131
+ return split_tokens
132
+
133
+ def _convert_token_to_id(self, token):
134
+ """ Converts a token (str) in an id using the vocab. """
135
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
136
+
137
+ def _convert_id_to_token(self, index):
138
+ """Converts an index (integer) in a token (str) using the vocab."""
139
+ return self.ids_to_tokens.get(index, self.unk_token)
140
+
141
+ def convert_tokens_to_string(self, tokens):
142
+ """ Converts a sequence of tokens (string) in a single string. """
143
+ out_string = " ".join(tokens).replace(" ##", "").strip()
144
+ return out_string
145
+
146
+ def build_inputs_with_special_tokens(
147
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
148
+ ) -> List[int]:
149
+ """
150
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
151
+ adding special tokens. A BERT sequence has the following format:
152
+
153
+ - single sequence: ``[CLS] X [SEP]``
154
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
155
+
156
+ Args:
157
+ token_ids_0 (:obj:`List[int]`):
158
+ List of IDs to which the special tokens will be added.
159
+ token_ids_1 (:obj:`List[int]`, `optional`):
160
+ Optional second list of IDs for sequence pairs.
161
+
162
+ Returns:
163
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
164
+ """
165
+ if token_ids_1 is None:
166
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
167
+ cls = [self.cls_token_id]
168
+ sep = [self.sep_token_id]
169
+ return cls + token_ids_0 + sep + token_ids_1 + sep
170
+
171
+ def get_special_tokens_mask(
172
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None,
173
+ already_has_special_tokens: bool = False
174
+ ) -> List[int]:
175
+ """
176
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
177
+ special tokens using the tokenizer ``prepare_for_model`` method.
178
+
179
+ Args:
180
+ token_ids_0 (:obj:`List[int]`):
181
+ List of IDs.
182
+ token_ids_1 (:obj:`List[int]`, `optional`):
183
+ Optional second list of IDs for sequence pairs.
184
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
185
+ Whether or not the token list is already formatted with special tokens for the model.
186
+
187
+ Returns:
188
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
189
+ """
190
+
191
+ if already_has_special_tokens:
192
+ if token_ids_1 is not None:
193
+ raise ValueError(
194
+ "You should not supply a second sequence if the provided sequence of "
195
+ "ids is already formatted with special tokens for the model."
196
+ )
197
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
198
+
199
+ if token_ids_1 is not None:
200
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
201
+ return [1] + ([0] * len(token_ids_0)) + [1]
202
+
203
+ def create_token_type_ids_from_sequences(
204
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
205
+ ) -> List[int]:
206
+ """
207
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
208
+ pair mask has the following format:
209
+
210
+ ::
211
+
212
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
213
+ | first sequence | second sequence |
214
+
215
+ If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
216
+
217
+ Args:
218
+ token_ids_0 (:obj:`List[int]`):
219
+ List of IDs.
220
+ token_ids_1 (:obj:`List[int]`, `optional`):
221
+ Optional second list of IDs for sequence pairs.
222
+
223
+ Returns:
224
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
225
+ sequence(s).
226
+ """
227
+ sep = [self.sep_token_id]
228
+ cls = [self.cls_token_id]
229
+ if token_ids_1 is None:
230
+ return len(cls + token_ids_0 + sep) * [0]
231
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
232
+
233
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
234
+ index = 0
235
+ if os.path.isdir(save_directory):
236
+ vocab_file = os.path.join(
237
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
238
+ )
239
+ else:
240
+ vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
241
+ with open(vocab_file, "w", encoding="utf-8") as writer:
242
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
243
+ if index != token_index:
244
+ logger.warning(
245
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
246
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
247
+ )
248
+ index = token_index
249
+ writer.write(token + "\n")
250
+ index += 1
251
+ return (vocab_file,)
252
+
253
+
254
+ class NumberWordpieceTokenizer(object):
255
+ """Runs WordPiece tokenization."""
256
+
257
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
258
+ self.vocab = vocab
259
+ self.unk_token = unk_token
260
+ self.max_input_chars_per_word = max_input_chars_per_word
261
+
262
+ def tokenize(self, text):
263
+ """
264
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
265
+ tokenization using the given vocabulary.
266
+
267
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
268
+
269
+ Args:
270
+ text: A single token or whitespace separated tokens. This should have
271
+ already been passed through `BasicTokenizer`.
272
+
273
+ Returns:
274
+ A list of wordpiece tokens.
275
+ """
276
+
277
+ output_tokens = []
278
+ for token in whitespace_tokenize(text):
279
+ chars = list(token)
280
+ if len(chars) > self.max_input_chars_per_word:
281
+ output_tokens.append(self.unk_token)
282
+ continue
283
+
284
+ ########## Code changed for the number tokenization ##############
285
+ chars_copy = chars.copy()
286
+ if "." in chars_copy:
287
+ chars_copy.remove(".")
288
+ if "," in chars_copy:
289
+ chars_copy.remove(",")
290
+ if all([s.isdigit for s in chars_copy]):
291
+ for index, char in enumerate(chars):
292
+ if index == 0:
293
+ output_tokens.append(char)
294
+ else:
295
+ output_tokens.append("##" + char)
296
+ ########## end change ##############
297
+ else:
298
+ is_bad = False
299
+ start = 0
300
+ sub_tokens = []
301
+ while start < len(chars):
302
+ end = len(chars)
303
+ cur_substr = None
304
+ while start < end:
305
+ substr = "".join(chars[start:end])
306
+ if start > 0:
307
+ substr = "##" + substr
308
+ if substr in self.vocab:
309
+ cur_substr = substr
310
+ break
311
+ end -= 1
312
+ if cur_substr is None:
313
+ is_bad = True
314
+ break
315
+ sub_tokens.append(cur_substr)
316
+ start = end
317
+
318
+ if is_bad:
319
+ output_tokens.append(self.unk_token)
320
+ else:
321
+ output_tokens.extend(sub_tokens)
322
+
323
+ return output_tokens
324
+
325
+
326
+ class NumBERTBasicTokenizer(object):
327
+ """
328
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
329
+
330
+ Args:
331
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
332
+ Whether or not to lowercase the input when tokenizing.
333
+ never_split (:obj:`Iterable`, `optional`):
334
+ Collection of tokens which will never be split during tokenization. Only has an effect when
335
+ :obj:`do_basic_tokenize=True`
336
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
337
+ Whether or not to tokenize Chinese characters.
338
+
339
+ This should likely be deactivated for Japanese (see this `issue
340
+ <https://github.com/huggingface/transformers/issues/328>`__).
341
+ strip_accents: (:obj:`bool`, `optional`):
342
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
343
+ value for :obj:`lowercase` (as in the original BERT).
344
+ """
345
+
346
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
347
+ if never_split is None:
348
+ never_split = []
349
+ self.do_lower_case = do_lower_case
350
+ self.never_split = set(never_split)
351
+ self.tokenize_chinese_chars = tokenize_chinese_chars
352
+ self.strip_accents = strip_accents
353
+
354
+ def tokenize(self, text, never_split=None):
355
+ """
356
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
357
+ WordPieceTokenizer.
358
+
359
+ Args:
360
+ **never_split**: (`optional`) list of str
361
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
362
+ :func:`PreTrainedTokenizer.tokenize`) List of token not to split.
363
+ """
364
+ # union() returns a new set by concatenating the two sets.
365
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
366
+ text = self._clean_text(text)
367
+
368
+ # This was added on November 1st, 2018 for the multilingual and Chinese
369
+ # models. This is also applied to the English models now, but it doesn't
370
+ # matter since the English models were not trained on any Chinese data
371
+ # and generally don't have any Chinese data in them (there are Chinese
372
+ # characters in the vocabulary because Wikipedia does have some Chinese
373
+ # words in the English Wikipedia.).
374
+ if self.tokenize_chinese_chars:
375
+ text = self._tokenize_chinese_chars(text)
376
+ orig_tokens = whitespace_tokenize(text)
377
+ split_tokens = []
378
+ for token in orig_tokens:
379
+ if token not in never_split:
380
+ if self.do_lower_case:
381
+ token = token.lower()
382
+ if self.strip_accents is not False:
383
+ token = self._run_strip_accents(token)
384
+ elif self.strip_accents:
385
+ token = self._run_strip_accents(token)
386
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
387
+
388
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
389
+ return output_tokens
390
+
391
+ def _run_strip_accents(self, text):
392
+ """Strips accents from a piece of text."""
393
+ text = unicodedata.normalize("NFD", text)
394
+ output = []
395
+ for char in text:
396
+ cat = unicodedata.category(char)
397
+ if cat == "Mn":
398
+ continue
399
+ output.append(char)
400
+ return "".join(output)
401
+
402
+ def _run_split_on_punc(self, text, never_split=None):
403
+ """Splits punctuation on a piece of text."""
404
+ if never_split is not None and text in never_split:
405
+ return [text]
406
+ chars = list(text)
407
+ i = 0
408
+ start_new_word = True
409
+ output = []
410
+
411
+ ########## Code changed for the number tokenization ##############
412
+ chars_copy = chars.copy()
413
+ if "." in chars_copy:
414
+ chars_copy.remove(".")
415
+ if "," in chars_copy:
416
+ chars_copy.remove(",")
417
+ if all([s.isdigit for s in chars_copy]):
418
+ return [text]
419
+ ########## end change ##############
420
+ else:
421
+ while i < len(chars):
422
+ char = chars[i]
423
+ if _is_punctuation(char):
424
+ output.append([char])
425
+ start_new_word = True
426
+ else:
427
+ if start_new_word:
428
+ output.append([])
429
+ start_new_word = False
430
+ output[-1].append(char)
431
+ i += 1
432
+
433
+ return ["".join(x) for x in output]
434
+
435
+ def _tokenize_chinese_chars(self, text):
436
+ """Adds whitespace around any CJK character."""
437
+ output = []
438
+ for char in text:
439
+ cp = ord(char)
440
+ if self._is_chinese_char(cp):
441
+ output.append(" ")
442
+ output.append(char)
443
+ output.append(" ")
444
+ else:
445
+ output.append(char)
446
+ return "".join(output)
447
+
448
+ def _is_chinese_char(self, cp):
449
+ """Checks whether CP is the codepoint of a CJK character."""
450
+ # This defines a "chinese character" as anything in the CJK Unicode block:
451
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
452
+ #
453
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
454
+ # despite its name. The modern Korean Hangul alphabet is a different block,
455
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
456
+ # space-separated words, so they are not treated specially and handled
457
+ # like the all of the other languages.
458
+ if (
459
+ (cp >= 0x4E00 and cp <= 0x9FFF)
460
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
461
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
462
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
463
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
464
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
465
+ or (cp >= 0xF900 and cp <= 0xFAFF)
466
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
467
+ ): #
468
+ return True
469
+
470
+ return False
471
+
472
+ def _clean_text(self, text):
473
+ """Performs invalid character removal and whitespace cleanup on text."""
474
+ output = []
475
+ for char in text:
476
+ cp = ord(char)
477
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
478
+ continue
479
+ if _is_whitespace(char):
480
+ output.append(" ")
481
+ else:
482
+ output.append(char)
483
+ return "".join(output)
484
+