import os from typing import List, Union import tensorflow as tf from tensorflow_text import BertTokenizer as BertTokenizerLayer from tensorflow_text import FastBertTokenizer, ShrinkLongestTrimmer, case_fold_utf8, combine_segments, pad_model_inputs from .tokenization_bert import BertTokenizer class TFBertTokenizer(tf.keras.layers.Layer): """ This is an in-graph tokenizer for BERT. It should be initialized similarly to other tokenizers, using the `from_pretrained()` method. It can also be initialized with the `from_tokenizer()` method, which imports settings from an existing standard tokenizer object. In-graph tokenizers, unlike other Hugging Face tokenizers, are actually Keras layers and are designed to be run when the model is called, rather than during preprocessing. As a result, they have somewhat more limited options than standard tokenizer classes. They are most useful when you want to create an end-to-end model that goes straight from `tf.string` inputs to outputs. Args: vocab_list (`list`): List containing the vocabulary. do_lower_case (`bool`, *optional*, defaults to `True`): Whether or not to lowercase the input when tokenizing. cls_token_id (`str`, *optional*, defaults to `"[CLS]"`): The classifier token which is used when doing sequence classification (classification of the whole sequence instead of per-token classification). It is the first token of the sequence when built with special tokens. sep_token_id (`str`, *optional*, defaults to `"[SEP]"`): The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for sequence classification or for a text and a question for question answering. It is also used as the last token of a sequence built with special tokens. pad_token_id (`str`, *optional*, defaults to `"[PAD]"`): The token used for padding, for example when batching sequences of different lengths. padding (`str`, defaults to `"longest"`): The type of padding to use. Can be either `"longest"`, to pad only up to the longest sample in the batch, or `"max_length", to pad all inputs to the maximum length supported by the tokenizer. truncation (`bool`, *optional*, defaults to `True`): Whether to truncate the sequence to the maximum length. max_length (`int`, *optional*, defaults to `512`): The maximum length of the sequence, used for padding (if `padding` is "max_length") and/or truncation (if `truncation` is `True`). pad_to_multiple_of (`int`, *optional*, defaults to `None`): If set, the sequence will be padded to a multiple of this value. return_token_type_ids (`bool`, *optional*, defaults to `True`): Whether to return token_type_ids. return_attention_mask (`bool`, *optional*, defaults to `True`): Whether to return the attention_mask. use_fast_bert_tokenizer (`bool`, *optional*, defaults to `True`): If True, will use the FastBertTokenizer class from Tensorflow Text. If False, will use the BertTokenizer class instead. BertTokenizer supports some additional options, but is slower and cannot be exported to TFLite. """ def __init__( self, vocab_list: List, do_lower_case: bool, cls_token_id: int = None, sep_token_id: int = None, pad_token_id: int = None, padding: str = "longest", truncation: bool = True, max_length: int = 512, pad_to_multiple_of: int = None, return_token_type_ids: bool = True, return_attention_mask: bool = True, use_fast_bert_tokenizer: bool = True, **tokenizer_kwargs, ): super().__init__() if use_fast_bert_tokenizer: self.tf_tokenizer = FastBertTokenizer( vocab_list, token_out_type=tf.int64, lower_case_nfd_strip_accents=do_lower_case, **tokenizer_kwargs ) else: lookup_table = tf.lookup.StaticVocabularyTable( tf.lookup.KeyValueTensorInitializer( keys=vocab_list, key_dtype=tf.string, values=tf.range(tf.size(vocab_list, out_type=tf.int64), dtype=tf.int64), value_dtype=tf.int64, ), num_oov_buckets=1, ) self.tf_tokenizer = BertTokenizerLayer( lookup_table, token_out_type=tf.int64, lower_case=do_lower_case, **tokenizer_kwargs ) self.vocab_list = vocab_list self.do_lower_case = do_lower_case self.cls_token_id = cls_token_id or vocab_list.index("[CLS]") self.sep_token_id = sep_token_id or vocab_list.index("[SEP]") self.pad_token_id = pad_token_id or vocab_list.index("[PAD]") self.paired_trimmer = ShrinkLongestTrimmer(max_length - 3, axis=1) # Allow room for special tokens self.max_length = max_length self.padding = padding self.truncation = truncation self.pad_to_multiple_of = pad_to_multiple_of self.return_token_type_ids = return_token_type_ids self.return_attention_mask = return_attention_mask @classmethod def from_tokenizer(cls, tokenizer: "PreTrainedTokenizerBase", **kwargs): # noqa: F821 """ Initialize a `TFBertTokenizer` from an existing `Tokenizer`. Args: tokenizer (`PreTrainedTokenizerBase`): The tokenizer to use to initialize the `TFBertTokenizer`. Examples: ```python from transformers import AutoTokenizer, TFBertTokenizer tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") tf_tokenizer = TFBertTokenizer.from_tokenizer(tokenizer) ``` """ do_lower_case = kwargs.pop("do_lower_case", None) do_lower_case = tokenizer.do_lower_case if do_lower_case is None else do_lower_case cls_token_id = kwargs.pop("cls_token_id", None) cls_token_id = tokenizer.cls_token_id if cls_token_id is None else cls_token_id sep_token_id = kwargs.pop("sep_token_id", None) sep_token_id = tokenizer.sep_token_id if sep_token_id is None else sep_token_id pad_token_id = kwargs.pop("pad_token_id", None) pad_token_id = tokenizer.pad_token_id if pad_token_id is None else pad_token_id vocab = tokenizer.get_vocab() vocab = sorted(vocab.items(), key=lambda x: x[1]) vocab_list = [entry[0] for entry in vocab] return cls( vocab_list=vocab_list, do_lower_case=do_lower_case, cls_token_id=cls_token_id, sep_token_id=sep_token_id, pad_token_id=pad_token_id, **kwargs, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], *init_inputs, **kwargs): """ Instantiate a `TFBertTokenizer` from a pre-trained tokenizer. Args: pretrained_model_name_or_path (`str` or `os.PathLike`): The name or path to the pre-trained tokenizer. Examples: ```python from transformers import TFBertTokenizer tf_tokenizer = TFBertTokenizer.from_pretrained("bert-base-uncased") ``` """ try: tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) except: # noqa: E722 from .tokenization_bert_fast import BertTokenizerFast tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, *init_inputs, **kwargs) return cls.from_tokenizer(tokenizer, **kwargs) def unpaired_tokenize(self, texts): if self.do_lower_case: texts = case_fold_utf8(texts) tokens = self.tf_tokenizer.tokenize(texts) return tokens.merge_dims(1, -1) def call( self, text, text_pair=None, padding=None, truncation=None, max_length=None, pad_to_multiple_of=None, return_token_type_ids=None, return_attention_mask=None, ): if padding is None: padding = self.padding if padding not in ("longest", "max_length"): raise ValueError("Padding must be either 'longest' or 'max_length'!") if max_length is not None and text_pair is not None: # Because we have to instantiate a Trimmer to do it properly raise ValueError("max_length cannot be overridden at call time when truncating paired texts!") if max_length is None: max_length = self.max_length if truncation is None: truncation = self.truncation if pad_to_multiple_of is None: pad_to_multiple_of = self.pad_to_multiple_of if return_token_type_ids is None: return_token_type_ids = self.return_token_type_ids if return_attention_mask is None: return_attention_mask = self.return_attention_mask if not isinstance(text, tf.Tensor): text = tf.convert_to_tensor(text) if text_pair is not None and not isinstance(text_pair, tf.Tensor): text_pair = tf.convert_to_tensor(text_pair) if text_pair is not None: if text.shape.rank > 1: raise ValueError("text argument should not be multidimensional when a text pair is supplied!") if text_pair.shape.rank > 1: raise ValueError("text_pair should not be multidimensional!") if text.shape.rank == 2: text, text_pair = text[:, 0], text[:, 1] text = self.unpaired_tokenize(text) if text_pair is None: # Unpaired text if truncation: text = text[:, : max_length - 2] # Allow room for special tokens input_ids, token_type_ids = combine_segments( (text,), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id ) else: # Paired text text_pair = self.unpaired_tokenize(text_pair) if truncation: text, text_pair = self.paired_trimmer.trim([text, text_pair]) input_ids, token_type_ids = combine_segments( (text, text_pair), start_of_sequence_id=self.cls_token_id, end_of_segment_id=self.sep_token_id ) if padding == "longest": pad_length = input_ids.bounding_shape(axis=1) if pad_to_multiple_of is not None: # No ceiling division in tensorflow, so we negate floordiv instead pad_length = pad_to_multiple_of * (-tf.math.floordiv(-pad_length, pad_to_multiple_of)) else: pad_length = max_length input_ids, attention_mask = pad_model_inputs(input_ids, max_seq_length=pad_length, pad_value=self.pad_token_id) output = {"input_ids": input_ids} if return_attention_mask: output["attention_mask"] = attention_mask if return_token_type_ids: token_type_ids, _ = pad_model_inputs( token_type_ids, max_seq_length=pad_length, pad_value=self.pad_token_id ) output["token_type_ids"] = token_type_ids return output def get_config(self): return { "vocab_list": self.vocab_list, "do_lower_case": self.do_lower_case, "cls_token_id": self.cls_token_id, "sep_token_id": self.sep_token_id, "pad_token_id": self.pad_token_id, }