360Zhinao2-7B-Chat-360K / tokenization_zhinao.py
zhaicunqi's picture
Upload folder using huggingface_hub
fa39d8a verified
raw
history blame
9.14 kB
import os
import torch
import base64
import tiktoken
from typing import Collection, Optional, Dict, List, Set, Tuple, Union
from transformers import PreTrainedTokenizer
from transformers.utils import PaddingStrategy
from transformers.tokenization_utils import PreTrainedTokenizer
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
class SPTokenizer:
def __init__(self, model_path):
self.vocab_file = model_path
self.pad_token = '<pad>'
self.unk_token = '<unk>'
self.mask_token = '<mask>'
self.eod_token = '<eod>'
self.eop_token = '<eop>'
self.im_start_token = '<|im_start|>'
self.im_end_token = '<|im_end|>'
## special_tokens
self.SPECIAL_TOKENS = (
self.pad_token,
self.unk_token,
self.mask_token,
self.eod_token,
self.eop_token,
'[space2]', '[space3]', '[space4]', '[space8]',
self.im_start_token, self.im_end_token
)
self.bulid_tokenizer()
self.out = self.output_core_token()
self.token2strs = {
"[space2]": " ",
"[space3]": " ",
"[space4]": " ",
"[space8]": " ",
}
self.str2tokens = {v: k for k, v in self.token2strs.items()}
self.sorted_strs = sorted(list(self.str2tokens.keys()),
key=lambda x: len(x), reverse=True)
## skip_special_tokens
self.decode_skip_special_tokens = [
self.pad_token,
self.unk_token,
self.mask_token,
self.eod_token,
self.eop_token,
self.im_start_token,
self.im_end_token]
self.decode_skip_special_tokens_ids = [self.convert_token_to_id(token) for token in self.decode_skip_special_tokens]
def _load_tiktoken_bpe(self, tiktoken_bpe_file: str):
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
def bulid_tokenizer(self):
mergeable_ranks = self._load_tiktoken_bpe(self.vocab_file)
special_tokens = {
token: index
for index, token in enumerate(
self.SPECIAL_TOKENS, start=len(mergeable_ranks)
)
}
encode = tiktoken.Encoding(
"zhinao",
pat_str=PAT_STR,
mergeable_ranks=mergeable_ranks,
special_tokens=special_tokens
)
decoder = {v: k for k, v in mergeable_ranks.items()}
decoder.update({v: k for k, v in special_tokens.items()})
decoder_token2id = {v: k for k, v in decoder.items()}
self.tokenizer = encode
self.decoder = decoder
self.decoder_token2id = decoder_token2id
self.num_tokens = len(mergeable_ranks) + len(self.SPECIAL_TOKENS)
def output_core_token(self):
"""output special tokens"""
out = {}
for t in self.SPECIAL_TOKENS:
out[t] = self.convert_token_to_id(t)
return out
def tokenize(
self,
text,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = ()):
tokens = []
text = self.convert(text)
for idx in self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special):
tokens.append(self.decoder[idx])
return tokens
def encode(self, text, allowed_special="all", disallowed_special=()):
"""text to id"""
text = self.convert(text)
return self.tokenizer.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
def decode(self, ids, errors="replace"):
"""id to text"""
text = self.tokenizer.decode(ids, errors=errors)
return self.deconvert(text)
def decode_tokens(self, tokens: List[str]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors="ignore")
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type bytes or str")
if temp:
text += temp.decode("utf-8", errors="ignore")
return self.deconvert(text)
def convert_id_to_token(self, idx):
return self.decoder[idx]
def convert_token_to_id(self, token):
return self.decoder_token2id[token]
def convert(self, text):
"""将文本的特殊字符转换成特殊token"""
for k in ["[br]", "<br>"]:
text = text.replace(k, "\n")
for k in self.sorted_strs:
if k in text:
text = text.replace(k, self.str2tokens[k])
return text
def deconvert(self, text):
"""将解码文本恢复原始字符"""
for t in self.token2strs:
if t in text:
text = text.replace(t, self.token2strs[t])
return text
class ZhinaoTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "vocab/360.tiktoken"}
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
self.name = "ZhinaoTokenizer"
self.vocab_file = vocab_file
self.tokenizer = SPTokenizer(model_path=vocab_file)
try:
kwargs.pop('eos_token')
kwargs.pop('pad_token')
kwargs.pop('unk_token')
except:
pass
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
self.pad_token_id = self.tokenizer.convert_token_to_id(self.tokenizer.pad_token)
self.eod_id = self.tokenizer.convert_token_to_id(self.tokenizer.eod_token)
self.im_start_id = self.tokenizer.convert_token_to_id(self.tokenizer.im_start_token)
self.im_end_id = self.tokenizer.convert_token_to_id(self.tokenizer.im_end_token)
@property
def eop_token(self) -> str:
return self.tokenizer.eop_token
@property
def eop_token_id(self):
return self.tokenizer.convert_token_to_id(self.tokenizer.eop_token)
@property
def vocab_size(self):
return self.tokenizer.num_tokens
def get_vocab(self):
""" Returns vocab as a dict """
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
split_special_tokens=False,
) -> List[Union[bytes, str]]:
tokens = []
for t in self.tokenizer.encode(
text, allowed_special=allowed_special, disallowed_special=disallowed_special
):
tokens.append(self.tokenizer.decoder[t])
return tokens
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = "ignore",
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
if skip_special_tokens:
token_ids = [i for i in token_ids if i not in self.tokenizer.decode_skip_special_tokens_ids]
return self.tokenizer.decode(token_ids, errors=errors)
def _tokenize(self, text, **kwargs):
raise NotImplementedError
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab. """
return self.tokenizer.convert_id_to_token(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
return self.tokenizer.decode_tokens(tokens)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""Save only the vocabulary of the tokenizer (vocabulary). """
if os.path.isdir(save_directory):
vocab_file = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
os.makedirs(save_directory + "/vocab", exist_ok=True)
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)