Shard-1 / code /tokenizer.py
Crownelius's picture
Initial release: Shard-40m-v1 (54.5M dense transformer, anneal final)
025878f verified
"""Train a fresh 8K BPE on a FineWeb-edu sample.
This is the 50M-scale variant of the 1M project's 4K BPE. We bump the default
vocab to 8192 and the document count to 50000 (was 50000 in 1M, kept the same
because the 1M doc-count was already saturating BPE merge quality at 4K vocab
-- doubling vocab needs roughly the same training set, not 2x more).
We do NOT reuse any FANT tokenizer here -- the point of this experiment family
is a clean small recipe with no external dependencies.
Output: tokenizer.json in the working dir (or wherever specified).
"""
from __future__ import annotations
import argparse
import time
from pathlib import Path
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel as BLPre
from tokenizers.decoders import ByteLevel as BLDec
from tokenizers.processors import ByteLevel as BLPost
from tokenizers.trainers import BpeTrainer
SPECIAL_TOKENS = [
"<|pad|>", # 0
"<|bos|>", # 1
"<|eos|>", # 2
"<|unk|>", # 3
"<|im_start|>", # 4 -- chat role open
"<|im_end|>", # 5 -- chat role close
]
def _iter_fineweb(n_docs: int):
"""Yield up to `n_docs` text strings from the FineWeb-edu streaming feed."""
from datasets import load_dataset
ds = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="default",
split="train",
streaming=True,
)
n = 0
for ex in ds:
if n >= n_docs:
return
text = ex.get("text", "")
if isinstance(text, str) and text.strip():
n += 1
yield text
def train_tokenizer(out_path: str = "tokenizer.json", vocab_size: int = 8192, n_docs: int = 50000) -> str:
tok = Tokenizer(BPE(unk_token="<|unk|>"))
tok.pre_tokenizer = BLPre(add_prefix_space=False)
tok.decoder = BLDec()
tok.post_processor = BLPost(trim_offsets=False)
trainer = BpeTrainer(
vocab_size=vocab_size,
special_tokens=SPECIAL_TOKENS,
initial_alphabet=BLPre.alphabet(),
show_progress=False,
)
print(f"[tokenizer] streaming up to {n_docs} FineWeb-edu docs...")
t0 = time.time()
docs = list(_iter_fineweb(n_docs))
print(f"[tokenizer] collected {len(docs)} docs in {time.time() - t0:.1f}s")
print(f"[tokenizer] training BPE vocab_size={vocab_size}...")
t0 = time.time()
tok.train_from_iterator(docs, trainer=trainer)
print(f"[tokenizer] trained in {time.time() - t0:.1f}s; vocab={tok.get_vocab_size()}")
out_dir = Path(out_path).parent
if str(out_dir) and not out_dir.exists():
out_dir.mkdir(parents=True, exist_ok=True)
tok.save(out_path)
print(f"[tokenizer] saved to {out_path}")
return out_path
def load_tokenizer(path: str = "tokenizer.json") -> Tokenizer:
return Tokenizer.from_file(path)
# Convenience accessors used by data.py / train.py
def special_token_id(tok: Tokenizer, name: str) -> int:
tid = tok.token_to_id(name)
assert tid is not None, f"{name} not in tokenizer"
return tid
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out", default="tokenizer.json")
ap.add_argument("--vocab", type=int, default=8192)
ap.add_argument("--docs", type=int, default=50000)
args = ap.parse_args()
train_tokenizer(args.out, args.vocab, args.docs)
if __name__ == "__main__":
main()