mzltest commited on
Commit
be5a7dc
1 Parent(s): 01254e1

Create tokenization_bert_word_level.py

Browse files
tokenizations/tokenization_bert_word_level.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+ from __future__ import absolute_import, division, print_function, unicode_literals
18
+
19
+ import collections
20
+ import logging
21
+ import os
22
+ import unicodedata
23
+ import thulac
24
+ from io import open
25
+
26
+ from transformers.tokenization_utils import PreTrainedTokenizer
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ lac = thulac.thulac(user_dict='tokenizations/thulac_dict/seg', seg_only=True)
31
+
32
+ VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'}
33
+
34
+ PRETRAINED_VOCAB_FILES_MAP = {
35
+ 'vocab_file':
36
+ {
37
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
38
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
39
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
40
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
41
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
42
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
43
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
44
+ 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
45
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
46
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
47
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
48
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
49
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
50
+ }
51
+ }
52
+
53
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
54
+ 'bert-base-uncased': 512,
55
+ 'bert-large-uncased': 512,
56
+ 'bert-base-cased': 512,
57
+ 'bert-large-cased': 512,
58
+ 'bert-base-multilingual-uncased': 512,
59
+ 'bert-base-multilingual-cased': 512,
60
+ 'bert-base-chinese': 512,
61
+ 'bert-base-german-cased': 512,
62
+ 'bert-large-uncased-whole-word-masking': 512,
63
+ 'bert-large-cased-whole-word-masking': 512,
64
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': 512,
65
+ 'bert-large-cased-whole-word-masking-finetuned-squad': 512,
66
+ 'bert-base-cased-finetuned-mrpc': 512,
67
+ }
68
+
69
+ def load_vocab(vocab_file):
70
+ """Loads a vocabulary file into a dictionary."""
71
+ vocab = collections.OrderedDict()
72
+ with open(vocab_file, "r", encoding="utf-8") as reader:
73
+ tokens = reader.readlines()
74
+ for index, token in enumerate(tokens):
75
+ token = token.rstrip('\n')
76
+ vocab[token] = index
77
+ return vocab
78
+
79
+
80
+ def whitespace_tokenize(text):
81
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
82
+ text = text.strip()
83
+ if not text:
84
+ return []
85
+ tokens = text.split()
86
+ return tokens
87
+
88
+
89
+ class BertTokenizer(PreTrainedTokenizer):
90
+ r"""
91
+ Constructs a BertTokenizer.
92
+ :class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
93
+
94
+ Args:
95
+ vocab_file: Path to a one-wordpiece-per-line vocabulary file
96
+ do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
97
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
98
+ max_len: An artificial maximum length to truncate tokenized_doupo sequences to; Effective maximum length is always the
99
+ minimum of this value (if specified) and the underlying BERT model's sequence length.
100
+ never_split: List of tokens which will never be split during tokenization. Only has an effect when
101
+ do_wordpiece_only=False
102
+ """
103
+
104
+ vocab_files_names = VOCAB_FILES_NAMES
105
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
106
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
107
+
108
+ def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None,
109
+ unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]",
110
+ mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs):
111
+ """Constructs a BertTokenizer.
112
+
113
+ Args:
114
+ **vocab_file**: Path to a one-wordpiece-per-line vocabulary file
115
+ **do_lower_case**: (`optional`) boolean (default True)
116
+ Whether to lower case the input
117
+ Only has an effect when do_basic_tokenize=True
118
+ **do_basic_tokenize**: (`optional`) boolean (default True)
119
+ Whether to do basic tokenization before wordpiece.
120
+ **never_split**: (`optional`) list of string
121
+ List of tokens which will never be split during tokenization.
122
+ Only has an effect when do_basic_tokenize=True
123
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
124
+ Whether to tokenize Chinese characters.
125
+ This should likely be desactivated for Japanese:
126
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
127
+ """
128
+ super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
129
+ pad_token=pad_token, cls_token=cls_token,
130
+ mask_token=mask_token, **kwargs)
131
+ if not os.path.isfile(vocab_file):
132
+ raise ValueError(
133
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
134
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
135
+ self.vocab = load_vocab(vocab_file)
136
+ self.ids_to_tokens = collections.OrderedDict(
137
+ [(ids, tok) for tok, ids in self.vocab.items()])
138
+ self.do_basic_tokenize = do_basic_tokenize
139
+ if do_basic_tokenize:
140
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
141
+ never_split=never_split,
142
+ tokenize_chinese_chars=tokenize_chinese_chars)
143
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
144
+
145
+ @property
146
+ def vocab_size(self):
147
+ return len(self.vocab)
148
+
149
+ def _tokenize(self, text):
150
+ split_tokens = []
151
+ if self.do_basic_tokenize:
152
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
153
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
154
+ split_tokens.append(sub_token)
155
+ else:
156
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
157
+ return split_tokens
158
+
159
+ def _convert_token_to_id(self, token):
160
+ """ Converts a token (str/unicode) in an id using the vocab. """
161
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
162
+
163
+ def _convert_id_to_token(self, index):
164
+ """Converts an index (integer) in a token (string/unicode) using the vocab."""
165
+ return self.ids_to_tokens.get(index, self.unk_token)
166
+
167
+ def convert_tokens_to_string(self, tokens):
168
+ """ Converts a sequence of tokens (string) in a single string. """
169
+ out_string = ' '.join(tokens).replace(' ##', '').strip()
170
+ return out_string
171
+
172
+ def save_vocabulary(self, vocab_path):
173
+ """Save the tokenizer vocabulary to a directory or file."""
174
+ index = 0
175
+ if os.path.isdir(vocab_path):
176
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file'])
177
+ with open(vocab_file, "w", encoding="utf-8") as writer:
178
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
179
+ if index != token_index:
180
+ logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
181
+ " Please check that the vocabulary is not corrupted!".format(vocab_file))
182
+ index = token_index
183
+ writer.write(token + u'\n')
184
+ index += 1
185
+ return (vocab_file,)
186
+
187
+ @classmethod
188
+ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
189
+ """ Instantiate a BertTokenizer from pre-trained vocabulary files.
190
+ """
191
+ if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES:
192
+ if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True):
193
+ logger.warning("The pre-trained model you are loading is a cased model but you have not set "
194
+ "`do_lower_case` to False. We are setting `do_lower_case=False` for you but "
195
+ "you may want to check this behavior.")
196
+ kwargs['do_lower_case'] = False
197
+ elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True):
198
+ logger.warning("The pre-trained model you are loading is an uncased model but you have set "
199
+ "`do_lower_case` to False. We are setting `do_lower_case=True` for you "
200
+ "but you may want to check this behavior.")
201
+ kwargs['do_lower_case'] = True
202
+
203
+ return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
204
+
205
+
206
+ class BasicTokenizer(object):
207
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
208
+
209
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
210
+ """ Constructs a BasicTokenizer.
211
+
212
+ Args:
213
+ **do_lower_case**: Whether to lower case the input.
214
+ **never_split**: (`optional`) list of str
215
+ Kept for backward compatibility purposes.
216
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
217
+ List of token not to split.
218
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
219
+ Whether to tokenize Chinese characters.
220
+ This should likely be desactivated for Japanese:
221
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
222
+ """
223
+ if never_split is None:
224
+ never_split = []
225
+ self.do_lower_case = do_lower_case
226
+ self.never_split = never_split
227
+ self.tokenize_chinese_chars = tokenize_chinese_chars
228
+
229
+ def tokenize(self, text, never_split=None):
230
+ """ Basic Tokenization of a piece of text.
231
+ Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
232
+
233
+ Args:
234
+ **never_split**: (`optional`) list of str
235
+ Kept for backward compatibility purposes.
236
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
237
+ List of token not to split.
238
+ """
239
+ never_split = self.never_split + (never_split if never_split is not None else [])
240
+ text = self._clean_text(text)
241
+ # This was added on November 1st, 2018 for the multilingual and Chinese
242
+ # models. This is also applied to the English models now, but it doesn't
243
+ # matter since the English models were not trained on any Chinese data
244
+ # and generally don't have any Chinese data in them (there are Chinese
245
+ # characters in the vocabulary because Wikipedia does have some Chinese
246
+ # words in the English Wikipedia.).
247
+ if self.tokenize_chinese_chars:
248
+ text = self._tokenize_chinese_chars(text)
249
+ orig_tokens = whitespace_tokenize(text)
250
+ split_tokens = []
251
+ for token in orig_tokens:
252
+ if self.do_lower_case and token not in never_split:
253
+ token = token.lower()
254
+ token = self._run_strip_accents(token)
255
+ split_tokens.extend(self._run_split_on_punc(token))
256
+
257
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
258
+ return output_tokens
259
+
260
+ def _run_strip_accents(self, text):
261
+ """Strips accents from a piece of text."""
262
+ text = unicodedata.normalize("NFD", text)
263
+ output = []
264
+ for char in text:
265
+ cat = unicodedata.category(char)
266
+ if cat == "Mn":
267
+ continue
268
+ output.append(char)
269
+ return "".join(output)
270
+
271
+ def _run_split_on_punc(self, text, never_split=None):
272
+ """Splits punctuation on a piece of text."""
273
+ if never_split is not None and text in never_split:
274
+ return [text]
275
+ chars = list(text)
276
+ i = 0
277
+ start_new_word = True
278
+ output = []
279
+ while i < len(chars):
280
+ char = chars[i]
281
+ if _is_punctuation(char):
282
+ output.append([char])
283
+ start_new_word = True
284
+ else:
285
+ if start_new_word:
286
+ output.append([])
287
+ start_new_word = False
288
+ output[-1].append(char)
289
+ i += 1
290
+
291
+ return ["".join(x) for x in output]
292
+
293
+ # def _tokenize_chinese_chars(self, text):
294
+ # """Adds whitespace around any CJK character."""
295
+ # output = []
296
+ # for char in text:
297
+ # cp = ord(char)
298
+ # if self._is_chinese_char(cp) or char.isdigit():
299
+ # output.append(" ")
300
+ # output.append(char)
301
+ # output.append(" ")
302
+ # else:
303
+ # output.append(char)
304
+ # return "".join(output)
305
+ def _tokenize_chinese_chars(self, text):
306
+ """Adds whitespace around any CJK character."""
307
+ output = []
308
+ for char in text:
309
+ if char.isdigit():
310
+ output.append(" ")
311
+ output.append(char)
312
+ output.append(" ")
313
+ else:
314
+ output.append(char)
315
+ text = "".join(output)
316
+ text = [item[0].strip() for item in lac.cut(text)]
317
+ text = [item for item in text if item]
318
+ return " ".join(text)
319
+
320
+ def _is_chinese_char(self, cp):
321
+ """Checks whether CP is the codepoint of a CJK character."""
322
+ # This defines a "chinese character" as anything in the CJK Unicode block:
323
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
324
+ #
325
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
326
+ # despite its name. The modern Korean Hangul alphabet is a different block,
327
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
328
+ # space-separated words, so they are not treated specially and handled
329
+ # like the all of the other languages.
330
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
331
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
332
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
333
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
334
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
335
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
336
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
337
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
338
+ return True
339
+
340
+ return False
341
+
342
+ def _clean_text(self, text):
343
+ """Performs invalid character removal and whitespace cleanup on text."""
344
+ output = []
345
+ for char in text:
346
+ cp = ord(char)
347
+ if cp == 0 or cp == 0xfffd or _is_control(char):
348
+ continue
349
+ if _is_whitespace(char):
350
+ output.append(" ")
351
+ else:
352
+ output.append(char)
353
+ return "".join(output)
354
+
355
+
356
+ class WordpieceTokenizer(object):
357
+ """Runs WordPiece tokenization."""
358
+
359
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
360
+ self.vocab = vocab
361
+ self.unk_token = unk_token
362
+ self.max_input_chars_per_word = max_input_chars_per_word
363
+
364
+ def tokenize(self, text):
365
+ """Tokenizes a piece of text into its word pieces.
366
+
367
+ This uses a greedy longest-match-first algorithm to perform tokenization
368
+ using the given vocabulary.
369
+
370
+ For example:
371
+ input = "unaffable"
372
+ output = ["un", "##aff", "##able"]
373
+
374
+ Args:
375
+ text: A single token or whitespace separated tokens. This should have
376
+ already been passed through `BasicTokenizer`.
377
+
378
+ Returns:
379
+ A list of wordpiece tokens.
380
+ """
381
+
382
+ output_tokens = []
383
+ for token in whitespace_tokenize(text):
384
+ chars = list(token)
385
+ if len(chars) > self.max_input_chars_per_word:
386
+ output_tokens.append(self.unk_token)
387
+ continue
388
+
389
+ is_bad = False
390
+ start = 0
391
+ sub_tokens = []
392
+ while start < len(chars):
393
+ end = len(chars)
394
+ cur_substr = None
395
+ while start < end:
396
+ substr = "".join(chars[start:end])
397
+ if start > 0:
398
+ substr = "##" + substr
399
+ if substr in self.vocab:
400
+ cur_substr = substr
401
+ break
402
+ end -= 1
403
+ if cur_substr is None:
404
+ is_bad = True
405
+ break
406
+ sub_tokens.append(cur_substr)
407
+ start = end
408
+
409
+ if is_bad:
410
+ output_tokens.append(self.unk_token)
411
+ else:
412
+ output_tokens.extend(sub_tokens)
413
+ return output_tokens
414
+
415
+
416
+ def _is_whitespace(char):
417
+ """Checks whether `chars` is a whitespace character."""
418
+ # \t, \n, and \r are technically contorl characters but we treat them
419
+ # as whitespace since they are generally considered as such.
420
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
421
+ return True
422
+ cat = unicodedata.category(char)
423
+ if cat == "Zs":
424
+ return True
425
+ return False
426
+
427
+
428
+ def _is_control(char):
429
+ """Checks whether `chars` is a control character."""
430
+ # These are technically control characters but we count them as whitespace
431
+ # characters.
432
+ if char == "\t" or char == "\n" or char == "\r":
433
+ return False
434
+ cat = unicodedata.category(char)
435
+ if cat.startswith("C"):
436
+ return True
437
+ return False
438
+
439
+
440
+ def _is_punctuation(char):
441
+ """Checks whether `chars` is a punctuation character."""
442
+ cp = ord(char)
443
+ # We treat all non-letter/number ASCII as punctuation.
444
+ # Characters such as "^", "$", and "`" are not in the Unicode
445
+ # Punctuation class but we treat them as punctuation anyways, for
446
+ # consistency.
447
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
448
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
449
+ return True
450
+ cat = unicodedata.category(char)
451
+ if cat.startswith("P"):
452
+ return True
453
+ return False