File size: 4,426 Bytes
93e9af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30eca34
 
93e9af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import string

from tokenizers import (
    Tokenizer as HFTokenizer,
    normalizers,
    pre_tokenizers,
    models,
    decoders,
)
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast


class LingLongTokenizerFast(PreTrainedTokenizerFast):
    vocab_files_names = {'vocab_file': 'tokenizer.txt', 'tokenizer_file': 'tokenizer.json'}
    model_input_names = ['input_ids', 'attention_mask']

    class CustomDecoder:

        @staticmethod
        def decode_chain(tokens: list[str]) -> list[str]:
            new_tokens = []
            for token in tokens:
                if token.startswith('##'):
                    new_tokens.append(token[2:])
                else:
                    new_tokens.append(' ' + token)

            # Remove whitespaces between Chinese characters.
            # TODO: This will remove whitespaces between some English words as well. Need fix.
            alphabet_set = set(list(string.ascii_letters))
            for i in range(len(new_tokens)):
                if new_tokens[i][0] == ' ':
                    if new_tokens[i][1] not in alphabet_set or i == 0:
                        new_tokens[i] = new_tokens[i][1:]
            return new_tokens

    def __init__(

            self,

            vocab_file: str | None = None,

            tokenizer_file: str | None = None,

            do_lower_case: bool = True,

            do_basic_tokenize: bool = True,

            unk_token: str = '<unk>',

            sep_token: str = '<sep>',

            pad_token: str = '<pad>',

            cls_token: str = '<cls>',

            mask_token: str = '<mask>',

            bos_token: str = '<|startoftext|>',

            eos_token: str = '<|endoftext|>',

            tokenize_chinese_chars: bool = True,

            strip_accents: bool | None = None,

            **kwargs,

    ):
        backend_tokenizer = None
        if tokenizer_file is None:
            backend_tokenizer = HFTokenizer(
                models.WordPiece.from_file(
                    vocab=vocab_file,
                    unk_token=unk_token,
                    max_input_chars_per_word=100,
                ),
            )
            backend_tokenizer.add_special_tokens(
                [unk_token, sep_token, pad_token, cls_token, mask_token, bos_token, eos_token],
            )
            normalizer_sequence = [normalizers.Replace('\n', sep_token)]
            if do_basic_tokenize:
                normalizer_sequence.append(
                    normalizers.BertNormalizer(
                        handle_chinese_chars=tokenize_chinese_chars,
                        strip_accents=strip_accents,
                        lowercase=do_lower_case,
                    ),
                )
            backend_tokenizer.normalizer = normalizers.Sequence(normalizer_sequence)
            backend_tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
                pre_tokenizers.Digits(individual_digits=True),
                pre_tokenizers.Punctuation(),
                pre_tokenizers.WhitespaceSplit(),
            ])
        super().__init__(
            tokenizer_file=tokenizer_file,
            tokenizer_object=backend_tokenizer,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            mask_token=mask_token,
            bos_token=bos_token,
            eos_token=eos_token,
            do_lower_case=do_lower_case,
            do_basic_tokenize=do_basic_tokenize,
            tokenize_chinese_chars=tokenize_chinese_chars,
            strip_accents=strip_accents,
            **kwargs,
        )
        self._tokenizer.decoder = decoders.Decoder.custom(self.CustomDecoder())
        self.add_special_tokens({'additional_special_tokens': [f'<unused{i}>' for i in range(1, 11)]})
        self.chat_template = '{{ bos_token }}{{ "问题:" }}{{ messages[-1]["content"] }}{{ "<unused1>答案:" }}'

    def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
        files = self.backend_tokenizer.model.save(save_directory, name=filename_prefix)
        return tuple(files)

    def save_pretrained(self, *args, **kwargs) -> tuple[str]:
        self._tokenizer.decoder = decoders.WordPiece()
        return super().save_pretrained(*args, **kwargs)