File size: 3,122 Bytes
814ee6b
 
1b7fc74
814ee6b
1b7fc74
814ee6b
 
 
 
 
 
 
1b7fc74
814ee6b
1b7fc74
 
 
 
814ee6b
 
1b7fc74
814ee6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7fc74
814ee6b
 
 
 
1b7fc74
814ee6b
 
 
 
 
1b7fc74
814ee6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7fc74
814ee6b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""

## adapt to transformer tokenizer

https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/tokenization_utils.py#L379

## usage

- grok

## 风险评估

- 可能会干扰 sentencepiece.SentencePieceProcessor的正常使用,比如 .vocab_size 原来是个方法,patch后是个property


## TODO

不用patch,改用wrapper。常见的 tokenizer通常是封装的 sentencepiece,
"""

import sentencepiece


@property
def vocab_size(self):
    """Returns vocab size"""
    return self.get_piece_size()


def get_vocab(self):
    """Returns vocab as a dict"""
    vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
    # vocab.update(self.added_tokens_encoder)
    return vocab


def _tokenize(self, text):
    """Returns a tokenized string."""
    return self.encode(text, out_type=str)


def _convert_token_to_id(self, token):
    """Converts a token (str) in an id using the vocab."""
    return self.piece_to_id(token)


def _convert_id_to_token(self, index):
    """Converts an index (integer) in a token (str) using the vocab."""
    token = self.IdToPiece(index)
    return token


def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
    """ copy from transformers.PreTrainedTokenizer
    Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
    added tokens.

    Args:
        ids (`int` or `List[int]`):
            The token id (or token ids) to convert to tokens.
        skip_special_tokens (`bool`, *optional*, defaults to `False`):
            Whether or not to remove special tokens in the decoding.

    Returns:
        `str` or `List[str]`: The decoded token(s).
    """
    self._added_tokens_decoder = {}  # add by xs
    if isinstance(ids, int):
        if ids in self._added_tokens_decoder:
            return self._added_tokens_decoder[ids].content
        else:
            return self._convert_id_to_token(ids)
    tokens = []
    for index in ids:
        index = int(index)
        if skip_special_tokens and index in self.all_special_ids:
            continue
        if index in self._added_tokens_decoder:
            tokens.append(self._added_tokens_decoder[index].content)
        else:
            tokens.append(self._convert_id_to_token(index))
    return tokens


def encode(self, *args, **kwargs):
    """
    add_special_token 是为了兼容 hf_tokenizer
    """
    kwargs.pop("add_special_tokens", None)
    kwargs.pop("allowed_special", None)
    return self.Encode(*args, **kwargs)


def decode(self, *args, **kwargs):
    kwargs.pop("skip_special_tokens", None)
    return self.Decode(*args, **kwargs)


sentencepiece.SentencePieceProcessor.vocab_size = vocab_size  #
sentencepiece.SentencePieceProcessor.get_vocab = get_vocab
sentencepiece.SentencePieceProcessor._convert_id_to_token = _convert_id_to_token
sentencepiece.SentencePieceProcessor.convert_ids_to_tokens = convert_ids_to_tokens
# sentencepiece.SentencePieceProcessor.tokenize = _tokenize
sentencepiece.SentencePieceProcessor.encode = encode
sentencepiece.SentencePieceProcessor.decode = decode