jamino30's picture
Upload folder using huggingface_hub
0097326 verified
raw
history blame
1.53 kB
from xsbpe.base import Tokenizer, get_adjacent_pair_counts, merge_pairs
class BasicTokenizer(Tokenizer):
def __init__(self):
super().__init__()
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
text_bytes = text.encode('utf-8')
ids = list(text_bytes)
merges = {}
vocab = {idx: bytes([idx]) for idx in range(256)}
for i in range(num_merges):
stats = get_adjacent_pair_counts(ids)
pair = max(stats, key=stats.get)
idx = 256 + i
ids = merge_pairs(ids, pair, idx)
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
self.merges = merges
self.vocab = vocab
def decode(self, ids):
text_bytes = b''.join(self.vocab[idx] for idx in ids)
text = text_bytes.decode('utf-8', errors='replace')
return text
def encode(self, text):
text_bytes = text.encode('utf-8')
ids = list(text_bytes)
while len(ids) >= 2:
stats = get_adjacent_pair_counts(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
if pair not in self.merges:
break
idx = self.merges[pair]
ids = merge_pairs(ids, pair, idx)
return ids