File size: 2,552 Bytes
a6c67ec
 
 
 
 
 
 
 
adcfb97
c766a08
 
 
 
 
 
 
 
a6c67ec
 
 
c766a08
 
a6c67ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adcfb97
a6c67ec
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

from tiktoken import Encoding
from utils.log_util import logger

def decode(self, tokens, errors="replace", skip_special_tokens=False):
    """
    默认的decode,可能会报错,详见 decode_test.py
    skip_special_tokens 是为了兼容 hf_tokenizer

    errors:
        decoded bytes are not guaranteed to be valid UTF-8.
        "strict"	Raise UnicodeError
        "ignore"	Ignore and continue
        "replace"	Replace with replacement character
        "backslashreplace"	Replace with backslashed escape sequence
        "xmlcharrefreplace"	Replace with XML character reference
        "namereplace"	Replace with \N{...} (named unicode character)
    """
    try:
        decode_str = self._core_bpe.decode_bytes(tokens).decode("utf-8", errors=errors)
    except Exception as e:
        logger.error(f"{e} -> return 'null'")
        decode_str = "null"
    return decode_str


def convert_ids_to_tokens(self, tokens, skip_special_tokens=False):
    """
    为什么没有这个方法?
    """
    try:
        return self.decode_tokens_bytes(tokens)
    except Exception as e:
        # 什么要返回None?见zh_util.py
        # 16个空闲id, 100256 100261-100275
        logger.error(e)
        return [None for _ in tokens]


def get_vocab(self, token_type="str"):
    """Returns vocab as a dict
    :param token_type: ["str", "byte"]
    :return:
    """
    vocab = {}
    key_error_list = []
    unicode_decode_error_list = []
    for i in range(self.vocab_size):
        try:
            token_byte = self.convert_ids_to_tokens([i])[0]
            if token_byte is None:
                continue
            # token_str = token_byte.decode("utf-8")
            vocab[token_byte] = i

        except UnicodeDecodeError:  # 773 UnicodeDecodeError
            unicode_decode_error_list.append((i, str(token_byte)))
            vocab[token_byte] = i

    # vocab.update(self.added_tokens_encoder)
    logger.info(f"{self.name} {len(key_error_list)} KeyError: {key_error_list}")
    logger.info(f"{self.name} {len(unicode_decode_error_list)} UnicodeDecodeError: {unicode_decode_error_list[:5]}")
    return vocab


def encode(self, *args, **kwargs):
    """
    add_special_token 是为了兼容 hf_tokenizer
    """
    kwargs.pop("add_special_tokens", None)
    kwargs["allowed_special"] = "all"
    return self._encode(*args, **kwargs)


# tiktoken patch
Encoding._encode = Encoding.encode
Encoding.encode = encode
Encoding.decode = decode
Encoding.convert_ids_to_tokens = convert_ids_to_tokens
Encoding.get_vocab = get_vocab