File size: 1,774 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/usr/bin/env python3
"""scripts/prepare_tinystories.py — pack TinyStories text into uint8 .bin shards.

Reads ``data/tinystories/TinyStories-train.txt`` and ``TinyStories-valid.txt``,
encodes them with the byte tokenizer (no BPE), and writes flat uint8 arrays
to ``train.bin`` / ``valid.bin`` next to the input. Reports token counts.

The trainer memmaps these files, so for a ~2 GB train shard we never load
the whole thing into RAM.
"""
from __future__ import annotations

import argparse
import time
from pathlib import Path

import numpy as np


def pack_text_file(in_path: Path, out_path: Path, chunk_bytes: int = 64 * 1024 * 1024) -> int:
    n = 0
    t0 = time.time()
    with in_path.open("rb") as fin, out_path.open("wb") as fout:
        while True:
            chunk = fin.read(chunk_bytes)
            if not chunk:
                break
            arr = np.frombuffer(chunk, dtype=np.uint8)
            arr.tofile(fout)
            n += arr.size
            mb = n / (1024 * 1024)
            elapsed = time.time() - t0
            print(f"  {mb:>8.1f} MiB packed  ({elapsed:.1f}s)")
    return n


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--data-dir", type=Path, default=Path("data/tinystories"))
    args = ap.parse_args()

    pairs = [
        ("TinyStories-train.txt", "train.bin"),
        ("TinyStories-valid.txt", "valid.bin"),
    ]
    for src, dst in pairs:
        in_path = args.data_dir / src
        out_path = args.data_dir / dst
        if not in_path.exists():
            raise SystemExit(f"missing input: {in_path}")
        print(f"packing {in_path} -> {out_path}")
        n = pack_text_file(in_path, out_path)
        print(f"  done. {n:,} bytes / tokens")


if __name__ == "__main__":
    main()