Atom2.7m / tokenizer_utils.py
ucr-max's picture
Upload Atom2.7m model
271e253 verified
Raw
History Blame Contribute Delete
11.3 kB
"""Shared construction and loading helpers for the project's tokenizer."""
from __future__ import annotations
from dataclasses import dataclass, field
import json
from pathlib import Path
import re
from typing import Any, Iterable
SPECIAL_TOKENS = [
"<|pad|>",
"<|bos|>",
"<|eos|>",
"<|unk|>",
"<|endoftext|>",
]
EOT_ID = SPECIAL_TOKENS.index("<|endoftext|>")
ARITHMETIC_TOKENS = ("+", "-", "*", "/", "=", "(", ")")
MAX_PLACE_ID = 64
PLACE_OVERFLOW_ID = MAX_PLACE_ID + 1
PLACE_VOCAB_SIZE = PLACE_OVERFLOW_ID + 1
RESULT_ROLE_ID = 10
SPACE_ROLE_ID = 11
ROLE_VOCAB_SIZE = SPACE_ROLE_ID + 1
MAX_OPERAND_ROLES = 9
@dataclass(frozen=True)
class FusionEncoding:
ids: list[int]
place_ids: list[int]
role_ids: list[int]
tokens: list[str] = field(default_factory=list)
@property
def input_ids(self) -> list[int]:
return self.ids
def __len__(self) -> int:
return len(self.ids)
def __iter__(self):
return iter(self.ids)
def __post_init__(self) -> None:
if not (len(self.ids) == len(self.place_ids) == len(self.role_ids)):
raise ValueError("Fusion tokenizer streams must have equal length")
def build_tokenizer() -> Any:
"""Build a byte-level BPE tokenizer with explicit lossless boundaries."""
from tokenizers import Regex, Tokenizer, decoders, models, pre_tokenizers
tokenizer = Tokenizer(models.BPE(unk_token="<|unk|>"))
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Split(
Regex(r"\s+|\d|[+\-*/=()]|[^\s\d+\-*/=()]+"),
behavior="isolated",
),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False),
]
)
tokenizer.decoder = decoders.ByteLevel()
return tokenizer
class FusionTokenizer:
"""Runtime wrapper adding LSD-first digit streams to a trained BPE tokenizer."""
_digit_span_re = re.compile(r"\d+")
def __init__(self, tokenizer: Any):
self.tokenizer = tokenizer
self._digit_token_ids = frozenset(
token_id
for digit in "0123456789"
if (token_id := self.tokenizer.token_to_id(digit)) is not None
)
self._digit_id_to_text = {
int(self.tokenizer.token_to_id(digit)): digit
for digit in "0123456789"
if self.tokenizer.token_to_id(digit) is not None
}
self._equals_id = self.tokenizer.token_to_id("=")
self._special_token_ids = frozenset(
token_id
for token in SPECIAL_TOKENS
if (token_id := self.tokenizer.token_to_id(token)) is not None
)
if len(self._digit_token_ids) != 10:
raise ValueError("Tokenizer vocabulary must contain atomic digit tokens 0-9")
if self._equals_id is None:
raise ValueError("Tokenizer vocabulary must contain an atomic '=' token")
def __getattr__(self, name: str) -> Any:
return getattr(self.tokenizer, name)
@property
def digit_token_ids(self) -> frozenset[int]:
return self._digit_token_ids
@property
def special_token_ids(self) -> frozenset[int]:
return self._special_token_ids
def get_vocab_size(self, with_added_tokens: bool = True) -> int:
return int(self.tokenizer.get_vocab_size(with_added_tokens=with_added_tokens))
def get_vocab(self, with_added_tokens: bool = True) -> dict[str, int]:
return self.tokenizer.get_vocab(with_added_tokens=with_added_tokens)
def token_to_id(self, token: str) -> int | None:
return self.tokenizer.token_to_id(token)
def id_to_token(self, token_id: int) -> str | None:
return self.tokenizer.id_to_token(int(token_id))
@classmethod
def _reverse_digit_spans(cls, text: str) -> str:
return cls._digit_span_re.sub(lambda match: match.group(0)[::-1], text)
def _decode_token_piece(self, token_id: int) -> str:
return self.tokenizer.decode([int(token_id)], skip_special_tokens=False)
@staticmethod
def _is_equation_whitespace(piece: str) -> bool:
return bool(piece) and piece.isspace() and "\n" not in piece and "\r" not in piece
def _is_equation_piece(self, token_id: int, piece: str) -> bool:
if token_id in self._special_token_ids:
return False
if token_id in self._digit_token_ids:
return True
if self._is_equation_whitespace(piece):
return True
return len(piece) == 1 and piece in set(ARITHMETIC_TOKENS)
def _annotate_equation_span(
self,
ids: list[int],
pieces: list[str],
start: int,
end: int,
role_ids: list[int],
) -> None:
equals_positions = [
index
for index in range(start, end)
if ids[index] == self._equals_id
]
if len(equals_positions) != 1:
return
equals_position = equals_positions[0]
digit_runs: list[tuple[int, int]] = []
index = start
while index < end:
if ids[index] not in self._digit_token_ids:
index += 1
continue
run_start = index
while index < end and ids[index] in self._digit_token_ids:
index += 1
digit_runs.append((run_start, index))
operand_runs = [(a, b) for a, b in digit_runs if b <= equals_position]
result_runs = [(a, b) for a, b in digit_runs if a > equals_position]
if not operand_runs or not result_runs or len(operand_runs) > MAX_OPERAND_ROLES:
return
for index in range(start, end):
if self._is_equation_whitespace(pieces[index]):
role_ids[index] = SPACE_ROLE_ID
for role, (run_start, run_end) in enumerate(operand_runs, start=1):
for index in range(run_start, run_end):
role_ids[index] = role
for run_start, run_end in result_runs:
for index in range(run_start, run_end):
role_ids[index] = RESULT_ROLE_ID
def annotate_ids(self, ids: Iterable[int]) -> tuple[list[int], list[int]]:
input_ids = [int(token_id) for token_id in ids]
place_ids = [0] * len(input_ids)
role_ids = [0] * len(input_ids)
pieces = [self._decode_token_piece(token_id) for token_id in input_ids]
index = 0
while index < len(input_ids):
if input_ids[index] not in self._digit_token_ids:
index += 1
continue
run_start = index
while index < len(input_ids) and input_ids[index] in self._digit_token_ids:
offset = index - run_start + 1
place_ids[index] = min(offset, PLACE_OVERFLOW_ID)
index += 1
span_start: int | None = None
for index, (token_id, piece) in enumerate(zip(input_ids, pieces, strict=True)):
if self._is_equation_piece(token_id, piece):
if span_start is None:
span_start = index
continue
if span_start is not None:
self._annotate_equation_span(input_ids, pieces, span_start, index, role_ids)
span_start = None
if span_start is not None:
self._annotate_equation_span(input_ids, pieces, span_start, len(input_ids), role_ids)
return place_ids, role_ids
def encode(self, text: str, *args, **kwargs) -> FusionEncoding:
transformed = self._reverse_digit_spans(text)
encoding = self.tokenizer.encode(transformed, *args, **kwargs)
ids = [int(token_id) for token_id in encoding.ids]
place_ids, role_ids = self.annotate_ids(ids)
return FusionEncoding(
ids=ids,
place_ids=place_ids,
role_ids=role_ids,
tokens=list(getattr(encoding, "tokens", [])),
)
def encode_batch(self, texts: list[str], *args, **kwargs) -> list[FusionEncoding]:
return [self.encode(text, *args, **kwargs) for text in texts]
def decode(
self,
token_ids: Iterable[int],
skip_special_tokens: bool = True,
) -> str:
pieces: list[str] = []
text_ids: list[int] = []
digit_buffer: list[str] = []
def flush_text() -> None:
if text_ids:
pieces.append(
self.tokenizer.decode(
text_ids,
skip_special_tokens=skip_special_tokens,
)
)
text_ids.clear()
def flush_digits() -> None:
if digit_buffer:
pieces.extend(reversed(digit_buffer))
digit_buffer.clear()
for raw_id in token_ids:
token_id = int(raw_id)
if token_id in self._digit_token_ids:
flush_text()
digit_buffer.append(self._digit_id_to_text[token_id])
continue
flush_digits()
text_ids.append(token_id)
flush_text()
flush_digits()
return "".join(pieces)
def build_trainer(vocab_size: int, min_frequency: int) -> Any:
from tokenizers import pre_tokenizers, trainers
return trainers.BpeTrainer(
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=SPECIAL_TOKENS,
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
)
def tokenizer_files(tokenizer_dir: Path) -> tuple[Path, Path, Path]:
return (
tokenizer_dir / "tokenizer.json",
tokenizer_dir / "vocab.json",
tokenizer_dir / "merges.txt",
)
def validate_tokenizer(tokenizer_dir: Path) -> None:
tokenizer_json, vocab_path, merges_path = tokenizer_files(tokenizer_dir)
if not tokenizer_json.exists():
raise FileNotFoundError(
f"Missing {tokenizer_json}. Retrain with train_tokenizer.py so the "
"whitespace and digit boundary rules are preserved."
)
if vocab_path.exists():
with vocab_path.open("r", encoding="utf-8") as f:
vocab = json.load(f)
else:
with tokenizer_json.open("r", encoding="utf-8") as f:
tokenizer_data = json.load(f)
vocab = tokenizer_data.get("model", {}).get("vocab")
if not isinstance(vocab, dict):
raise FileNotFoundError(f"Missing vocab.json and no embedded vocab in {tokenizer_json}")
max_id = max(vocab.values())
if max_id > 65_535:
raise ValueError(f"Tokenizer max id {max_id} does not fit in uint16")
if vocab.get("<|endoftext|>") != EOT_ID:
raise ValueError(
f"Expected <|endoftext|> id {EOT_ID}, "
f"got {vocab.get('<|endoftext|>')}"
)
missing = [
token
for token in (*[str(value) for value in range(10)], *ARITHMETIC_TOKENS)
if token not in vocab
]
if missing:
raise ValueError(f"Tokenizer missing required atomic tokens: {missing}")
def load_tokenizer(tokenizer_dir: Path) -> Any:
from tokenizers import Tokenizer
validate_tokenizer(tokenizer_dir)
tokenizer_json, _, _ = tokenizer_files(tokenizer_dir)
return FusionTokenizer(Tokenizer.from_file(str(tokenizer_json)))