Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import logging | |
from copy import copy | |
from pathlib import Path | |
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer | |
try: | |
import tiktoken | |
from tiktoken.load import load_tiktoken_bpe | |
has_tiktoken = True | |
except ImportError: | |
has_tiktoken = False | |
DEFAULT_TIKTOKEN_PATTERN = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" | |
DEFAULT_TIKTOKEN_SPECIAL_TOKENS = { | |
"<|begin_of_text|>": 0, | |
"<|end_of_text|>": 1, | |
"<|fim_prefix|>": 2, | |
"<|fim_middle|>": 3, | |
"<|fim_end_fill|>": 253, | |
"<|fim_pad|>": 254, | |
"<|fim_suffix|>": 255, | |
} | |
TIKTOKEN_MAX_ENCODE_CHARS = 400_000 | |
logger = logging.getLogger(__name__) | |
class TikTokenTokenizer(Tokenizer): | |
def __init__(self, model_path: str) -> None: | |
mergeable_ranks = load_tiktoken_bpe(model_path) | |
all_special_tokens_with_ids = copy(DEFAULT_TIKTOKEN_SPECIAL_TOKENS) | |
missing_ids = set(range(256)) - set(all_special_tokens_with_ids.values()) | |
for id in missing_ids: | |
all_special_tokens_with_ids[f"<|reserved_special_token_{id}|>"] = id | |
for name in all_special_tokens_with_ids: | |
all_special_tokens_with_ids[name] += len(mergeable_ranks) | |
self.tkt_model = tiktoken.core.Encoding( | |
name=Path(model_path).stem, | |
pat_str=DEFAULT_TIKTOKEN_PATTERN, | |
mergeable_ranks=mergeable_ranks, | |
special_tokens=all_special_tokens_with_ids, | |
) | |
self.bos_id: int = self.tkt_model.encode_single_token("<|begin_of_text|>") | |
self.eos_id: int = self.tkt_model.encode_single_token("<|end_of_text|>") | |
self.n_words: int = self.tkt_model.n_vocab | |
logger.info( | |
f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" | |
) | |
def get_vocab_size(self) -> int: | |
return self.n_words | |
def encode(self, s: str, add_bos: bool, add_eos: bool): | |
assert isinstance(s, str) | |
subs = [] | |
for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS): | |
subs.append(s[i : i + TIKTOKEN_MAX_ENCODE_CHARS]) | |
return ( | |
[self.bos_id] * add_bos | |
+ sum(self.tkt_model.encode_ordinary_batch(subs), start=[]) | |
+ [self.eos_id] * add_eos | |
) | |
def decode(self, tokens: list[int]): | |
return self.tkt_model.decode(tokens) | |
def get_token_offsets( | |
self, text: str, tokens: list[int] | None = None | |
) -> tuple[list[str], list[int]]: | |
if tokens is not None: | |
token_bytes = self.tkt_model.decode_tokens_bytes(tokens) | |
else: | |
token_bytes = self.tkt_model.decode_tokens_bytes( | |
self.tkt_model.encode(text, allowed_special="all") | |
) | |
text_len, offsets = 0, [] | |
for token in token_bytes: | |
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0))) | |
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0) | |
substrs = [text[s:e] for s, e in zip(offsets, offsets[1:] + [None])] | |
return substrs, offsets | |