singletongue commited on
Commit
2601e53
1 Parent(s): 380f2db

Update tokenizer

Browse files
Files changed (3) hide show
  1. README.md +22 -1
  2. entity_vocab.json +2 -2
  3. tokenization_luke_bert_japanese.py +412 -44
README.md CHANGED
@@ -13,4 +13,25 @@ tags:
13
  - 2023年7月1日時点の日本語Wikipediaのデータで事前学習をおこないました
14
  - `[UNK]` (unknown) エンティティを扱えるようにしました
15
 
16
- 詳細は[ブログ記事](https://tech.uzabase.com/entry/2023/09/07/172958)をご参照ください。
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  - 2023年7月1日時点の日本語Wikipediaのデータで事前学習をおこないました
14
  - `[UNK]` (unknown) エンティティを扱えるようにしました
15
 
16
+ 詳細は[ブログ記事](https://tech.uzabase.com/entry/2023/09/07/172958)をご参照ください。
17
+
18
+ ## 使用方法
19
+
20
+ ```python
21
+ from transformers import AutoTokenizer, AutoModel
22
+
23
+ # 本モデル用のトークナイザのコードを使用するため、trust_remote_code=True の指定が必要です
24
+ tokenizer = AutoTokenizer.from_pretrained("uzabase/luke-japanese-wordpiece-base", trust_remote_code=True)
25
+
26
+ model = AutoModel.from_pretrained("uzabase/luke-japanese-wordpiece-base")
27
+ ```
28
+
29
+ ## 更新情報
30
+
31
+ - **2023/11/28:** 以下の更新を行いました。
32
+ - トークナイザが transformers v4.34.0 以降で読み込み不可となっていた問題を修正しました。
33
+ - トークナイザの出力に `position_ids` を含めるように変更しました。
34
+ - 以前は LUKE のモデルが [自動的に付与](https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/luke/modeling_luke.py#L424) する `position_ids` が使われていましたが、これは RoBERTa 仕様のものであり、BERT を使った本モデルでは正しい値となっていませんでした。そこで、 BERT 向けの正しい `position_ids` の値がモデルに入力されるように、`position_ids` を明示的にトークナイザの出力に含めるようにしました。
35
+ - トークナイザの `entity_vocab` の各トークン(`"[PAD]"` 等の特殊トークンを除く)の先頭に付いていた `"None:"` の文字列を除去しました。
36
+ - 例えば、 `"None:聖徳太子"` となっていたトークンは `"聖徳太子"` に修正されています。
37
+ - **2023/09/07:** モデルを公開しました。
entity_vocab.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:be6327e7cafc2f2b5f694a594d57113fd2bf6b620c592929202f75683b18b67d
3
- size 23721849
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44b62a4236024bcfbc396e434fb137edecbb106e7f6bc36bc2465016d99d84dd
3
+ size 20763373
tokenization_luke_bert_japanese.py CHANGED
@@ -18,7 +18,7 @@ import collections
18
  import copy
19
  import json
20
  import os
21
- from typing import List, Optional, Tuple
22
 
23
  from transformers.models.bert_japanese.tokenization_bert_japanese import (
24
  BasicTokenizer,
@@ -31,7 +31,9 @@ from transformers.models.bert_japanese.tokenization_bert_japanese import (
31
  load_vocab,
32
  )
33
  from transformers.models.luke import LukeTokenizer
34
- from transformers.tokenization_utils_base import AddedToken
 
 
35
  from transformers.utils import logging
36
 
37
 
@@ -53,7 +55,7 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
53
  vocab_files_names = VOCAB_FILES_NAMES
54
  pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
55
  max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
56
- model_input_names = ["input_ids", "attention_mask"]
57
 
58
  def __init__(
59
  self,
@@ -85,35 +87,6 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
85
  jumanpp_kwargs=None,
86
  **kwargs,
87
  ):
88
- # We call the grandparent's init, not the parent's.
89
- super(LukeTokenizer, self).__init__(
90
- spm_file=spm_file,
91
- unk_token=unk_token,
92
- sep_token=sep_token,
93
- pad_token=pad_token,
94
- cls_token=cls_token,
95
- mask_token=mask_token,
96
- do_lower_case=do_lower_case,
97
- do_word_tokenize=do_word_tokenize,
98
- do_subword_tokenize=do_subword_tokenize,
99
- word_tokenizer_type=word_tokenizer_type,
100
- subword_tokenizer_type=subword_tokenizer_type,
101
- never_split=never_split,
102
- mecab_kwargs=mecab_kwargs,
103
- sudachi_kwargs=sudachi_kwargs,
104
- jumanpp_kwargs=jumanpp_kwargs,
105
- task=task,
106
- max_entity_length=32,
107
- max_mention_length=30,
108
- entity_token_1="<ent>",
109
- entity_token_2="<ent2>",
110
- entity_unk_token=entity_unk_token,
111
- entity_pad_token=entity_pad_token,
112
- entity_mask_token=entity_mask_token,
113
- entity_mask2_token=entity_mask2_token,
114
- **kwargs,
115
- )
116
-
117
  if subword_tokenizer_type == "sentencepiece":
118
  if not os.path.isfile(spm_file):
119
  raise ValueError(
@@ -161,11 +134,11 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
161
  self.subword_tokenizer_type = subword_tokenizer_type
162
  if do_subword_tokenize:
163
  if subword_tokenizer_type == "wordpiece":
164
- self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
165
  elif subword_tokenizer_type == "character":
166
- self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=self.unk_token)
167
  elif subword_tokenizer_type == "sentencepiece":
168
- self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=self.unk_token)
169
  else:
170
  raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
171
 
@@ -212,6 +185,35 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
212
 
213
  self.max_mention_length = max_mention_length
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  @property
216
  # Copied from BertJapaneseTokenizer
217
  def do_lower_case(self):
@@ -298,16 +300,13 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
298
  """
299
  Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
300
  adding special tokens. A BERT sequence has the following format:
301
-
302
  - single sequence: `[CLS] X [SEP]`
303
  - pair of sequences: `[CLS] A [SEP] B [SEP]`
304
-
305
  Args:
306
  token_ids_0 (`List[int]`):
307
  List of IDs to which the special tokens will be added.
308
  token_ids_1 (`List[int]`, *optional*):
309
  Optional second list of IDs for sequence pairs.
310
-
311
  Returns:
312
  `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
313
  """
@@ -324,7 +323,6 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
324
  """
325
  Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
326
  special tokens using the tokenizer `prepare_for_model` method.
327
-
328
  Args:
329
  token_ids_0 (`List[int]`):
330
  List of IDs.
@@ -332,7 +330,6 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
332
  Optional second list of IDs for sequence pairs.
333
  already_has_special_tokens (`bool`, *optional*, defaults to `False`):
334
  Whether or not the token list is already formatted with special tokens for the model.
335
-
336
  Returns:
337
  `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
338
  """
@@ -353,20 +350,16 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
353
  """
354
  Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
355
  pair mask has the following format:
356
-
357
  ```
358
  0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
359
  | first sequence | second sequence |
360
  ```
361
-
362
  If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
363
-
364
  Args:
365
  token_ids_0 (`List[int]`):
366
  List of IDs.
367
  token_ids_1 (`List[int]`, *optional*):
368
  Optional second list of IDs for sequence pairs.
369
-
370
  Returns:
371
  `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
372
  """
@@ -376,9 +369,384 @@ class LukeBertJapaneseTokenizer(LukeTokenizer):
376
  return len(cls + token_ids_0 + sep) * [0]
377
  return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
378
 
 
379
  def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
380
  return (text, kwargs)
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
383
  if os.path.isdir(save_directory):
384
  if self.subword_tokenizer_type == "sentencepiece":
 
18
  import copy
19
  import json
20
  import os
21
+ from typing import Dict, List, Optional, Tuple, Union
22
 
23
  from transformers.models.bert_japanese.tokenization_bert_japanese import (
24
  BasicTokenizer,
 
31
  load_vocab,
32
  )
33
  from transformers.models.luke import LukeTokenizer
34
+ from transformers.tokenization_utils_base import (
35
+ AddedToken, BatchEncoding, EncodedInput, PaddingStrategy, TensorType, TruncationStrategy
36
+ )
37
  from transformers.utils import logging
38
 
39
 
 
55
  vocab_files_names = VOCAB_FILES_NAMES
56
  pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
57
  max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
58
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
59
 
60
  def __init__(
61
  self,
 
87
  jumanpp_kwargs=None,
88
  **kwargs,
89
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  if subword_tokenizer_type == "sentencepiece":
91
  if not os.path.isfile(spm_file):
92
  raise ValueError(
 
134
  self.subword_tokenizer_type = subword_tokenizer_type
135
  if do_subword_tokenize:
136
  if subword_tokenizer_type == "wordpiece":
137
+ self.subword_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=str(unk_token))
138
  elif subword_tokenizer_type == "character":
139
+ self.subword_tokenizer = CharacterTokenizer(vocab=self.vocab, unk_token=str(unk_token))
140
  elif subword_tokenizer_type == "sentencepiece":
141
+ self.subword_tokenizer = SentencepieceTokenizer(vocab=self.spm_file, unk_token=str(unk_token))
142
  else:
143
  raise ValueError(f"Invalid subword_tokenizer_type '{subword_tokenizer_type}' is specified.")
144
 
 
185
 
186
  self.max_mention_length = max_mention_length
187
 
188
+ # We call the grandparent's init, not the parent's.
189
+ super(LukeTokenizer, self).__init__(
190
+ spm_file=spm_file,
191
+ unk_token=unk_token,
192
+ sep_token=sep_token,
193
+ pad_token=pad_token,
194
+ cls_token=cls_token,
195
+ mask_token=mask_token,
196
+ do_lower_case=do_lower_case,
197
+ do_word_tokenize=do_word_tokenize,
198
+ do_subword_tokenize=do_subword_tokenize,
199
+ word_tokenizer_type=word_tokenizer_type,
200
+ subword_tokenizer_type=subword_tokenizer_type,
201
+ never_split=never_split,
202
+ mecab_kwargs=mecab_kwargs,
203
+ sudachi_kwargs=sudachi_kwargs,
204
+ jumanpp_kwargs=jumanpp_kwargs,
205
+ task=task,
206
+ max_entity_length=32,
207
+ max_mention_length=30,
208
+ entity_token_1="<ent>",
209
+ entity_token_2="<ent2>",
210
+ entity_unk_token=entity_unk_token,
211
+ entity_pad_token=entity_pad_token,
212
+ entity_mask_token=entity_mask_token,
213
+ entity_mask2_token=entity_mask2_token,
214
+ **kwargs,
215
+ )
216
+
217
  @property
218
  # Copied from BertJapaneseTokenizer
219
  def do_lower_case(self):
 
300
  """
301
  Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
302
  adding special tokens. A BERT sequence has the following format:
 
303
  - single sequence: `[CLS] X [SEP]`
304
  - pair of sequences: `[CLS] A [SEP] B [SEP]`
 
305
  Args:
306
  token_ids_0 (`List[int]`):
307
  List of IDs to which the special tokens will be added.
308
  token_ids_1 (`List[int]`, *optional*):
309
  Optional second list of IDs for sequence pairs.
 
310
  Returns:
311
  `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
312
  """
 
323
  """
324
  Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
325
  special tokens using the tokenizer `prepare_for_model` method.
 
326
  Args:
327
  token_ids_0 (`List[int]`):
328
  List of IDs.
 
330
  Optional second list of IDs for sequence pairs.
331
  already_has_special_tokens (`bool`, *optional*, defaults to `False`):
332
  Whether or not the token list is already formatted with special tokens for the model.
 
333
  Returns:
334
  `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
335
  """
 
350
  """
351
  Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
352
  pair mask has the following format:
 
353
  ```
354
  0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
355
  | first sequence | second sequence |
356
  ```
 
357
  If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
 
358
  Args:
359
  token_ids_0 (`List[int]`):
360
  List of IDs.
361
  token_ids_1 (`List[int]`, *optional*):
362
  Optional second list of IDs for sequence pairs.
 
363
  Returns:
364
  `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
365
  """
 
369
  return len(cls + token_ids_0 + sep) * [0]
370
  return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
371
 
372
+ # Copied and modified from LukeTokenizer, removing the `add_prefix_space` process
373
  def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
374
  return (text, kwargs)
375
 
376
+ # Copied and modified from LukeTokenizer, adding `position_ids` to the output
377
+ def prepare_for_model(
378
+ self,
379
+ ids: List[int],
380
+ pair_ids: Optional[List[int]] = None,
381
+ entity_ids: Optional[List[int]] = None,
382
+ pair_entity_ids: Optional[List[int]] = None,
383
+ entity_token_spans: Optional[List[Tuple[int, int]]] = None,
384
+ pair_entity_token_spans: Optional[List[Tuple[int, int]]] = None,
385
+ add_special_tokens: bool = True,
386
+ padding: Union[bool, str, PaddingStrategy] = False,
387
+ truncation: Union[bool, str, TruncationStrategy] = None,
388
+ max_length: Optional[int] = None,
389
+ max_entity_length: Optional[int] = None,
390
+ stride: int = 0,
391
+ pad_to_multiple_of: Optional[int] = None,
392
+ return_tensors: Optional[Union[str, TensorType]] = None,
393
+ return_token_type_ids: Optional[bool] = None,
394
+ return_attention_mask: Optional[bool] = None,
395
+ return_overflowing_tokens: bool = False,
396
+ return_special_tokens_mask: bool = False,
397
+ return_offsets_mapping: bool = False,
398
+ return_length: bool = False,
399
+ verbose: bool = True,
400
+ prepend_batch_axis: bool = False,
401
+ **kwargs,
402
+ ) -> BatchEncoding:
403
+ """
404
+ Prepares a sequence of input id, entity id and entity span, or a pair of sequences of inputs ids, entity ids,
405
+ entity spans so that it can be used by the model. It adds special tokens, truncates sequences if overflowing
406
+ while taking into account the special tokens and manages a moving window (with user defined stride) for
407
+ overflowing tokens. Please Note, for *pair_ids* different than `None` and *truncation_strategy = longest_first*
408
+ or `True`, it is not possible to return overflowing tokens. Such a combination of arguments will raise an
409
+ error.
410
+
411
+ Args:
412
+ ids (`List[int]`):
413
+ Tokenized input ids of the first sequence.
414
+ pair_ids (`List[int]`, *optional*):
415
+ Tokenized input ids of the second sequence.
416
+ entity_ids (`List[int]`, *optional*):
417
+ Entity ids of the first sequence.
418
+ pair_entity_ids (`List[int]`, *optional*):
419
+ Entity ids of the second sequence.
420
+ entity_token_spans (`List[Tuple[int, int]]`, *optional*):
421
+ Entity spans of the first sequence.
422
+ pair_entity_token_spans (`List[Tuple[int, int]]`, *optional*):
423
+ Entity spans of the second sequence.
424
+ max_entity_length (`int`, *optional*):
425
+ The maximum length of the entity sequence.
426
+ """
427
+
428
+ # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
429
+ padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
430
+ padding=padding,
431
+ truncation=truncation,
432
+ max_length=max_length,
433
+ pad_to_multiple_of=pad_to_multiple_of,
434
+ verbose=verbose,
435
+ **kwargs,
436
+ )
437
+
438
+ # Compute lengths
439
+ pair = bool(pair_ids is not None)
440
+ len_ids = len(ids)
441
+ len_pair_ids = len(pair_ids) if pair else 0
442
+
443
+ if return_token_type_ids and not add_special_tokens:
444
+ raise ValueError(
445
+ "Asking to return token_type_ids while setting add_special_tokens to False "
446
+ "results in an undefined behavior. Please set add_special_tokens to True or "
447
+ "set return_token_type_ids to None."
448
+ )
449
+ if (
450
+ return_overflowing_tokens
451
+ and truncation_strategy == TruncationStrategy.LONGEST_FIRST
452
+ and pair_ids is not None
453
+ ):
454
+ raise ValueError(
455
+ "Not possible to return overflowing tokens for pair of sequences with the "
456
+ "`longest_first`. Please select another truncation strategy than `longest_first`, "
457
+ "for instance `only_second` or `only_first`."
458
+ )
459
+
460
+ # Load from model defaults
461
+ if return_token_type_ids is None:
462
+ return_token_type_ids = "token_type_ids" in self.model_input_names
463
+ if return_attention_mask is None:
464
+ return_attention_mask = "attention_mask" in self.model_input_names
465
+
466
+ encoded_inputs = {}
467
+
468
+ # Compute the total size of the returned word encodings
469
+ total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
470
+
471
+ # Truncation: Handle max sequence length and max_entity_length
472
+ overflowing_tokens = []
473
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
474
+ # truncate words up to max_length
475
+ ids, pair_ids, overflowing_tokens = self.truncate_sequences(
476
+ ids,
477
+ pair_ids=pair_ids,
478
+ num_tokens_to_remove=total_len - max_length,
479
+ truncation_strategy=truncation_strategy,
480
+ stride=stride,
481
+ )
482
+
483
+ if return_overflowing_tokens:
484
+ encoded_inputs["overflowing_tokens"] = overflowing_tokens
485
+ encoded_inputs["num_truncated_tokens"] = total_len - max_length
486
+
487
+ # Add special tokens
488
+ if add_special_tokens:
489
+ sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
490
+ token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
491
+ entity_token_offset = 1 # 1 * <s> token
492
+ pair_entity_token_offset = len(ids) + 3 # 1 * <s> token & 2 * <sep> tokens
493
+ else:
494
+ sequence = ids + pair_ids if pair else ids
495
+ token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
496
+ entity_token_offset = 0
497
+ pair_entity_token_offset = len(ids)
498
+
499
+ # Build output dictionary
500
+ encoded_inputs["input_ids"] = sequence
501
+ encoded_inputs["position_ids"] = list(range(len(sequence)))
502
+ if return_token_type_ids:
503
+ encoded_inputs["token_type_ids"] = token_type_ids
504
+ if return_special_tokens_mask:
505
+ if add_special_tokens:
506
+ encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
507
+ else:
508
+ encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
509
+
510
+ # Set max entity length
511
+ if not max_entity_length:
512
+ max_entity_length = self.max_entity_length
513
+
514
+ if entity_ids is not None:
515
+ total_entity_len = 0
516
+ num_invalid_entities = 0
517
+ valid_entity_ids = [ent_id for ent_id, span in zip(entity_ids, entity_token_spans) if span[1] <= len(ids)]
518
+ valid_entity_token_spans = [span for span in entity_token_spans if span[1] <= len(ids)]
519
+
520
+ total_entity_len += len(valid_entity_ids)
521
+ num_invalid_entities += len(entity_ids) - len(valid_entity_ids)
522
+
523
+ valid_pair_entity_ids, valid_pair_entity_token_spans = None, None
524
+ if pair_entity_ids is not None:
525
+ valid_pair_entity_ids = [
526
+ ent_id
527
+ for ent_id, span in zip(pair_entity_ids, pair_entity_token_spans)
528
+ if span[1] <= len(pair_ids)
529
+ ]
530
+ valid_pair_entity_token_spans = [span for span in pair_entity_token_spans if span[1] <= len(pair_ids)]
531
+ total_entity_len += len(valid_pair_entity_ids)
532
+ num_invalid_entities += len(pair_entity_ids) - len(valid_pair_entity_ids)
533
+
534
+ if num_invalid_entities != 0:
535
+ logger.warning(
536
+ f"{num_invalid_entities} entities are ignored because their entity spans are invalid due to the"
537
+ " truncation of input tokens"
538
+ )
539
+
540
+ if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and total_entity_len > max_entity_length:
541
+ # truncate entities up to max_entity_length
542
+ valid_entity_ids, valid_pair_entity_ids, overflowing_entities = self.truncate_sequences(
543
+ valid_entity_ids,
544
+ pair_ids=valid_pair_entity_ids,
545
+ num_tokens_to_remove=total_entity_len - max_entity_length,
546
+ truncation_strategy=truncation_strategy,
547
+ stride=stride,
548
+ )
549
+ valid_entity_token_spans = valid_entity_token_spans[: len(valid_entity_ids)]
550
+ if valid_pair_entity_token_spans is not None:
551
+ valid_pair_entity_token_spans = valid_pair_entity_token_spans[: len(valid_pair_entity_ids)]
552
+
553
+ if return_overflowing_tokens:
554
+ encoded_inputs["overflowing_entities"] = overflowing_entities
555
+ encoded_inputs["num_truncated_entities"] = total_entity_len - max_entity_length
556
+
557
+ final_entity_ids = valid_entity_ids + valid_pair_entity_ids if valid_pair_entity_ids else valid_entity_ids
558
+ encoded_inputs["entity_ids"] = list(final_entity_ids)
559
+ entity_position_ids = []
560
+ entity_start_positions = []
561
+ entity_end_positions = []
562
+ for token_spans, offset in (
563
+ (valid_entity_token_spans, entity_token_offset),
564
+ (valid_pair_entity_token_spans, pair_entity_token_offset),
565
+ ):
566
+ if token_spans is not None:
567
+ for start, end in token_spans:
568
+ start += offset
569
+ end += offset
570
+ position_ids = list(range(start, end))[: self.max_mention_length]
571
+ position_ids += [-1] * (self.max_mention_length - end + start)
572
+ entity_position_ids.append(position_ids)
573
+ entity_start_positions.append(start)
574
+ entity_end_positions.append(end - 1)
575
+
576
+ encoded_inputs["entity_position_ids"] = entity_position_ids
577
+ if self.task == "entity_span_classification":
578
+ encoded_inputs["entity_start_positions"] = entity_start_positions
579
+ encoded_inputs["entity_end_positions"] = entity_end_positions
580
+
581
+ if return_token_type_ids:
582
+ encoded_inputs["entity_token_type_ids"] = [0] * len(encoded_inputs["entity_ids"])
583
+
584
+ # Check lengths
585
+ self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
586
+
587
+ # Padding
588
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
589
+ encoded_inputs = self.pad(
590
+ encoded_inputs,
591
+ max_length=max_length,
592
+ max_entity_length=max_entity_length,
593
+ padding=padding_strategy.value,
594
+ pad_to_multiple_of=pad_to_multiple_of,
595
+ return_attention_mask=return_attention_mask,
596
+ )
597
+
598
+ if return_length:
599
+ encoded_inputs["length"] = len(encoded_inputs["input_ids"])
600
+
601
+ batch_outputs = BatchEncoding(
602
+ encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
603
+ )
604
+
605
+ return batch_outputs
606
+
607
+ # Copied and modified from LukeTokenizer, adding the padding of `position_ids`
608
+ def _pad(
609
+ self,
610
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
611
+ max_length: Optional[int] = None,
612
+ max_entity_length: Optional[int] = None,
613
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
614
+ pad_to_multiple_of: Optional[int] = None,
615
+ return_attention_mask: Optional[bool] = None,
616
+ ) -> dict:
617
+ """
618
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
619
+
620
+
621
+ Args:
622
+ encoded_inputs:
623
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
624
+ max_length: maximum length of the returned list and optionally padding length (see below).
625
+ Will truncate by taking into account the special tokens.
626
+ max_entity_length: The maximum length of the entity sequence.
627
+ padding_strategy: PaddingStrategy to use for padding.
628
+
629
+
630
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
631
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
632
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
633
+ The tokenizer padding sides are defined in self.padding_side:
634
+
635
+
636
+ - 'left': pads on the left of the sequences
637
+ - 'right': pads on the right of the sequences
638
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
639
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
640
+ `>= 7.5` (Volta).
641
+ return_attention_mask:
642
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
643
+ """
644
+ entities_provided = bool("entity_ids" in encoded_inputs)
645
+
646
+ # Load from model defaults
647
+ if return_attention_mask is None:
648
+ return_attention_mask = "attention_mask" in self.model_input_names
649
+
650
+ if padding_strategy == PaddingStrategy.LONGEST:
651
+ max_length = len(encoded_inputs["input_ids"])
652
+ if entities_provided:
653
+ max_entity_length = len(encoded_inputs["entity_ids"])
654
+
655
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
656
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
657
+
658
+ if (
659
+ entities_provided
660
+ and max_entity_length is not None
661
+ and pad_to_multiple_of is not None
662
+ and (max_entity_length % pad_to_multiple_of != 0)
663
+ ):
664
+ max_entity_length = ((max_entity_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
665
+
666
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and (
667
+ len(encoded_inputs["input_ids"]) != max_length
668
+ or (entities_provided and len(encoded_inputs["entity_ids"]) != max_entity_length)
669
+ )
670
+
671
+ # Initialize attention mask if not present.
672
+ if return_attention_mask and "attention_mask" not in encoded_inputs:
673
+ encoded_inputs["attention_mask"] = [1] * len(encoded_inputs["input_ids"])
674
+ if entities_provided and return_attention_mask and "entity_attention_mask" not in encoded_inputs:
675
+ encoded_inputs["entity_attention_mask"] = [1] * len(encoded_inputs["entity_ids"])
676
+
677
+ if needs_to_be_padded:
678
+ difference = max_length - len(encoded_inputs["input_ids"])
679
+ if entities_provided:
680
+ entity_difference = max_entity_length - len(encoded_inputs["entity_ids"])
681
+ if self.padding_side == "right":
682
+ if return_attention_mask:
683
+ encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
684
+ if entities_provided:
685
+ encoded_inputs["entity_attention_mask"] = (
686
+ encoded_inputs["entity_attention_mask"] + [0] * entity_difference
687
+ )
688
+ if "token_type_ids" in encoded_inputs:
689
+ encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"] + [0] * difference
690
+ if entities_provided:
691
+ encoded_inputs["entity_token_type_ids"] = (
692
+ encoded_inputs["entity_token_type_ids"] + [0] * entity_difference
693
+ )
694
+ if "special_tokens_mask" in encoded_inputs:
695
+ encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
696
+ encoded_inputs["input_ids"] = encoded_inputs["input_ids"] + [self.pad_token_id] * difference
697
+ encoded_inputs["position_ids"] = encoded_inputs["position_ids"] + [0] * difference
698
+ if entities_provided:
699
+ encoded_inputs["entity_ids"] = (
700
+ encoded_inputs["entity_ids"] + [self.entity_pad_token_id] * entity_difference
701
+ )
702
+ encoded_inputs["entity_position_ids"] = (
703
+ encoded_inputs["entity_position_ids"] + [[-1] * self.max_mention_length] * entity_difference
704
+ )
705
+ if self.task == "entity_span_classification":
706
+ encoded_inputs["entity_start_positions"] = (
707
+ encoded_inputs["entity_start_positions"] + [0] * entity_difference
708
+ )
709
+ encoded_inputs["entity_end_positions"] = (
710
+ encoded_inputs["entity_end_positions"] + [0] * entity_difference
711
+ )
712
+
713
+ elif self.padding_side == "left":
714
+ if return_attention_mask:
715
+ encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
716
+ if entities_provided:
717
+ encoded_inputs["entity_attention_mask"] = [0] * entity_difference + encoded_inputs[
718
+ "entity_attention_mask"
719
+ ]
720
+ if "token_type_ids" in encoded_inputs:
721
+ encoded_inputs["token_type_ids"] = [0] * difference + encoded_inputs["token_type_ids"]
722
+ if entities_provided:
723
+ encoded_inputs["entity_token_type_ids"] = [0] * entity_difference + encoded_inputs[
724
+ "entity_token_type_ids"
725
+ ]
726
+ if "special_tokens_mask" in encoded_inputs:
727
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
728
+ encoded_inputs["input_ids"] = [self.pad_token_id] * difference + encoded_inputs["input_ids"]
729
+ encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
730
+ if entities_provided:
731
+ encoded_inputs["entity_ids"] = [self.entity_pad_token_id] * entity_difference + encoded_inputs[
732
+ "entity_ids"
733
+ ]
734
+ encoded_inputs["entity_position_ids"] = [
735
+ [-1] * self.max_mention_length
736
+ ] * entity_difference + encoded_inputs["entity_position_ids"]
737
+ if self.task == "entity_span_classification":
738
+ encoded_inputs["entity_start_positions"] = [0] * entity_difference + encoded_inputs[
739
+ "entity_start_positions"
740
+ ]
741
+ encoded_inputs["entity_end_positions"] = [0] * entity_difference + encoded_inputs[
742
+ "entity_end_positions"
743
+ ]
744
+ else:
745
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
746
+
747
+ return encoded_inputs
748
+
749
+ # Copied and modified from BertJapaneseTokenizer and LukeTokenizer
750
  def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
751
  if os.path.isdir(save_directory):
752
  if self.subword_tokenizer_type == "sentencepiece":