Spaces:
Runtime error
Runtime error
import os | |
from typing import List, Union | |
import tensorflow as tf | |
from tensorflow_text import BertTokenizer as BertTokenizerLayer | |
from tensorflow_text import FastBertTokenizer, ShrinkLongestTrimmer, case_fold_utf8, combine_segments, pad_model_inputs | |
from .tokenization_bert import BertTokenizer | |
class TFBertTokenizer(tf.keras.layers.Layer): | |
""" | |
This is an in-graph tokenizer for BERT. It should be initialized similarly to other tokenizers, using the | |
`from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings | |
from an existing standard tokenizer object. | |
In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run | |
when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options | |
than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes | |
straight from `tf.string` inputs to outputs. | |
Args: | |
vocab_list (`list`): | |
List containing the vocabulary. | |
do_lower_case (`bool`, *optional*, defaults to `True`): | |
Whether or not to lowercase the input when tokenizing. | |
cls_token_id (`str`, *optional*, defaults to `"[CLS]"`): | |
The classifier token which is used when doing sequence classification (classification of the whole sequence | |
instead of per-token classification). It is the first token of the sequence when built with special tokens. | |
sep_token_id (`str`, *optional*, defaults to `"[SEP]"`): | |
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for | |
sequence classification or for a text and a question for question answering. It is also used as the last | |
token of a sequence built with special tokens. | |
pad_token_id (`str`, *optional*, defaults to `"[PAD]"`): | |
The token used for padding, for example when batching sequences of different lengths. | |
padding (`str`, defaults to `"longest"`): | |
The type of padding to use. Can be either `"longest"`, to pad only up to the longest sample in the batch, | |
or `"max_length", to pad all inputs to the maximum length supported by the tokenizer. | |
truncation (`bool`, *optional*, defaults to `True`): | |
Whether to truncate the sequence to the maximum length. | |
max_length (`int`, *optional*, defaults to `512`): | |
The maximum length of the sequence, used for padding (if `padding` is "max_length") and/or truncation (if | |
`truncation` is `True`). | |
pad_to_multiple_of (`int`, *optional*, defaults to `None`): | |
If set, the sequence will be padded to a multiple of this value. | |
return_token_type_ids (`bool`, *optional*, defaults to `True`): | |
Whether to return token_type_ids. | |
return_attention_mask (`bool`, *optional*, defaults to `True`): | |
Whether to return the attention_mask. | |
use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`): | |
If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer | |
class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to | |
TFLite. | |
""" | |
def __init__( | |
self, | |
vocab_list: List, | |
do_lower_case: bool, | |
cls_token_id: int = None, | |
sep_token_id: int = None, | |
pad_token_id: int = None, | |
padding: str = "longest", | |
truncation: bool = True, | |
max_length: int = 512, | |
pad_to_multiple_of: int = None, | |
return_token_type_ids: bool = True, | |
return_attention_mask: bool = True, | |
use_fast_bert_tokenizer: bool = True, | |
**tokenizer_kwargs, | |
): | |
super().__init__() | |
if use_fast_bert_tokenizer: | |
self.tf_tokenizer = FastBertTokenizer( | |
vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs | |
) | |
else: | |
lookup_table = tf.lookup.StaticVocabularyTable( | |
tf.lookup.KeyValueTensorInitializer( | |
keys=vocab_list, | |
key_dtype=tf.string, | |
values=tf.range(tf.size(vocab_list, out_type=tf.int64), dtype=tf.int64), | |
value_dtype=tf.int64, | |
), | |
num_oov_buckets=1, | |
) | |
self.tf_tokenizer = BertTokenizerLayer( | |
lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs | |
) | |
self.vocab_list = vocab_list | |
self.do_lower_case = do_lower_case | |
self.cls_token_id = cls_token_id or vocab_list.index("[CLS]") | |
self.sep_token_id = sep_token_id or vocab_list.index("[SEP]") | |
self.pad_token_id = pad_token_id or vocab_list.index("[PAD]") | |
self.paired_trimmer = ShrinkLongestTrimmer(max_length - 3, axis=1) # Allow room for special tokens | |
self.max_length = max_length | |
self.padding = padding | |
self.truncation = truncation | |
self.pad_to_multiple_of = pad_to_multiple_of | |
self.return_token_type_ids = return_token_type_ids | |
self.return_attention_mask = return_attention_mask | |
def from_tokenizer(cls, tokenizer: "PreTrainedTokenizerBase", **kwargs): # noqa: F821 | |
""" | |
Initialize a `TFBertTokenizer` from an existing `Tokenizer`. | |
Args: | |
tokenizer (`PreTrainedTokenizerBase`): | |
The tokenizer to use to initialize the `TFBertTokenizer`. | |
Examples: | |
```python | |
from transformers import AutoTokenizer, TFBertTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer) | |
``` | |
""" | |
do_lower_case = kwargs.pop("do_lower_case", None) | |
do_lower_case = tokenizer.do_lower_case if do_lower_case is None else do_lower_case | |
cls_token_id = kwargs.pop("cls_token_id", None) | |
cls_token_id = tokenizer.cls_token_id if cls_token_id is None else cls_token_id | |
sep_token_id = kwargs.pop("sep_token_id", None) | |
sep_token_id = tokenizer.sep_token_id if sep_token_id is None else sep_token_id | |
pad_token_id = kwargs.pop("pad_token_id", None) | |
pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id | |
vocab = tokenizer.get_vocab() | |
vocab = sorted(vocab.items(), key=lambda x: x[1]) | |
vocab_list = [entry[0] for entry in vocab] | |
return cls( | |
vocab_list=vocab_list, | |
do_lower_case=do_lower_case, | |
cls_token_id=cls_token_id, | |
sep_token_id=sep_token_id, | |
pad_token_id=pad_token_id, | |
**kwargs, | |
) | |
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): | |
""" | |
Instantiate a `TFBertTokenizer` from a pre-trained tokenizer. | |
Args: | |
pretrained_model_name_or_path (`str` or `os.PathLike`): | |
The name or path to the pre-trained tokenizer. | |
Examples: | |
```python | |
from transformers import TFBertTokenizer | |
tf_tokenizer = TFBertTokenizer.from_pretrained("bert-base-uncased") | |
``` | |
""" | |
try: | |
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) | |
except: # noqa: E722 | |
from .tokenization_bert_fast import BertTokenizerFast | |
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) | |
return cls.from_tokenizer(tokenizer, **kwargs) | |
def unpaired_tokenize(self, texts): | |
if self.do_lower_case: | |
texts = case_fold_utf8(texts) | |
tokens = self.tf_tokenizer.tokenize(texts) | |
return tokens.merge_dims(1, -1) | |
def call( | |
self, | |
text, | |
text_pair=None, | |
padding=None, | |
truncation=None, | |
max_length=None, | |
pad_to_multiple_of=None, | |
return_token_type_ids=None, | |
return_attention_mask=None, | |
): | |
if padding is None: | |
padding = self.padding | |
if padding not in ("longest", "max_length"): | |
raise ValueError("Padding must be either 'longest' or 'max_length'!") | |
if max_length is not None and text_pair is not None: | |
# Because we have to instantiate a Trimmer to do it properly | |
raise ValueError("max_length cannot be overridden at call time when truncating paired texts!") | |
if max_length is None: | |
max_length = self.max_length | |
if truncation is None: | |
truncation = self.truncation | |
if pad_to_multiple_of is None: | |
pad_to_multiple_of = self.pad_to_multiple_of | |
if return_token_type_ids is None: | |
return_token_type_ids = self.return_token_type_ids | |
if return_attention_mask is None: | |
return_attention_mask = self.return_attention_mask | |
if not isinstance(text, tf.Tensor): | |
text = tf.convert_to_tensor(text) | |
if text_pair is not None and not isinstance(text_pair, tf.Tensor): | |
text_pair = tf.convert_to_tensor(text_pair) | |
if text_pair is not None: | |
if text.shape.rank > 1: | |
raise ValueError("text argument should not be multidimensional when a text pair is supplied!") | |
if text_pair.shape.rank > 1: | |
raise ValueError("text_pair should not be multidimensional!") | |
if text.shape.rank == 2: | |
text, text_pair = text[:, 0], text[:, 1] | |
text = self.unpaired_tokenize(text) | |
if text_pair is None: # Unpaired text | |
if truncation: | |
text = text[:, : max_length - 2] # Allow room for special tokens | |
input_ids, token_type_ids = combine_segments( | |
(text,), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id | |
) | |
else: # Paired text | |
text_pair = self.unpaired_tokenize(text_pair) | |
if truncation: | |
text, text_pair = self.paired_trimmer.trim([text, text_pair]) | |
input_ids, token_type_ids = combine_segments( | |
(text, text_pair), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id | |
) | |
if padding == "longest": | |
pad_length = input_ids.bounding_shape(axis=1) | |
if pad_to_multiple_of is not None: | |
# No ceiling division in tensorflow, so we negate floordiv instead | |
pad_length = pad_to_multiple_of * (-tf.math.floordiv(-pad_length, pad_to_multiple_of)) | |
else: | |
pad_length = max_length | |
input_ids, attention_mask = pad_model_inputs(input_ids, max_seq_length=pad_length, pad_value=self.pad_token_id) | |
output = {"input_ids": input_ids} | |
if return_attention_mask: | |
output["attention_mask"] = attention_mask | |
if return_token_type_ids: | |
token_type_ids, _ = pad_model_inputs( | |
token_type_ids, max_seq_length=pad_length, pad_value=self.pad_token_id | |
) | |
output["token_type_ids"] = token_type_ids | |
return output | |
def get_config(self): | |
return { | |
"vocab_list": self.vocab_list, | |
"do_lower_case": self.do_lower_case, | |
"cls_token_id": self.cls_token_id, | |
"sep_token_id": self.sep_token_id, | |
"pad_token_id": self.pad_token_id, | |
} | |