Source code for transformers.tokenization_bert_generation

# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model BertGeneration."""

import os
from shutil import copyfile
from typing import List

from .tokenization_utils import PreTrainedTokenizer
from .utils import logging

logger = logging.get_logger(__name__)

VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}

tokenizer_url = (

[docs]class BertGenerationTokenizer(PreTrainedTokenizer): """ Constructs a BertGenerationTokenizer tokenizer. Based on `SentencePiece <>`__ . This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users should refer to the superclass for more information regarding methods. Args: vocab_file (:obj:`string`): `SentencePiece <>`__ file (generally has a `.spm` extension) that contains the vocabulary necessary to instantiate a tokenizer. eos_token (:obj:`string`, `optional`, defaults to :obj:`"</s>"`): The end of sequence token. bos_token (:obj:`string`, `optional`, defaults to :obj:`"<s>"`): The begin of sequence token. unk_token (:obj:`string`, `optional`, defaults to :obj:`"<unk>"`): The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this token instead. pad_token (:obj:`string`, `optional`, defaults to :obj:`"<pad>"`): The token used for padding, for example when batching sequences of different lengths. """ vocab_files_names = VOCAB_FILES_NAMES prefix_tokens: List[int] = [] def __init__( self, vocab_file, bos_token="<s>", eos_token="</s>", unk_token="<unk>", pad_token="<pad>", sep_token="<::::>", **kwargs ): # Add extra_ids to the special token list super().__init__( bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, sep_token=sep_token, **kwargs, ) try: import sentencepiece as spm except ImportError: logger.warning( "You need to install SentencePiece to use T5Tokenizer:" "" "pip install sentencepiece" ) raise self.vocab_file = vocab_file self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(vocab_file) @property def vocab_size(self): return self.sp_model.get_piece_size()
[docs] def get_vocab(self): vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} vocab.update(self.added_tokens_encoder) return vocab
def __getstate__(self): state = self.__dict__.copy() state["sp_model"] = None return state def __setstate__(self, d): self.__dict__ = d try: import sentencepiece as spm except ImportError: logger.warning( "You need to install SentencePiece to use BertGenerationTokenizer:" "pip install sentencepiece" ) raise self.sp_model = spm.SentencePieceProcessor() self.sp_model.Load(self.vocab_file) def _tokenize(self, text, sample=False): """Take as input a string and return a list of strings (tokens) for words/sub-words""" if not sample: pieces = self.sp_model.EncodeAsPieces(text) else: pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) return pieces def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ return self.sp_model.piece_to_id(token) def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" token = self.sp_model.IdToPiece(index) return token
[docs] def convert_tokens_to_string(self, tokens): """ Converts a sequence of tokens (string) in a single string. """ out_string = self.sp_model.decode_pieces(tokens) return out_string
[docs] def save_vocabulary(self, save_directory): """Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory. """ if not os.path.isdir(save_directory): logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"]) if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,)