File size: 5,484 Bytes
ecca05a
 
 
 
 
 
0ad13f5
ecca05a
0ad13f5
ecca05a
0ad13f5
ecca05a
0ad13f5
 
 
ecca05a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ad13f5
ecca05a
 
0ad13f5
ecca05a
0ad13f5
ecca05a
0ad13f5
 
 
 
 
 
 
 
 
 
 
 
 
ecca05a
 
 
 
 
 
 
 
 
 
 
 
 
 
0ad13f5
 
 
 
 
 
 
 
ecca05a
0ad13f5
 
 
 
ecca05a
0ad13f5
ecca05a
0ad13f5
ecca05a
 
 
0ad13f5
 
ecca05a
0ad13f5
 
ecca05a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104

"""
Forked from the file src/transformers/models/bert_generation/tokenization_bert_generation.py from the HuggingFace Transformers library.
Permalink: https://github.com/huggingface/transformers/blob/04ab5605fbb4ef207b10bf2772d88c53fc242e83/src/transformers/models/bert_generation/tokenization_bert_generation.py

"""
import os
import sentencepiece as spm
from shutil import copyfile
from transformers import PreTrainedTokenizer
from typing import Any, Dict, List, Optional, Tuple
VOCAB_FILES_NAMES = {'vocab_file': 'tokenizer.model'}

class BNTokenizer(PreTrainedTokenizer):
    """
      Construct a ReplitLMTokenizer tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
      This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods.

      Args:
          vocab_file (`str`):
              [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
              contains the vocabulary necessary to instantiate a tokenizer.
          eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
              The end of sequence token.
          bos_token (`str`, *optional*, defaults to `None`):
              The begin of sequence token.
          unk_token (`str`, *optional*, defaults to `"<|unk|>"`):
              The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
              token instead.
          pad_token (`str`, *optional*, defaults to `"<|pad|>"`):
              The token used for padding, for example when batching sequences of different lengths.
          sp_model_kwargs (`dict`, *optional*):
              Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
              SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
              to set:
              - `enable_sampling`: Enable subword regularization.
              - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
                - `nbest_size = {0,1}`: No sampling is performed.
                - `nbest_size > 1`: samples from the nbest_size results.
                - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
                  using forward-filtering-and-backward-sampling algorithm.
              - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
                BPE-dropout.
      """
    vocab_files_names = VOCAB_FILES_NAMES
    prefix_tokens: List[int] = []
    model_input_names = ['input_ids', 'attention_mask']

    def __init__(self, vocab_file, bos_token=None, eos_token='</s>', unk_token='<unk>', pad_token='<|reserved001|>', sep_token=None, sp_model_kwargs: Optional[Dict[str, Any]]=None, **kwargs) -> None:
        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
        super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, pad_token=pad_token, sep_token=sep_token, sp_model_kwargs=self.sp_model_kwargs, **kwargs)
        self.vocab_file = vocab_file
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.Load(vocab_file)

    @property
    def vocab_size(self):
        return self.sp_model.get_piece_size()

    def get_vocab(self):
        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    def __getstate__(self):
        state = self.__dict__.copy()
        state['sp_model'] = None
        return state

    def __setstate__(self, d):
        self.__dict__ = d
        if not hasattr(self, 'sp_model_kwargs'):
            self.sp_model_kwargs = {}
        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
        self.sp_model.load(self.vocab_file)

    def _tokenize(self, text: str) -> List[str]:
        """Take as input a string and return a list of strings (tokens) for words/sub-words"""
        return self.sp_model.encode(text, out_type=str)

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.sp_model.piece_to_id(token)

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        token = self.sp_model.id_to_piece(index)
        return token

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        return self.sp_model.decode(tokens)

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str]=None) -> Tuple[str]:
        if not os.path.isdir(save_directory):
            raise ValueError(f'Vocabulary path ({save_directory}) should be a directory')
        out_vocab_file = os.path.join(save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'])
        if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
            copyfile(self.vocab_file, out_vocab_file)
        elif not os.path.isfile(self.vocab_file):
            with open(out_vocab_file, 'wb') as fi:
                content_spiece_model = self.sp_model.serialized_model_proto()
                fi.write(content_spiece_model)
        return (out_vocab_file,)