Spaces:
Sleeping
Sleeping
import regex as re | |
from base_tokenizer import Tokenizer, get_stats, merge | |
DEVANAGARI_SPLIT_PATTERN = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{N}+| ?(?:[\u0904-\u0939\u093d-\u093d\u0950-\u0950\u0958-\u0961\u0970-\u097f\ua8f2-\ua8fe\U00011b00-\U00011b09\u1cd3-\u1cd3\u1ce9-\u1cec\u1cee-\u1cf3\u1cf5-\u1cf6\u1cfa-\u1cfa][\u0900-\u0903\u093a-\u093c\u093e-\u094f\u0951-\u0957\u0962-\u0963\ua8e0-\ua8f1\ua8ff-\ua8ff\u1cd0-\u1cd2\u1cd4-\u1ce8\u1ced-\u1ced\u1cf4-\u1cf4\u1cf7-\u1cf9]*)+| ?\p{L}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |
class DevanagariTokenizer(Tokenizer): | |
def __init__(self, pattern=None): | |
""" | |
- pattern: optional string to override the default (GPT-4 split pattern) | |
- special_tokens: str -> int dictionary of special tokens | |
example: {'<|endoftext|>': 100257} | |
""" | |
super().__init__() | |
self.pattern = DEVANAGARI_SPLIT_PATTERN if pattern is None else pattern | |
self.compiled_pattern = re.compile(self.pattern) | |
self.special_tokens = {} | |
self.inverse_special_tokens = {} | |
def train(self, text, vocab_size, verbose=False): | |
assert vocab_size >= 256 | |
num_merges = vocab_size - 256 | |
# split the text up into text chunks | |
text_chunks = re.findall(self.compiled_pattern, text) | |
# input text preprocessing | |
ids = [list(ch.encode("utf-8")) for ch in text_chunks] | |
# iteratively merge the most common pairs to create new tokens | |
merges = {} # (int, int) -> int | |
vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes | |
input_len = 0 | |
for chunk_ids in ids: | |
# calculate length of tokens for compression ratio. | |
# total token length is sum of all token length in each chunk. | |
input_len += len(chunk_ids) | |
for i in range(num_merges): | |
# count the number of times every consecutive pair appears | |
stats = {} | |
for chunk_ids in ids: | |
# passing in stats will update it in place, adding up counts | |
get_stats(chunk_ids, stats) | |
# find the pair with the highest count | |
pair = max(stats, key=stats.get) | |
# mint a new token: assign it the next available id | |
idx = 256 + i | |
# replace all occurrences of pair in ids with idx | |
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] | |
# save the merge | |
merges[pair] = idx | |
vocab[idx] = vocab[pair[0]] + vocab[pair[1]] | |
# prints | |
if verbose: | |
print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") | |
output_len = 0 | |
for chunk_ids in ids: | |
output_len += len(chunk_ids) | |
print(f"input_len: {input_len}, output_len: {output_len} compression ratio: {input_len / output_len:.2f}X") | |
# save class variables | |
self.merges = merges # used in encode() | |
self.vocab = vocab # used in decode() | |
def register_special_tokens(self, special_tokens): | |
# special_tokens is a dictionary of str -> int | |
# example: {"<|endoftext|>": 100257} | |
self.special_tokens = special_tokens | |
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} | |
def decode(self, ids): | |
# given ids (list of integers), return Python string | |
part_bytes = [] | |
# get the byte for the corresponding token from vocab | |
for idx in ids: | |
if idx in self.vocab: | |
part_bytes.append(self.vocab[idx]) | |
elif idx in self.inverse_special_tokens: | |
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) | |
else: | |
raise ValueError(f"invalid token id: {idx}") | |
text_bytes = b"".join(part_bytes) | |
text = text_bytes.decode("utf-8", errors="replace") | |
return text | |
def _encode_chunk(self, text_bytes): | |
# return the token ids | |
# let's begin. first, convert all bytes to integers in range 0..255 | |
ids = list(text_bytes) | |
while len(ids) >= 2: | |
# find the pair with the lowest merge index | |
stats = get_stats(ids) | |
pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | |
# subtle: if there are no more merges available, the key will | |
# result in an inf for every single pair, and the min will be | |
# just the first pair in the list, arbitrarily | |
# we can detect this terminating case by a membership check | |
if pair not in self.merges: | |
break # nothing else can be merged anymore | |
# otherwise let's merge the best pair (lowest merge index) | |
idx = self.merges[pair] | |
ids = merge(ids, pair, idx) | |
return ids | |
def encode_ordinary(self, text): | |
"""Encoding that ignores any special tokens.""" | |
# split text into chunks of text by categories defined in regex pattern | |
text_chunks = re.findall(self.compiled_pattern, text) | |
# all chunks of text are encoded separately, then results are joined | |
ids = [] | |
for chunk in text_chunks: | |
chunk_bytes = chunk.encode("utf-8") # raw bytes | |
chunk_ids = self._encode_chunk(chunk_bytes) | |
ids.extend(chunk_ids) | |
return ids | |
def encode(self, text, allowed_special="none_raise"): | |
""" | |
Unlike encode_ordinary, this function handles special tokens. | |
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens | |
if none_raise, then an error is raised if any special token is encountered in text | |
this is the default tiktoken behavior right now as well | |
any other behavior is either annoying, or a major footgun | |
""" | |
# decode the user desire w.r.t. handling of special tokens | |
special = None | |
if allowed_special == "all": | |
special = self.special_tokens | |
elif allowed_special == "none": | |
special = {} | |
elif allowed_special == "none_raise": | |
special = {} | |
assert all(token not in text for token in self.special_tokens) | |
elif isinstance(allowed_special, set): | |
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} | |
else: | |
raise ValueError(f"allowed_special={allowed_special} not understood") | |
if not special: | |
# shortcut: if no special tokens, just use the ordinary encoding | |
return self.encode_ordinary(text) | |
# otherwise, we have to be careful with potential special tokens in text | |
# we handle special tokens by splitting the text | |
# based on the occurrence of any exact match with any of the special tokens | |
# we can use re.split for this. note that surrounding the pattern with () | |
# makes it into a capturing group, so the special tokens will be included | |
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" | |
special_chunks = re.split(special_pattern, text) | |
# now all the special characters are separated from the rest of the text | |
# all chunks of text are encoded separately, then results are joined | |
ids = [] | |
for part in special_chunks: | |
if part in special: | |
# this is a special token, encode it separately as a special case | |
ids.append(special[part]) | |
else: | |
# this is an ordinary sequence, encode it normally | |
ids.extend(self.encode_ordinary(part)) | |
return ids |