File size: 4,426 Bytes
93e9af8 30eca34 93e9af8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import string
from tokenizers import (
Tokenizer as HFTokenizer,
normalizers,
pre_tokenizers,
models,
decoders,
)
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
class LingLongTokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = {'vocab_file': 'tokenizer.txt', 'tokenizer_file': 'tokenizer.json'}
model_input_names = ['input_ids', 'attention_mask']
class CustomDecoder:
@staticmethod
def decode_chain(tokens: list[str]) -> list[str]:
new_tokens = []
for token in tokens:
if token.startswith('##'):
new_tokens.append(token[2:])
else:
new_tokens.append(' ' + token)
# Remove whitespaces between Chinese characters.
# TODO: This will remove whitespaces between some English words as well. Need fix.
alphabet_set = set(list(string.ascii_letters))
for i in range(len(new_tokens)):
if new_tokens[i][0] == ' ':
if new_tokens[i][1] not in alphabet_set or i == 0:
new_tokens[i] = new_tokens[i][1:]
return new_tokens
def __init__(
self,
vocab_file: str | None = None,
tokenizer_file: str | None = None,
do_lower_case: bool = True,
do_basic_tokenize: bool = True,
unk_token: str = '<unk>',
sep_token: str = '<sep>',
pad_token: str = '<pad>',
cls_token: str = '<cls>',
mask_token: str = '<mask>',
bos_token: str = '<|startoftext|>',
eos_token: str = '<|endoftext|>',
tokenize_chinese_chars: bool = True,
strip_accents: bool | None = None,
**kwargs,
):
backend_tokenizer = None
if tokenizer_file is None:
backend_tokenizer = HFTokenizer(
models.WordPiece.from_file(
vocab=vocab_file,
unk_token=unk_token,
max_input_chars_per_word=100,
),
)
backend_tokenizer.add_special_tokens(
[unk_token, sep_token, pad_token, cls_token, mask_token, bos_token, eos_token],
)
normalizer_sequence = [normalizers.Replace('\n', sep_token)]
if do_basic_tokenize:
normalizer_sequence.append(
normalizers.BertNormalizer(
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
),
)
backend_tokenizer.normalizer = normalizers.Sequence(normalizer_sequence)
backend_tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Digits(individual_digits=True),
pre_tokenizers.Punctuation(),
pre_tokenizers.WhitespaceSplit(),
])
super().__init__(
tokenizer_file=tokenizer_file,
tokenizer_object=backend_tokenizer,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
bos_token=bos_token,
eos_token=eos_token,
do_lower_case=do_lower_case,
do_basic_tokenize=do_basic_tokenize,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
**kwargs,
)
self._tokenizer.decoder = decoders.Decoder.custom(self.CustomDecoder())
self.add_special_tokens({'additional_special_tokens': [f'<unused{i}>' for i in range(1, 11)]})
self.chat_template = '{{ bos_token }}{{ "问题:" }}{{ messages[-1]["content"] }}{{ "<unused1>答案:" }}'
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
files = self.backend_tokenizer.model.save(save_directory, name=filename_prefix)
return tuple(files)
def save_pretrained(self, *args, **kwargs) -> tuple[str]:
self._tokenizer.decoder = decoders.WordPiece()
return super().save_pretrained(*args, **kwargs)
|