feather-a10g-large-runtime / overlay /scripts /build_token_cache.py
icarus112's picture
Update Feather a10g-large training runtime image
e5cf7c3 verified
from __future__ import annotations
"""Fast parallel token cache builder.
Reads parquet shards DIRECTLY via pyarrow (no HF streaming overhead),
tokenizes with multiprocessing.Pool, writes packed (T+1) int32 rows.
Uses the pre-downloaded shards in ~/.cache/huggingface/hub/ — no network.
Usage: python scripts/build_token_cache.py [--gb 2] [--workers 8]
"""
import argparse
import glob
import os
import sys
import time
from pathlib import Path
from multiprocessing import Pool
sys.stdout.reconfigure(line_buffering=True)
import numpy as np
import pyarrow.parquet as pq
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from prepare import Tokenizer
HF_HUB_CACHE = os.path.expanduser("~/.cache/huggingface/hub")
# Which column each dataset uses for text
TEXT_COLS: dict[str, list[str]] = {
"fineweb-edu": ["text"],
"fineweb": ["text"],
"stack-v2": ["text", "content"],
"nemotron-math": ["text"],
"nemotron-specialized": ["text"],
"wikipedia": ["text"],
"cosmopedia": ["text"],
}
# Dataset repo → cache dir mapping
REPO_DIRS = {
"fineweb-edu": "datasets--HuggingFaceFW--fineweb-edu",
"fineweb": "datasets--HuggingFaceFW--fineweb",
"stack-v2": "datasets--OpenCoder-LLM--opc-fineweb-code-corpus",
"nemotron-math": "datasets--nvidia--Nemotron-CC-Math-v1",
"nemotron-specialized": "datasets--nvidia--Nemotron-Pretraining-Specialized-v1.1",
"wikipedia": "datasets--wikimedia--wikipedia",
"cosmopedia": "datasets--HuggingFaceTB--cosmopedia",
}
def find_parquet_files() -> list[tuple[str, str]]:
"""Return [(dataset_name, parquet_path), ...] for all cached shards."""
results = []
for name, dirname in REPO_DIRS.items():
base = os.path.join(HF_HUB_CACHE, dirname, "snapshots")
if not os.path.isdir(base):
continue
for snap in os.listdir(base):
snap_dir = os.path.join(base, snap)
for root, _, files in os.walk(snap_dir):
for f in files:
if f.endswith(".parquet"):
results.append((name, os.path.join(root, f)))
return results
# Tokenizer loaded once per worker process
_WORKER_TOKENIZER = None
_WORKER_BOS = None
def _worker_init():
global _WORKER_TOKENIZER, _WORKER_BOS
_WORKER_TOKENIZER = Tokenizer.from_directory()
_WORKER_BOS = _WORKER_TOKENIZER.get_bos_token_id()
def _tokenize_batch(args: tuple[list[str], int]) -> list[list[int]]:
"""Tokenize a batch of text strings. Returns list of token-id lists."""
texts, _ = args
return _WORKER_TOKENIZER.encode(texts, prepend=_WORKER_BOS)
def iter_text_from_parquet(name: str, path: str, batch_size: int = 512):
"""Stream text batches from one parquet file."""
cols = TEXT_COLS.get(name, ["text"])
try:
pf = pq.ParquetFile(path)
except Exception as e:
print(f" [skip] {path}: {e}", flush=True)
return
# Find which column exists
schema_names = set(pf.schema_arrow.names)
col = next((c for c in cols if c in schema_names), None)
if col is None:
return
for batch in pf.iter_batches(batch_size=batch_size, columns=[col]):
texts = batch.column(col).to_pylist()
texts = [t for t in texts if t]
if texts:
yield texts
def pack_rows(token_lists: list[list[int]], row_capacity: int) -> np.ndarray:
"""Pack variable-length token sequences into (N, row_capacity) rows using simple greedy concat."""
rows = []
current = []
for doc in token_lists:
if len(current) + len(doc) > row_capacity:
# Flush current row (pad with 0)
if len(current) >= row_capacity // 2: # skip too-short trailing bits
row = current[:row_capacity]
if len(row) < row_capacity:
row = row + [0] * (row_capacity - len(row))
rows.append(row)
# Start new row with this doc (truncate if too long)
current = doc[:row_capacity]
else:
current.extend(doc)
# Emit full rows as we fill up
while len(current) >= row_capacity:
rows.append(current[:row_capacity])
current = current[row_capacity:]
if not rows:
return np.empty((0, row_capacity), dtype=np.int32)
return np.asarray(rows, dtype=np.int32)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--gb", type=float, default=2.0)
ap.add_argument("--seq-len", type=int, default=512)
ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 4) - 2))
ap.add_argument("--batch-size", type=int, default=512, help="docs per tokenizer call")
args = ap.parse_args()
T = args.seq_len
row_capacity = T + 1
target_bytes = int(args.gb * 1024**3)
target_rows = target_bytes // (row_capacity * 4)
# Load tokenizer in main process for vocab size
tok = Tokenizer.from_directory()
V = tok.get_vocab_size()
cache_path = os.path.expanduser(
f"~/.cache/autoresearch/packed_tokens_v1_T{T}_V{V}_train.bin"
)
tmp_path = cache_path + ".tmp"
print(f"[cache-build] target: {args.gb:.1f} GB = {target_rows} rows of (T+1)={row_capacity} int32", flush=True)
print(f"[cache-build] workers: {args.workers}", flush=True)
parquet_files = find_parquet_files()
print(f"[cache-build] found {len(parquet_files)} parquet shards", flush=True)
for name, path in parquet_files:
sz = os.path.getsize(path) / 1024**2
print(f" [{name}] {path.split('/blobs/')[-1]} ({sz:.0f} MB)", flush=True)
if not parquet_files:
print("[cache-build] no shards found — run predownload first", flush=True)
sys.exit(1)
t_start = time.time()
rows_written = 0
# Single-batch tokenize function using the pool
pool = Pool(processes=args.workers, initializer=_worker_init)
pending_batches = [] # batches of texts waiting to be tokenized
PENDING_LIMIT = args.workers * 4
def flush_to_tokenize():
"""Submit pending batches to pool, write results as they come."""
nonlocal rows_written
if not pending_batches:
return
batch_args = [(b, 0) for b in pending_batches]
# Use imap_unordered for streaming results
for token_lists in pool.imap_unordered(_tokenize_batch, batch_args, chunksize=1):
rows = pack_rows(token_lists, row_capacity)
if len(rows) > 0:
fout.write(rows.tobytes())
rows_written += len(rows)
if rows_written >= target_rows:
return
if rows_written % 8192 < len(rows):
elapsed = time.time() - t_start
bw = rows_written * row_capacity * 4 / 1024**3
mbps = bw * 1024 / max(elapsed, 0.001)
pct = 100 * rows_written / target_rows
print(f" {rows_written:>8} rows {bw:.2f} GB {pct:5.1f}% {mbps:.1f} MB/s t={elapsed:.0f}s", flush=True)
pending_batches.clear()
with open(tmp_path, "wb") as fout:
try:
done = False
# Round-robin across datasets to get diverse blend
iterators = []
for name, path in parquet_files:
iterators.append((name, iter_text_from_parquet(name, path, args.batch_size)))
while iterators and not done:
for i in range(len(iterators) - 1, -1, -1):
name, it = iterators[i]
try:
texts = next(it)
except StopIteration:
iterators.pop(i)
continue
pending_batches.append(texts)
if len(pending_batches) >= PENDING_LIMIT:
flush_to_tokenize()
if rows_written >= target_rows:
done = True
break
# Final flush
if not done and pending_batches:
flush_to_tokenize()
finally:
pool.close()
pool.terminate()
pool.join()
os.replace(tmp_path, cache_path)
elapsed = time.time() - t_start
total_bytes = rows_written * row_capacity * 4
print(f"\n[cache-build] DONE — {rows_written} rows, {total_bytes/1024**3:.2f} GB in {elapsed:.0f}s ({total_bytes/1024**2/elapsed:.1f} MB/s)", flush=True)
print(f"[cache-build] cache: {cache_path}", flush=True)
if __name__ == "__main__":
main()