import unicodedata from collections import defaultdict from itertools import pairwise def get_adjacent_pair_counts(ids) -> defaultdict: counts = defaultdict(int) for pair in pairwise(ids): counts[pair] += 1 return counts def merge_pairs(ids, pair, idx): newids = [] i = 0 n = len(ids) while i < n: if i < n - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]: newids.append(idx) i += 2 else: newids.append(ids[i]) i += 1 return newids def replace_control_characters(s: str) -> str: chars = [] for ch in s: if unicodedata.category(ch)[0] != 'C': chars.append(ch) else: chars.append(f'\\u{ord(ch):04x}') # escape return ''.join(chars) def render_token(t: bytes) -> str: s = t.decode('utf-8', errors='replace') s = replace_control_characters(s) return s class Tokenizer: """Base class for Tokenizers""" def __init__(self): self.merges = {} self.pattern = '' self.special_tokens = {} self.vocab = self._build_vocab() def train(self, text, vocab_size, verbose=False): raise NotImplementedError def encode(self, text): raise NotImplementedError def decode(self, ids): raise NotImplementedError def _build_vocab(self): vocab = {idx: bytes([idx]) for idx in range(256)} for (p0, p1), idx in self.merges.items(): vocab[idx] = vocab[p0] + vocab[p1] for special, idx in self.special_tokens.items(): vocab[idx] = special.encode('utf-8') return vocab