Source code for transformers.tokenization_bert_generation

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model BertGeneration."""

import os
from shutil import copyfile
from typing import List, Optional, Tuple

import sentencepiece as spm

from .tokenization_utils import PreTrainedTokenizer
from .utils import logging

logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}

tokenizer_url = ""

[docs]class BertGenerationTokenizer(PreTrainedTokenizer): """ Construct a BertGeneration tokenizer. Based on `SentencePiece <>`__. This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. Args: vocab_file (:obj:`str`): `SentencePiece <>`__ file (generally has a `.spm` extension) that contains the vocabulary necessary to instantiate a tokenizer. eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`): The end of sequence token. bos_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`): The begin of sequence token. unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`): The token used for padding, for example when batching sequences of different lengths. """ vocab_files_names = VOCAB_FILES_NAMES pretrained_vocab_files_map = {"vocab_file": {"bert_for_seq_generation": tokenizer_url}} max_model_input_sizes = {"bert_for_seq_generation": 512} prefix_tokens: List[int] = [] def __init__( self, vocab_file, bos_token="<s>", eos_token="</s>", unk_token="<unk>", pad_token="<pad>", sep_token="<::::>", **kwargs ): # Add extra_ids to the special token list super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, sep_token=sep_token, **kwargs, ) self.vocab_file = vocab_file self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(vocab_file) @property def vocab_size(self): return self.sp_model.get_piece_size() def get_vocab(self): vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None return state def __setstate__(self, d): self.__dict__ = d self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(self.vocab_file) def _tokenize(self, text, sample=False): """Take as input a string and return a list of strings (tokens) for words/sub-words""" if not sample: pieces = self.sp_model.EncodeAsPieces(text) else: pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) return pieces def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ return self.sp_model.piece_to_id(token) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" token = self.sp_model.IdToPiece(index) return token def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ out_string = self.sp_model.decode_pieces(tokens) return out_string
[docs] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: if not os.path.isdir(save_directory): logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return out_vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,)