Add Lichess PGN -> PAWN Parquet extraction pipeline (#4)
Browse files- .dockerignore +2 -0
- Dockerfile +7 -0
- deploy/entrypoint-lichess-parquet.sh +62 -0
- engine/python/chess_engine/__init__.py +6 -0
- engine/src/lib.rs +264 -0
- engine/src/pgn.rs +628 -0
- scripts/extract_lichess_parquet.py +603 -0
- tests/test_enriched_pgn.py +393 -0
.dockerignore
CHANGED
|
@@ -14,6 +14,8 @@ deploy/
|
|
| 14 |
!deploy/entrypoint-extract.sh
|
| 15 |
!deploy/entrypoint-lc0.sh
|
| 16 |
!deploy/entrypoint-rosa-sweep.sh
|
|
|
|
|
|
|
| 17 |
*.so
|
| 18 |
CLAUDE.md
|
| 19 |
docs/
|
|
|
|
| 14 |
!deploy/entrypoint-extract.sh
|
| 15 |
!deploy/entrypoint-lc0.sh
|
| 16 |
!deploy/entrypoint-rosa-sweep.sh
|
| 17 |
+
!deploy/entrypoint-lc0-selfplay.sh
|
| 18 |
+
!deploy/entrypoint-lichess-parquet.sh
|
| 19 |
*.so
|
| 20 |
CLAUDE.md
|
| 21 |
docs/
|
Dockerfile
CHANGED
|
@@ -92,6 +92,13 @@ COPY deploy/entrypoint-rosa-sweep.sh /entrypoint-rosa-sweep.sh
|
|
| 92 |
RUN chmod +x /entrypoint-rosa-sweep.sh
|
| 93 |
ENTRYPOINT ["/entrypoint-rosa-sweep.sh"]
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
# ── Interactive (default) — SSH + Jupyter, stays alive ───────────────
|
| 96 |
FROM runtime-base AS interactive
|
| 97 |
# Inherits /start.sh entrypoint from Runpod base image
|
|
|
|
| 92 |
RUN chmod +x /entrypoint-rosa-sweep.sh
|
| 93 |
ENTRYPOINT ["/entrypoint-rosa-sweep.sh"]
|
| 94 |
|
| 95 |
+
# ── Lichess extract — downloads PGN, writes Parquet, pushes to HF ───
|
| 96 |
+
FROM runtime-base AS lichess-extract
|
| 97 |
+
RUN pip install --no-cache-dir zstandard
|
| 98 |
+
COPY deploy/entrypoint-lichess-parquet.sh /entrypoint-lichess-parquet.sh
|
| 99 |
+
RUN chmod +x /entrypoint-lichess-parquet.sh
|
| 100 |
+
ENTRYPOINT ["/entrypoint-lichess-parquet.sh"]
|
| 101 |
+
|
| 102 |
# ── Interactive (default) — SSH + Jupyter, stays alive ───────────────
|
| 103 |
FROM runtime-base AS interactive
|
| 104 |
# Inherits /start.sh entrypoint from Runpod base image
|
deploy/entrypoint-lichess-parquet.sh
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Lichess PGN -> PAWN Parquet extraction entrypoint.
|
| 3 |
+
# Downloads monthly database dumps, parses via Rust engine, writes sharded
|
| 4 |
+
# Parquet with train/val/test splits, and optionally pushes to HuggingFace.
|
| 5 |
+
#
|
| 6 |
+
# Required env vars:
|
| 7 |
+
# MONTHS — space-separated training months (e.g., "2025-01 2025-02 2025-03")
|
| 8 |
+
#
|
| 9 |
+
# Optional env vars:
|
| 10 |
+
# HF_TOKEN — HuggingFace token (for pushing dataset)
|
| 11 |
+
# HF_REPO — HuggingFace dataset repo (e.g., "thomas-schweich/pawn-lichess-full")
|
| 12 |
+
# HOLDOUT_MONTH — month for val/test (e.g., "2023-12")
|
| 13 |
+
# HOLDOUT_GAMES — games per split from holdout month (default: 50000)
|
| 14 |
+
# BATCH_SIZE — games per parsing batch (default: 500000)
|
| 15 |
+
# SHARD_SIZE — games per output shard (default: 1000000)
|
| 16 |
+
# MAX_GAMES — stop after this many training games (for testing)
|
| 17 |
+
# OUTPUT_DIR — output directory (default: /workspace/lichess-parquet)
|
| 18 |
+
# SEED — random seed for holdout sampling (default: 42)
|
| 19 |
+
set -euo pipefail
|
| 20 |
+
|
| 21 |
+
echo "=== Lichess Parquet Extraction ==="
|
| 22 |
+
echo " Training months: ${MONTHS:?MONTHS env var is required}"
|
| 23 |
+
echo " Holdout month: ${HOLDOUT_MONTH:-none}"
|
| 24 |
+
echo " Holdout games/split: ${HOLDOUT_GAMES:-50000}"
|
| 25 |
+
echo " HF Repo: ${HF_REPO:-none}"
|
| 26 |
+
echo " Batch size: ${BATCH_SIZE:-500000}"
|
| 27 |
+
echo " Shard size: ${SHARD_SIZE:-1000000}"
|
| 28 |
+
echo ""
|
| 29 |
+
|
| 30 |
+
# Persist HF token if provided
|
| 31 |
+
if [ -n "${HF_TOKEN:-}" ]; then
|
| 32 |
+
mkdir -p ~/.cache/huggingface
|
| 33 |
+
echo -n "$HF_TOKEN" > ~/.cache/huggingface/token
|
| 34 |
+
echo "HF token persisted"
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
# Install zstandard if not available (needed for streaming decompression)
|
| 38 |
+
python3 -c "import zstandard" 2>/dev/null || pip install --no-cache-dir zstandard
|
| 39 |
+
|
| 40 |
+
# Build the command as an array to avoid shell injection
|
| 41 |
+
CMD=(python3 /opt/pawn/scripts/extract_lichess_parquet.py
|
| 42 |
+
--months $MONTHS
|
| 43 |
+
--output "${OUTPUT_DIR:-/workspace/lichess-parquet}"
|
| 44 |
+
--batch-size "${BATCH_SIZE:-500000}"
|
| 45 |
+
--shard-size "${SHARD_SIZE:-1000000}"
|
| 46 |
+
--seed "${SEED:-42}"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if [ -n "${HOLDOUT_MONTH:-}" ]; then
|
| 50 |
+
CMD+=(--holdout-month "$HOLDOUT_MONTH")
|
| 51 |
+
CMD+=(--holdout-games "${HOLDOUT_GAMES:-50000}")
|
| 52 |
+
fi
|
| 53 |
+
if [ -n "${HF_REPO:-}" ]; then
|
| 54 |
+
CMD+=(--hf-repo "$HF_REPO")
|
| 55 |
+
fi
|
| 56 |
+
if [ -n "${MAX_GAMES:-}" ]; then
|
| 57 |
+
CMD+=(--max-games "$MAX_GAMES")
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
echo "Running: ${CMD[*]}"
|
| 61 |
+
echo ""
|
| 62 |
+
exec "${CMD[@]}"
|
engine/python/chess_engine/__init__.py
CHANGED
|
@@ -26,6 +26,9 @@ from chess_engine._engine import (
|
|
| 26 |
# PGN parsing
|
| 27 |
parse_pgn_file,
|
| 28 |
pgn_to_tokens,
|
|
|
|
|
|
|
|
|
|
| 29 |
# UCI parsing
|
| 30 |
parse_uci_file,
|
| 31 |
uci_to_tokens,
|
|
@@ -60,6 +63,9 @@ __all__ = [
|
|
| 60 |
"validate_games",
|
| 61 |
"parse_pgn_file",
|
| 62 |
"pgn_to_tokens",
|
|
|
|
|
|
|
|
|
|
| 63 |
"parse_uci_file",
|
| 64 |
"uci_to_tokens",
|
| 65 |
"pgn_to_uci",
|
|
|
|
| 26 |
# PGN parsing
|
| 27 |
parse_pgn_file,
|
| 28 |
pgn_to_tokens,
|
| 29 |
+
parse_pgn_enriched,
|
| 30 |
+
count_pgn_games_in_date_range,
|
| 31 |
+
parse_pgn_sampled,
|
| 32 |
# UCI parsing
|
| 33 |
parse_uci_file,
|
| 34 |
uci_to_tokens,
|
|
|
|
| 63 |
"validate_games",
|
| 64 |
"parse_pgn_file",
|
| 65 |
"pgn_to_tokens",
|
| 66 |
+
"parse_pgn_enriched",
|
| 67 |
+
"count_pgn_games_in_date_range",
|
| 68 |
+
"parse_pgn_sampled",
|
| 69 |
"parse_uci_file",
|
| 70 |
"uci_to_tokens",
|
| 71 |
"pgn_to_uci",
|
engine/src/lib.rs
CHANGED
|
@@ -809,6 +809,267 @@ fn pgn_to_uci(py: Python<'_>, games: Vec<Vec<String>>) -> PyResult<Vec<Vec<Strin
|
|
| 809 |
Ok(results.into_iter().map(|(uci, _)| uci).collect())
|
| 810 |
}
|
| 811 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
// ---------------------------------------------------------------------------
|
| 813 |
// UCI engine self-play generation
|
| 814 |
// ---------------------------------------------------------------------------
|
|
@@ -1172,6 +1433,9 @@ fn _engine(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
| 1172 |
m.add_function(wrap_pyfunction!(parse_uci_file, m)?)?;
|
| 1173 |
m.add_function(wrap_pyfunction!(uci_to_tokens, m)?)?;
|
| 1174 |
m.add_function(wrap_pyfunction!(pgn_to_uci, m)?)?;
|
|
|
|
|
|
|
|
|
|
| 1175 |
m.add_function(wrap_pyfunction!(generate_engine_games_py, m)?)?;
|
| 1176 |
m.add_function(wrap_pyfunction!(compute_accuracy_ceiling_py, m)?)?;
|
| 1177 |
Ok(())
|
|
|
|
| 809 |
Ok(results.into_iter().map(|(uci, _)| uci).collect())
|
| 810 |
}
|
| 811 |
|
| 812 |
+
// ---------------------------------------------------------------------------
|
| 813 |
+
// Enriched PGN parsing (for dataset construction)
|
| 814 |
+
// ---------------------------------------------------------------------------
|
| 815 |
+
|
| 816 |
+
/// Parse PGN text with full annotation extraction for dataset building.
|
| 817 |
+
///
|
| 818 |
+
/// Extracts move tokens, clock annotations, eval annotations, and all PGN
|
| 819 |
+
/// headers in a single pass. Designed for streaming: Python passes chunks of
|
| 820 |
+
/// PGN text (containing complete games), Rust returns structured columns.
|
| 821 |
+
///
|
| 822 |
+
/// Returns a dict with:
|
| 823 |
+
/// tokens: ndarray[i16, (N, max_ply)] — PAWN token IDs, 0-padded
|
| 824 |
+
/// clocks: ndarray[u16, (N, max_ply)] — seconds remaining, 0-padded
|
| 825 |
+
/// evals: ndarray[i16, (N, max_ply)] — centipawns, 0-padded (i16::MIN = no annotation)
|
| 826 |
+
/// game_lengths: ndarray[u16, (N,)] — number of plies per game
|
| 827 |
+
/// white_elo: ndarray[u16, (N,)] — white Elo (0 if missing)
|
| 828 |
+
/// black_elo: ndarray[u16, (N,)] — black Elo (0 if missing)
|
| 829 |
+
/// white_rating_diff: ndarray[i16, (N,)] — white rating change (0 if missing)
|
| 830 |
+
/// black_rating_diff: ndarray[i16, (N,)] — black rating change (0 if missing)
|
| 831 |
+
/// result: list[str] — "1-0", "0-1", "1/2-1/2", or ""
|
| 832 |
+
/// white: list[str] — white player name
|
| 833 |
+
/// black: list[str] — black player name
|
| 834 |
+
/// eco: list[str] — ECO code
|
| 835 |
+
/// opening: list[str] — opening name
|
| 836 |
+
/// time_control: list[str] — time control string
|
| 837 |
+
/// termination: list[str] — termination reason
|
| 838 |
+
/// date_time: list[str] — "YYYY.MM.DD HH:MM:SS" UTC
|
| 839 |
+
/// site: list[str] — game URL
|
| 840 |
+
#[pyfunction]
|
| 841 |
+
#[pyo3(signature = (content, max_ply=255, max_games=1_000_000, min_ply=1))]
|
| 842 |
+
fn parse_pgn_enriched<'py>(
|
| 843 |
+
py: Python<'py>,
|
| 844 |
+
content: &str,
|
| 845 |
+
max_ply: usize,
|
| 846 |
+
max_games: usize,
|
| 847 |
+
min_ply: usize,
|
| 848 |
+
) -> PyResult<PyObject> {
|
| 849 |
+
let games = py.allow_threads(|| {
|
| 850 |
+
pgn::parse_pgn_enriched(content, max_ply, max_games, min_ply)
|
| 851 |
+
});
|
| 852 |
+
|
| 853 |
+
let n = games.len();
|
| 854 |
+
let dict = PyDict::new(py);
|
| 855 |
+
|
| 856 |
+
// Flat 0-padded arrays for tokens, clocks, evals (N * max_ply)
|
| 857 |
+
let mut flat_tokens = vec![0i16; n * max_ply];
|
| 858 |
+
let mut flat_clocks = vec![0u16; n * max_ply];
|
| 859 |
+
let mut flat_evals = vec![0i16; n * max_ply];
|
| 860 |
+
|
| 861 |
+
// Scalar arrays
|
| 862 |
+
let mut lengths_out = Vec::with_capacity(n);
|
| 863 |
+
let mut white_elo_out = Vec::with_capacity(n);
|
| 864 |
+
let mut black_elo_out = Vec::with_capacity(n);
|
| 865 |
+
let mut white_rd_out = Vec::with_capacity(n);
|
| 866 |
+
let mut black_rd_out = Vec::with_capacity(n);
|
| 867 |
+
|
| 868 |
+
// String lists
|
| 869 |
+
let mut result_out = Vec::with_capacity(n);
|
| 870 |
+
let mut white_out = Vec::with_capacity(n);
|
| 871 |
+
let mut black_out = Vec::with_capacity(n);
|
| 872 |
+
let mut eco_out = Vec::with_capacity(n);
|
| 873 |
+
let mut opening_out = Vec::with_capacity(n);
|
| 874 |
+
let mut tc_out = Vec::with_capacity(n);
|
| 875 |
+
let mut term_out = Vec::with_capacity(n);
|
| 876 |
+
let mut datetime_out = Vec::with_capacity(n);
|
| 877 |
+
let mut site_out = Vec::with_capacity(n);
|
| 878 |
+
|
| 879 |
+
for (gi, g) in games.iter().enumerate() {
|
| 880 |
+
let offset = gi * max_ply;
|
| 881 |
+
let len = g.game_length.min(max_ply);
|
| 882 |
+
for t in 0..len {
|
| 883 |
+
flat_tokens[offset + t] = g.tokens[t] as i16;
|
| 884 |
+
flat_clocks[offset + t] = g.clocks[t];
|
| 885 |
+
flat_evals[offset + t] = g.evals[t];
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
lengths_out.push(g.game_length as u16);
|
| 889 |
+
|
| 890 |
+
let h = &g.headers;
|
| 891 |
+
white_elo_out.push(h.get("WhiteElo").and_then(|s| s.parse::<u16>().ok()).unwrap_or(0));
|
| 892 |
+
black_elo_out.push(h.get("BlackElo").and_then(|s| s.parse::<u16>().ok()).unwrap_or(0));
|
| 893 |
+
white_rd_out.push(h.get("WhiteRatingDiff").and_then(|s| s.parse::<i16>().ok()).unwrap_or(0));
|
| 894 |
+
black_rd_out.push(h.get("BlackRatingDiff").and_then(|s| s.parse::<i16>().ok()).unwrap_or(0));
|
| 895 |
+
|
| 896 |
+
result_out.push(h.get("Result").cloned().unwrap_or_default());
|
| 897 |
+
white_out.push(h.get("White").cloned().unwrap_or_default());
|
| 898 |
+
black_out.push(h.get("Black").cloned().unwrap_or_default());
|
| 899 |
+
eco_out.push(h.get("ECO").cloned().unwrap_or_default());
|
| 900 |
+
opening_out.push(h.get("Opening").cloned().unwrap_or_default());
|
| 901 |
+
tc_out.push(h.get("TimeControl").cloned().unwrap_or_default());
|
| 902 |
+
term_out.push(h.get("Termination").cloned().unwrap_or_default());
|
| 903 |
+
site_out.push(h.get("Site").cloned().unwrap_or_default());
|
| 904 |
+
|
| 905 |
+
let date = h.get("UTCDate").cloned().unwrap_or_default();
|
| 906 |
+
let time = h.get("UTCTime").cloned().unwrap_or_default();
|
| 907 |
+
if !date.is_empty() && !time.is_empty() {
|
| 908 |
+
datetime_out.push(format!("{} {}", date, time));
|
| 909 |
+
} else {
|
| 910 |
+
datetime_out.push(date);
|
| 911 |
+
}
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
// 2D numpy arrays: (N, max_ply)
|
| 915 |
+
let tokens_arr = numpy::PyArray::from_vec(py, flat_tokens).reshape([n, max_ply])?;
|
| 916 |
+
let clocks_arr = numpy::PyArray::from_vec(py, flat_clocks).reshape([n, max_ply])?;
|
| 917 |
+
let evals_arr = numpy::PyArray::from_vec(py, flat_evals).reshape([n, max_ply])?;
|
| 918 |
+
|
| 919 |
+
// 1D numpy arrays
|
| 920 |
+
let lengths_arr = numpy::PyArray::from_vec(py, lengths_out);
|
| 921 |
+
let white_elo_arr = numpy::PyArray::from_vec(py, white_elo_out);
|
| 922 |
+
let black_elo_arr = numpy::PyArray::from_vec(py, black_elo_out);
|
| 923 |
+
let white_rd_arr = numpy::PyArray::from_vec(py, white_rd_out);
|
| 924 |
+
let black_rd_arr = numpy::PyArray::from_vec(py, black_rd_out);
|
| 925 |
+
|
| 926 |
+
dict.set_item("tokens", tokens_arr)?;
|
| 927 |
+
dict.set_item("clocks", clocks_arr)?;
|
| 928 |
+
dict.set_item("evals", evals_arr)?;
|
| 929 |
+
dict.set_item("game_lengths", lengths_arr)?;
|
| 930 |
+
dict.set_item("white_elo", white_elo_arr)?;
|
| 931 |
+
dict.set_item("black_elo", black_elo_arr)?;
|
| 932 |
+
dict.set_item("white_rating_diff", white_rd_arr)?;
|
| 933 |
+
dict.set_item("black_rating_diff", black_rd_arr)?;
|
| 934 |
+
dict.set_item("result", result_out)?;
|
| 935 |
+
dict.set_item("white", white_out)?;
|
| 936 |
+
dict.set_item("black", black_out)?;
|
| 937 |
+
dict.set_item("eco", eco_out)?;
|
| 938 |
+
dict.set_item("opening", opening_out)?;
|
| 939 |
+
dict.set_item("time_control", tc_out)?;
|
| 940 |
+
dict.set_item("termination", term_out)?;
|
| 941 |
+
dict.set_item("date_time", datetime_out)?;
|
| 942 |
+
dict.set_item("site", site_out)?;
|
| 943 |
+
|
| 944 |
+
Ok(dict.into())
|
| 945 |
+
}
|
| 946 |
+
|
| 947 |
+
/// Count games in a PGN string whose UTCDate falls within [date_start, date_end].
|
| 948 |
+
/// Header-only scan — no tokenization. Very fast.
|
| 949 |
+
#[pyfunction]
|
| 950 |
+
fn count_pgn_games_in_date_range(
|
| 951 |
+
py: Python<'_>,
|
| 952 |
+
content: &str,
|
| 953 |
+
date_start: &str,
|
| 954 |
+
date_end: &str,
|
| 955 |
+
) -> PyResult<usize> {
|
| 956 |
+
let count = py.allow_threads(|| {
|
| 957 |
+
pgn::count_games_in_date_range(content, date_start, date_end)
|
| 958 |
+
});
|
| 959 |
+
Ok(count)
|
| 960 |
+
}
|
| 961 |
+
|
| 962 |
+
/// Parse only specific games (by index within a date range) from a PGN string.
|
| 963 |
+
///
|
| 964 |
+
/// Used for uniform random sampling: call count_pgn_games_in_date_range first
|
| 965 |
+
/// to get the total, generate random indices in Python, then call this to
|
| 966 |
+
/// parse only those games. `game_offset` is the cumulative count of
|
| 967 |
+
/// date-matching games from previous chunks.
|
| 968 |
+
///
|
| 969 |
+
/// Returns the same dict format as parse_pgn_enriched.
|
| 970 |
+
#[pyfunction]
|
| 971 |
+
#[pyo3(signature = (content, indices, date_start, date_end, game_offset=0, max_ply=255, min_ply=1))]
|
| 972 |
+
fn parse_pgn_sampled<'py>(
|
| 973 |
+
py: Python<'py>,
|
| 974 |
+
content: &str,
|
| 975 |
+
indices: Vec<usize>,
|
| 976 |
+
date_start: &str,
|
| 977 |
+
date_end: &str,
|
| 978 |
+
game_offset: usize,
|
| 979 |
+
max_ply: usize,
|
| 980 |
+
min_ply: usize,
|
| 981 |
+
) -> PyResult<PyObject> {
|
| 982 |
+
let index_set: std::collections::HashSet<usize> = indices.into_iter().collect();
|
| 983 |
+
|
| 984 |
+
let games = py.allow_threads(|| {
|
| 985 |
+
pgn::parse_pgn_enriched_sampled(
|
| 986 |
+
content, max_ply, min_ply, date_start, date_end, &index_set, game_offset,
|
| 987 |
+
)
|
| 988 |
+
});
|
| 989 |
+
|
| 990 |
+
// Reuse the same dict-building logic as parse_pgn_enriched
|
| 991 |
+
let n = games.len();
|
| 992 |
+
let dict = PyDict::new(py);
|
| 993 |
+
|
| 994 |
+
let mut flat_tokens = vec![0i16; n * max_ply];
|
| 995 |
+
let mut flat_clocks = vec![0u16; n * max_ply];
|
| 996 |
+
let mut flat_evals = vec![0i16; n * max_ply];
|
| 997 |
+
let mut lengths_out = Vec::with_capacity(n);
|
| 998 |
+
let mut white_elo_out = Vec::with_capacity(n);
|
| 999 |
+
let mut black_elo_out = Vec::with_capacity(n);
|
| 1000 |
+
let mut white_rd_out = Vec::with_capacity(n);
|
| 1001 |
+
let mut black_rd_out = Vec::with_capacity(n);
|
| 1002 |
+
let mut result_out = Vec::with_capacity(n);
|
| 1003 |
+
let mut white_out = Vec::with_capacity(n);
|
| 1004 |
+
let mut black_out = Vec::with_capacity(n);
|
| 1005 |
+
let mut eco_out = Vec::with_capacity(n);
|
| 1006 |
+
let mut opening_out = Vec::with_capacity(n);
|
| 1007 |
+
let mut tc_out = Vec::with_capacity(n);
|
| 1008 |
+
let mut term_out = Vec::with_capacity(n);
|
| 1009 |
+
let mut datetime_out = Vec::with_capacity(n);
|
| 1010 |
+
let mut site_out = Vec::with_capacity(n);
|
| 1011 |
+
|
| 1012 |
+
for (gi, g) in games.iter().enumerate() {
|
| 1013 |
+
let offset = gi * max_ply;
|
| 1014 |
+
let len = g.game_length.min(max_ply);
|
| 1015 |
+
for t in 0..len {
|
| 1016 |
+
flat_tokens[offset + t] = g.tokens[t] as i16;
|
| 1017 |
+
flat_clocks[offset + t] = g.clocks[t];
|
| 1018 |
+
flat_evals[offset + t] = g.evals[t];
|
| 1019 |
+
}
|
| 1020 |
+
lengths_out.push(g.game_length as u16);
|
| 1021 |
+
let h = &g.headers;
|
| 1022 |
+
white_elo_out.push(h.get("WhiteElo").and_then(|s| s.parse::<u16>().ok()).unwrap_or(0));
|
| 1023 |
+
black_elo_out.push(h.get("BlackElo").and_then(|s| s.parse::<u16>().ok()).unwrap_or(0));
|
| 1024 |
+
white_rd_out.push(h.get("WhiteRatingDiff").and_then(|s| s.parse::<i16>().ok()).unwrap_or(0));
|
| 1025 |
+
black_rd_out.push(h.get("BlackRatingDiff").and_then(|s| s.parse::<i16>().ok()).unwrap_or(0));
|
| 1026 |
+
result_out.push(h.get("Result").cloned().unwrap_or_default());
|
| 1027 |
+
white_out.push(h.get("White").cloned().unwrap_or_default());
|
| 1028 |
+
black_out.push(h.get("Black").cloned().unwrap_or_default());
|
| 1029 |
+
eco_out.push(h.get("ECO").cloned().unwrap_or_default());
|
| 1030 |
+
opening_out.push(h.get("Opening").cloned().unwrap_or_default());
|
| 1031 |
+
tc_out.push(h.get("TimeControl").cloned().unwrap_or_default());
|
| 1032 |
+
term_out.push(h.get("Termination").cloned().unwrap_or_default());
|
| 1033 |
+
site_out.push(h.get("Site").cloned().unwrap_or_default());
|
| 1034 |
+
let date = h.get("UTCDate").cloned().unwrap_or_default();
|
| 1035 |
+
let time = h.get("UTCTime").cloned().unwrap_or_default();
|
| 1036 |
+
if !date.is_empty() && !time.is_empty() {
|
| 1037 |
+
datetime_out.push(format!("{} {}", date, time));
|
| 1038 |
+
} else {
|
| 1039 |
+
datetime_out.push(date);
|
| 1040 |
+
}
|
| 1041 |
+
}
|
| 1042 |
+
|
| 1043 |
+
let tokens_arr = numpy::PyArray::from_vec(py, flat_tokens).reshape([n, max_ply])?;
|
| 1044 |
+
let clocks_arr = numpy::PyArray::from_vec(py, flat_clocks).reshape([n, max_ply])?;
|
| 1045 |
+
let evals_arr = numpy::PyArray::from_vec(py, flat_evals).reshape([n, max_ply])?;
|
| 1046 |
+
let lengths_arr = numpy::PyArray::from_vec(py, lengths_out);
|
| 1047 |
+
let white_elo_arr = numpy::PyArray::from_vec(py, white_elo_out);
|
| 1048 |
+
let black_elo_arr = numpy::PyArray::from_vec(py, black_elo_out);
|
| 1049 |
+
let white_rd_arr = numpy::PyArray::from_vec(py, white_rd_out);
|
| 1050 |
+
let black_rd_arr = numpy::PyArray::from_vec(py, black_rd_out);
|
| 1051 |
+
|
| 1052 |
+
dict.set_item("tokens", tokens_arr)?;
|
| 1053 |
+
dict.set_item("clocks", clocks_arr)?;
|
| 1054 |
+
dict.set_item("evals", evals_arr)?;
|
| 1055 |
+
dict.set_item("game_lengths", lengths_arr)?;
|
| 1056 |
+
dict.set_item("white_elo", white_elo_arr)?;
|
| 1057 |
+
dict.set_item("black_elo", black_elo_arr)?;
|
| 1058 |
+
dict.set_item("white_rating_diff", white_rd_arr)?;
|
| 1059 |
+
dict.set_item("black_rating_diff", black_rd_arr)?;
|
| 1060 |
+
dict.set_item("result", result_out)?;
|
| 1061 |
+
dict.set_item("white", white_out)?;
|
| 1062 |
+
dict.set_item("black", black_out)?;
|
| 1063 |
+
dict.set_item("eco", eco_out)?;
|
| 1064 |
+
dict.set_item("opening", opening_out)?;
|
| 1065 |
+
dict.set_item("time_control", tc_out)?;
|
| 1066 |
+
dict.set_item("termination", term_out)?;
|
| 1067 |
+
dict.set_item("date_time", datetime_out)?;
|
| 1068 |
+
dict.set_item("site", site_out)?;
|
| 1069 |
+
|
| 1070 |
+
Ok(dict.into())
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
// ---------------------------------------------------------------------------
|
| 1074 |
// UCI engine self-play generation
|
| 1075 |
// ---------------------------------------------------------------------------
|
|
|
|
| 1433 |
m.add_function(wrap_pyfunction!(parse_uci_file, m)?)?;
|
| 1434 |
m.add_function(wrap_pyfunction!(uci_to_tokens, m)?)?;
|
| 1435 |
m.add_function(wrap_pyfunction!(pgn_to_uci, m)?)?;
|
| 1436 |
+
m.add_function(wrap_pyfunction!(parse_pgn_enriched, m)?)?;
|
| 1437 |
+
m.add_function(wrap_pyfunction!(count_pgn_games_in_date_range, m)?)?;
|
| 1438 |
+
m.add_function(wrap_pyfunction!(parse_pgn_sampled, m)?)?;
|
| 1439 |
m.add_function(wrap_pyfunction!(generate_engine_games_py, m)?)?;
|
| 1440 |
m.add_function(wrap_pyfunction!(compute_accuracy_ceiling_py, m)?)?;
|
| 1441 |
Ok(())
|
engine/src/pgn.rs
CHANGED
|
@@ -3,7 +3,11 @@
|
|
| 3 |
//! Full pipeline in Rust: reads PGN files, extracts SAN move strings,
|
| 4 |
//! converts to PAWN tokens via shakmaty. Uses rayon for parallel
|
| 5 |
//! token conversion.
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
| 7 |
use std::fs;
|
| 8 |
use rayon::prelude::*;
|
| 9 |
use shakmaty::{Chess, Position};
|
|
@@ -11,6 +15,420 @@ use shakmaty::san::San;
|
|
| 11 |
|
| 12 |
use crate::board::move_to_token;
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
/// Convert a sequence of SAN move strings to PAWN token indices.
|
| 15 |
///
|
| 16 |
/// Returns (tokens, n_valid) where tokens has length up to max_ply,
|
|
@@ -277,4 +695,214 @@ mod tests {
|
|
| 277 |
|
| 278 |
fs::remove_file(path).ok();
|
| 279 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
}
|
|
|
|
| 3 |
//! Full pipeline in Rust: reads PGN files, extracts SAN move strings,
|
| 4 |
//! converts to PAWN tokens via shakmaty. Uses rayon for parallel
|
| 5 |
//! token conversion.
|
| 6 |
+
//!
|
| 7 |
+
//! Also provides enriched parsing that extracts clock annotations,
|
| 8 |
+
//! eval annotations, and PGN headers for dataset construction.
|
| 9 |
|
| 10 |
+
use std::collections::{HashMap, HashSet};
|
| 11 |
use std::fs;
|
| 12 |
use rayon::prelude::*;
|
| 13 |
use shakmaty::{Chess, Position};
|
|
|
|
| 15 |
|
| 16 |
use crate::board::move_to_token;
|
| 17 |
|
| 18 |
+
// ---------------------------------------------------------------------------
|
| 19 |
+
// Enriched PGN parsing — extracts moves, clocks, evals, and headers
|
| 20 |
+
// ---------------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
/// A fully parsed game with move tokens, annotations, and metadata.
|
| 23 |
+
pub struct EnrichedGame {
|
| 24 |
+
/// PAWN token indices for each ply (not padded).
|
| 25 |
+
pub tokens: Vec<u16>,
|
| 26 |
+
/// Seconds remaining on clock after each ply (0 = no annotation).
|
| 27 |
+
pub clocks: Vec<u16>,
|
| 28 |
+
/// Centipawns from white's perspective after each ply.
|
| 29 |
+
/// Mate scores: ±(32767-N). No annotation: 0x8000 (-32768 as i16).
|
| 30 |
+
pub evals: Vec<i16>,
|
| 31 |
+
/// Number of valid plies.
|
| 32 |
+
pub game_length: usize,
|
| 33 |
+
/// PGN header fields (e.g., "White" -> "alice", "WhiteElo" -> "1873").
|
| 34 |
+
pub headers: HashMap<String, String>,
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
/// Parse a PGN string into enriched games.
|
| 38 |
+
///
|
| 39 |
+
/// Extracts SAN moves (tokenized), `[%clk h:mm:ss]` annotations,
|
| 40 |
+
/// `[%eval ±N.NN]` / `[%eval #±N]` annotations, and all PGN headers.
|
| 41 |
+
/// Tokenization uses shakmaty and is parallelized with rayon.
|
| 42 |
+
pub fn parse_pgn_enriched(
|
| 43 |
+
content: &str,
|
| 44 |
+
max_ply: usize,
|
| 45 |
+
max_games: usize,
|
| 46 |
+
min_ply: usize,
|
| 47 |
+
) -> Vec<EnrichedGame> {
|
| 48 |
+
let raw_games = parse_raw_games(content, max_games, None, None);
|
| 49 |
+
|
| 50 |
+
// Phase 2: parallel tokenization + annotation extraction
|
| 51 |
+
raw_games
|
| 52 |
+
.into_par_iter()
|
| 53 |
+
.filter_map(|raw| {
|
| 54 |
+
let (san_moves, clocks_raw, evals_raw) = extract_moves_and_annotations(&raw.movetext);
|
| 55 |
+
if san_moves.len() < min_ply {
|
| 56 |
+
return None;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
// Tokenize SAN moves via shakmaty
|
| 60 |
+
let refs: Vec<&str> = san_moves.iter().map(|s| s.as_str()).collect();
|
| 61 |
+
let (tokens, n_valid) = san_moves_to_tokens(&refs, max_ply);
|
| 62 |
+
if n_valid < min_ply {
|
| 63 |
+
return None;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// Trim annotations to match token count (moves may have failed to parse).
|
| 67 |
+
let clocks = clocks_raw.into_iter().take(n_valid).collect();
|
| 68 |
+
let evals = evals_raw.into_iter().take(n_valid).collect();
|
| 69 |
+
|
| 70 |
+
Some(EnrichedGame {
|
| 71 |
+
tokens,
|
| 72 |
+
clocks,
|
| 73 |
+
evals,
|
| 74 |
+
game_length: n_valid,
|
| 75 |
+
headers: raw.headers,
|
| 76 |
+
})
|
| 77 |
+
})
|
| 78 |
+
.collect()
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
/// Count games in a PGN string whose UTCDate falls within [start, end].
|
| 82 |
+
///
|
| 83 |
+
/// Header-only scan — no movetext parsing, no tokenization.
|
| 84 |
+
/// Returns (count_in_range, offset) where offset is the running game index
|
| 85 |
+
/// that should be passed to the next chunk for correct global indexing.
|
| 86 |
+
pub fn count_games_in_date_range(
|
| 87 |
+
content: &str,
|
| 88 |
+
date_start: &str,
|
| 89 |
+
date_end: &str,
|
| 90 |
+
) -> usize {
|
| 91 |
+
let mut count = 0;
|
| 92 |
+
let mut current_date: Option<String> = None;
|
| 93 |
+
let mut in_movetext = false;
|
| 94 |
+
|
| 95 |
+
for line in content.lines() {
|
| 96 |
+
let line = line.trim();
|
| 97 |
+
if line.is_empty() {
|
| 98 |
+
if in_movetext {
|
| 99 |
+
// End of game — check if the date was in range
|
| 100 |
+
if let Some(ref d) = current_date {
|
| 101 |
+
if d.as_str() >= date_start && d.as_str() <= date_end {
|
| 102 |
+
count += 1;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
current_date = None;
|
| 106 |
+
in_movetext = false;
|
| 107 |
+
}
|
| 108 |
+
continue;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
if line.starts_with('[') && line.ends_with(']') {
|
| 112 |
+
if let Some((key, value)) = parse_header_line(line) {
|
| 113 |
+
if key == "UTCDate" {
|
| 114 |
+
current_date = Some(value);
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
in_movetext = false;
|
| 118 |
+
} else {
|
| 119 |
+
in_movetext = true;
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
// Handle last game
|
| 124 |
+
if in_movetext {
|
| 125 |
+
if let Some(ref d) = current_date {
|
| 126 |
+
if d.as_str() >= date_start && d.as_str() <= date_end {
|
| 127 |
+
count += 1;
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
count
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
/// Parse a PGN string, but only tokenize games at specific indices within a
|
| 136 |
+
/// date range. Used for uniform random sampling: Python counts games in the
|
| 137 |
+
/// date range (via `count_games_in_date_range`), generates a random index
|
| 138 |
+
/// set, then calls this to parse only those games.
|
| 139 |
+
///
|
| 140 |
+
/// `indices` are 0-based within the date-range-matching games of this chunk.
|
| 141 |
+
/// `game_offset` is the number of date-matching games seen in previous chunks,
|
| 142 |
+
/// so global index = game_offset + local_index.
|
| 143 |
+
pub fn parse_pgn_enriched_sampled(
|
| 144 |
+
content: &str,
|
| 145 |
+
max_ply: usize,
|
| 146 |
+
min_ply: usize,
|
| 147 |
+
date_start: &str,
|
| 148 |
+
date_end: &str,
|
| 149 |
+
indices: &HashSet<usize>,
|
| 150 |
+
game_offset: usize,
|
| 151 |
+
) -> Vec<EnrichedGame> {
|
| 152 |
+
let raw_games = parse_raw_games(content, usize::MAX, Some((date_start, date_end)), Some((indices, game_offset)));
|
| 153 |
+
|
| 154 |
+
raw_games
|
| 155 |
+
.into_par_iter()
|
| 156 |
+
.filter_map(|raw| {
|
| 157 |
+
let (san_moves, clocks_raw, evals_raw) = extract_moves_and_annotations(&raw.movetext);
|
| 158 |
+
if san_moves.len() < min_ply {
|
| 159 |
+
return None;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
let refs: Vec<&str> = san_moves.iter().map(|s| s.as_str()).collect();
|
| 163 |
+
let (tokens, n_valid) = san_moves_to_tokens(&refs, max_ply);
|
| 164 |
+
if n_valid < min_ply {
|
| 165 |
+
return None;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
let clocks = clocks_raw.into_iter().take(n_valid).collect();
|
| 169 |
+
let evals = evals_raw.into_iter().take(n_valid).collect();
|
| 170 |
+
|
| 171 |
+
Some(EnrichedGame {
|
| 172 |
+
tokens,
|
| 173 |
+
clocks,
|
| 174 |
+
evals,
|
| 175 |
+
game_length: n_valid,
|
| 176 |
+
headers: raw.headers,
|
| 177 |
+
})
|
| 178 |
+
})
|
| 179 |
+
.collect()
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
/// Raw game data before tokenization.
|
| 183 |
+
struct RawGame {
|
| 184 |
+
headers: HashMap<String, String>,
|
| 185 |
+
movetext: String,
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/// Single-threaded PGN line scanner. Extracts headers and raw movetext.
|
| 189 |
+
///
|
| 190 |
+
/// If `date_range` is Some((start, end)), only games whose UTCDate falls
|
| 191 |
+
/// within [start, end] are included. If `sample` is Some((indices, offset)),
|
| 192 |
+
/// only games whose (offset + local_index) is in the index set are kept.
|
| 193 |
+
fn parse_raw_games(
|
| 194 |
+
content: &str,
|
| 195 |
+
max_games: usize,
|
| 196 |
+
date_range: Option<(&str, &str)>,
|
| 197 |
+
sample: Option<(&HashSet<usize>, usize)>,
|
| 198 |
+
) -> Vec<RawGame> {
|
| 199 |
+
let mut games = Vec::new();
|
| 200 |
+
let mut headers: HashMap<String, String> = HashMap::new();
|
| 201 |
+
let mut movetext_lines: Vec<&str> = Vec::new();
|
| 202 |
+
let mut in_movetext = false;
|
| 203 |
+
let mut date_excluded = false; // UTCDate outside date_range
|
| 204 |
+
let mut has_utc_date = false; // saw a UTCDate header for this game
|
| 205 |
+
let mut date_matched_idx = 0usize; // count of date-matching games seen
|
| 206 |
+
|
| 207 |
+
for line in content.lines() {
|
| 208 |
+
let line = line.trim();
|
| 209 |
+
if line.is_empty() {
|
| 210 |
+
if in_movetext {
|
| 211 |
+
// End of movetext — game boundary.
|
| 212 |
+
// Exclude if date is out of range OR if date_range is active
|
| 213 |
+
// but no UTCDate header was found (consistent with count_games_in_date_range).
|
| 214 |
+
let excluded = date_excluded || (date_range.is_some() && !has_utc_date);
|
| 215 |
+
if !excluded && !movetext_lines.is_empty() {
|
| 216 |
+
// Game passed date filter. Check sample if present.
|
| 217 |
+
let keep = match sample {
|
| 218 |
+
Some((indices, offset)) => indices.contains(&(offset + date_matched_idx)),
|
| 219 |
+
None => true,
|
| 220 |
+
};
|
| 221 |
+
date_matched_idx += 1;
|
| 222 |
+
|
| 223 |
+
if keep {
|
| 224 |
+
games.push(RawGame {
|
| 225 |
+
headers: std::mem::take(&mut headers),
|
| 226 |
+
movetext: movetext_lines.join(" "),
|
| 227 |
+
});
|
| 228 |
+
if games.len() >= max_games {
|
| 229 |
+
break;
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
movetext_lines.clear();
|
| 234 |
+
headers.clear();
|
| 235 |
+
in_movetext = false;
|
| 236 |
+
date_excluded = false;
|
| 237 |
+
has_utc_date = false;
|
| 238 |
+
}
|
| 239 |
+
// Blank line between headers and movetext: don't reset state
|
| 240 |
+
continue;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
// Header line: [Key "Value"]
|
| 244 |
+
if line.starts_with('[') && line.ends_with(']') {
|
| 245 |
+
if let Some((key, value)) = parse_header_line(line) {
|
| 246 |
+
if key == "UTCDate" {
|
| 247 |
+
has_utc_date = true;
|
| 248 |
+
if let Some((start, end)) = date_range {
|
| 249 |
+
if value.as_str() < start || value.as_str() > end {
|
| 250 |
+
date_excluded = true;
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
if !date_excluded {
|
| 255 |
+
headers.insert(key, value);
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
in_movetext = false;
|
| 259 |
+
continue;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
if !date_excluded {
|
| 263 |
+
in_movetext = true;
|
| 264 |
+
movetext_lines.push(line);
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// Handle last game
|
| 269 |
+
let last_excluded = date_excluded || (date_range.is_some() && !has_utc_date);
|
| 270 |
+
if in_movetext && !last_excluded && !movetext_lines.is_empty() && games.len() < max_games {
|
| 271 |
+
let keep = match sample {
|
| 272 |
+
Some((indices, offset)) => indices.contains(&(offset + date_matched_idx)),
|
| 273 |
+
None => true,
|
| 274 |
+
};
|
| 275 |
+
if keep {
|
| 276 |
+
games.push(RawGame {
|
| 277 |
+
headers: std::mem::take(&mut headers),
|
| 278 |
+
movetext: movetext_lines.join(" "),
|
| 279 |
+
});
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
games
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
/// Parse a PGN header line like `[White "alice"]` into ("White", "alice").
|
| 287 |
+
fn parse_header_line(line: &str) -> Option<(String, String)> {
|
| 288 |
+
// Strip surrounding brackets
|
| 289 |
+
let inner = line.strip_prefix('[')?.strip_suffix(']')?.trim();
|
| 290 |
+
let space = inner.find(' ')?;
|
| 291 |
+
let key = inner[..space].to_string();
|
| 292 |
+
let value_part = inner[space..].trim();
|
| 293 |
+
// Strip surrounding quotes
|
| 294 |
+
let value = value_part
|
| 295 |
+
.strip_prefix('"')
|
| 296 |
+
.and_then(|v| v.strip_suffix('"'))
|
| 297 |
+
.unwrap_or(value_part)
|
| 298 |
+
.to_string();
|
| 299 |
+
Some((key, value))
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
/// Sentinel for "no clock annotation" (0x8000 as u16 = 32768).
|
| 303 |
+
const CLOCK_NONE: u16 = 0x8000;
|
| 304 |
+
/// Sentinel for "no eval annotation" (0x8000 as i16 = -32768).
|
| 305 |
+
const EVAL_NONE: i16 = -0x8000; // i16::MIN
|
| 306 |
+
|
| 307 |
+
/// Extract SAN moves, clock annotations, and eval annotations from movetext.
|
| 308 |
+
///
|
| 309 |
+
/// Returns (san_moves, clocks, evals) where clocks[i] is the clock after
|
| 310 |
+
/// move i (CLOCK_NONE if no annotation) and evals[i] is centipawns after
|
| 311 |
+
/// move i (EVAL_NONE if no annotation).
|
| 312 |
+
fn extract_moves_and_annotations(text: &str) -> (Vec<String>, Vec<u16>, Vec<i16>) {
|
| 313 |
+
let mut moves = Vec::new();
|
| 314 |
+
let mut clocks = Vec::new();
|
| 315 |
+
let mut evals = Vec::new();
|
| 316 |
+
|
| 317 |
+
let bytes = text.as_bytes();
|
| 318 |
+
let len = bytes.len();
|
| 319 |
+
|
| 320 |
+
// Lichess format: move { comment } move { comment } ...
|
| 321 |
+
// The comment annotates the move immediately before it.
|
| 322 |
+
let mut i = 0;
|
| 323 |
+
while i < len {
|
| 324 |
+
if bytes[i].is_ascii_whitespace() {
|
| 325 |
+
i += 1;
|
| 326 |
+
continue;
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
// Comment: { ... } — applies to the last pushed move
|
| 330 |
+
if bytes[i] == b'{' {
|
| 331 |
+
i += 1;
|
| 332 |
+
let start = i;
|
| 333 |
+
while i < len && bytes[i] != b'}' {
|
| 334 |
+
i += 1;
|
| 335 |
+
}
|
| 336 |
+
let comment = &text[start..i];
|
| 337 |
+
if i < len { i += 1; }
|
| 338 |
+
|
| 339 |
+
// Apply to last move
|
| 340 |
+
if let Some(last_clk) = clocks.last_mut() {
|
| 341 |
+
let mut clk = CLOCK_NONE;
|
| 342 |
+
let mut ev = EVAL_NONE;
|
| 343 |
+
parse_comment(comment, &mut clk, &mut ev);
|
| 344 |
+
if clk != CLOCK_NONE { *last_clk = clk; }
|
| 345 |
+
if ev != EVAL_NONE {
|
| 346 |
+
if let Some(last_ev) = evals.last_mut() {
|
| 347 |
+
*last_ev = ev;
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
continue;
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
let start = i;
|
| 355 |
+
while i < len && !bytes[i].is_ascii_whitespace() && bytes[i] != b'{' {
|
| 356 |
+
i += 1;
|
| 357 |
+
}
|
| 358 |
+
let token = &text[start..i];
|
| 359 |
+
|
| 360 |
+
if token.starts_with('$') { continue; }
|
| 361 |
+
if token == "1-0" || token == "0-1" || token == "1/2-1/2" || token == "*" { break; }
|
| 362 |
+
|
| 363 |
+
let stripped = token.trim_end_matches('.');
|
| 364 |
+
if !stripped.is_empty() && stripped.bytes().all(|b| b.is_ascii_digit()) { continue; }
|
| 365 |
+
|
| 366 |
+
moves.push(token.to_string());
|
| 367 |
+
clocks.push(CLOCK_NONE);
|
| 368 |
+
evals.push(EVAL_NONE);
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
(moves, clocks, evals)
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
/// Parse a PGN comment body for clock and eval annotations.
|
| 375 |
+
///
|
| 376 |
+
/// Lichess format: `[%clk 0:03:00]` and `[%eval 1.23]` or `[%eval #-3]`.
|
| 377 |
+
fn parse_comment(comment: &str, clock: &mut u16, eval: &mut i16) {
|
| 378 |
+
// Clock: [%clk H:MM:SS]
|
| 379 |
+
if let Some(pos) = comment.find("[%clk ") {
|
| 380 |
+
let rest = &comment[pos + 6..];
|
| 381 |
+
if let Some(end) = rest.find(']') {
|
| 382 |
+
let clk_str = rest[..end].trim();
|
| 383 |
+
if let Some(secs) = parse_clock(clk_str) {
|
| 384 |
+
*clock = secs;
|
| 385 |
+
}
|
| 386 |
+
}
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
// Eval: [%eval 1.23] or [%eval #-3]
|
| 390 |
+
if let Some(pos) = comment.find("[%eval ") {
|
| 391 |
+
let rest = &comment[pos + 7..];
|
| 392 |
+
if let Some(end) = rest.find(']') {
|
| 393 |
+
let eval_str = rest[..end].trim();
|
| 394 |
+
if let Some(cp) = parse_eval(eval_str) {
|
| 395 |
+
*eval = cp;
|
| 396 |
+
}
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
/// Parse "H:MM:SS" into total seconds as u16.
|
| 402 |
+
fn parse_clock(s: &str) -> Option<u16> {
|
| 403 |
+
let parts: Vec<&str> = s.split(':').collect();
|
| 404 |
+
if parts.len() != 3 { return None; }
|
| 405 |
+
let h: u32 = parts[0].parse().ok()?;
|
| 406 |
+
let m: u32 = parts[1].parse().ok()?;
|
| 407 |
+
let s: u32 = parts[2].parse().ok()?;
|
| 408 |
+
let total = h * 3600 + m * 60 + s;
|
| 409 |
+
// Cap at 0x7FFF (32767) to avoid collision with CLOCK_NONE (0x8000)
|
| 410 |
+
Some(total.min(0x7FFF) as u16)
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
/// Parse eval string into centipawns (i16).
|
| 414 |
+
/// "1.23" → 123, "-0.50" → -50.
|
| 415 |
+
/// Mate scores: "#N" → 32767-N, "#-N" → -(32767-N).
|
| 416 |
+
/// Bit 14 is always set for mates, making them detectable via bitmask.
|
| 417 |
+
/// Centipawn values are clamped to ±16383 to avoid overlap with the mate range.
|
| 418 |
+
fn parse_eval(s: &str) -> Option<i16> {
|
| 419 |
+
if s.starts_with('#') {
|
| 420 |
+
let rest = &s[1..];
|
| 421 |
+
let n: i32 = rest.parse().ok()?;
|
| 422 |
+
let abs_n = n.unsigned_abs().max(1) as i16;
|
| 423 |
+
let mate_val = 32767 - abs_n;
|
| 424 |
+
Some(if n > 0 { mate_val } else { -mate_val })
|
| 425 |
+
} else {
|
| 426 |
+
let f: f64 = s.parse().ok()?;
|
| 427 |
+
let cp = (f * 100.0).round() as i32;
|
| 428 |
+
Some(cp.clamp(-16383, 16383) as i16)
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
/// Convert a sequence of SAN move strings to PAWN token indices.
|
| 433 |
///
|
| 434 |
/// Returns (tokens, n_valid) where tokens has length up to max_ply,
|
|
|
|
| 695 |
|
| 696 |
fs::remove_file(path).ok();
|
| 697 |
}
|
| 698 |
+
|
| 699 |
+
// --- Enriched parsing tests ---
|
| 700 |
+
|
| 701 |
+
#[test]
|
| 702 |
+
fn test_parse_clock() {
|
| 703 |
+
assert_eq!(parse_clock("0:10:00"), Some(600));
|
| 704 |
+
assert_eq!(parse_clock("1:30:00"), Some(5400));
|
| 705 |
+
assert_eq!(parse_clock("0:00:05"), Some(5));
|
| 706 |
+
assert_eq!(parse_clock("0:03:00"), Some(180));
|
| 707 |
+
assert_eq!(parse_clock("bad"), None);
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
#[test]
|
| 711 |
+
fn test_parse_eval() {
|
| 712 |
+
assert_eq!(parse_eval("0.23"), Some(23));
|
| 713 |
+
assert_eq!(parse_eval("-1.50"), Some(-150));
|
| 714 |
+
assert_eq!(parse_eval("0.00"), Some(0));
|
| 715 |
+
// Mate scores: 32767 - N
|
| 716 |
+
assert_eq!(parse_eval("#1"), Some(32766));
|
| 717 |
+
assert_eq!(parse_eval("#-1"), Some(-32766));
|
| 718 |
+
assert_eq!(parse_eval("#3"), Some(32764));
|
| 719 |
+
assert_eq!(parse_eval("#-3"), Some(-32764));
|
| 720 |
+
assert_eq!(parse_eval("#10"), Some(32757));
|
| 721 |
+
// Bit 14 (0x4000 = 16384) is set for all mate values
|
| 722 |
+
assert!(parse_eval("#1").unwrap() & 0x4000 != 0);
|
| 723 |
+
assert!(parse_eval("#100").unwrap() & 0x4000 != 0);
|
| 724 |
+
// Centipawns clamped to ±16383 to avoid mate range
|
| 725 |
+
assert_eq!(parse_eval("200.00"), Some(16383));
|
| 726 |
+
assert_eq!(parse_eval("-200.00"), Some(-16383));
|
| 727 |
+
}
|
| 728 |
+
|
| 729 |
+
#[test]
|
| 730 |
+
fn test_parse_header_line() {
|
| 731 |
+
assert_eq!(
|
| 732 |
+
parse_header_line(r#"[White "alice"]"#),
|
| 733 |
+
Some(("White".to_string(), "alice".to_string()))
|
| 734 |
+
);
|
| 735 |
+
assert_eq!(
|
| 736 |
+
parse_header_line(r#"[WhiteElo "1873"]"#),
|
| 737 |
+
Some(("WhiteElo".to_string(), "1873".to_string()))
|
| 738 |
+
);
|
| 739 |
+
assert_eq!(
|
| 740 |
+
parse_header_line(r#"[Opening "Bird Opening: Dutch Variation"]"#),
|
| 741 |
+
Some(("Opening".to_string(), "Bird Opening: Dutch Variation".to_string()))
|
| 742 |
+
);
|
| 743 |
+
}
|
| 744 |
+
|
| 745 |
+
#[test]
|
| 746 |
+
fn test_extract_moves_and_annotations() {
|
| 747 |
+
let text = r#"1. e4 { [%clk 0:10:00] [%eval 0.23] } 1... e5 { [%clk 0:09:58] [%eval 0.31] } 2. Nf3 { [%clk 0:09:55] } 1-0"#;
|
| 748 |
+
let (moves, clocks, evals) = extract_moves_and_annotations(text);
|
| 749 |
+
assert_eq!(moves, vec!["e4", "e5", "Nf3"]);
|
| 750 |
+
assert_eq!(clocks, vec![600, 598, 595]);
|
| 751 |
+
assert_eq!(evals, vec![23, 31, EVAL_NONE]);
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
#[test]
|
| 755 |
+
fn test_extract_moves_no_annotations() {
|
| 756 |
+
let text = "1. e4 e5 2. Nf3 Nc6 1-0";
|
| 757 |
+
let (moves, clocks, evals) = extract_moves_and_annotations(text);
|
| 758 |
+
assert_eq!(moves, vec!["e4", "e5", "Nf3", "Nc6"]);
|
| 759 |
+
assert_eq!(clocks, vec![CLOCK_NONE, CLOCK_NONE, CLOCK_NONE, CLOCK_NONE]);
|
| 760 |
+
assert_eq!(evals, vec![EVAL_NONE; 4]);
|
| 761 |
+
}
|
| 762 |
+
|
| 763 |
+
#[test]
|
| 764 |
+
fn test_extract_moves_mate_eval() {
|
| 765 |
+
let text = r#"1. e4 { [%eval 0.23] } 1... e5 { [%eval #-3] } 1-0"#;
|
| 766 |
+
let (moves, _clocks, evals) = extract_moves_and_annotations(text);
|
| 767 |
+
assert_eq!(moves, vec!["e4", "e5"]);
|
| 768 |
+
assert_eq!(evals, vec![23, -32764]);
|
| 769 |
+
}
|
| 770 |
+
|
| 771 |
+
#[test]
|
| 772 |
+
fn test_enriched_full_game() {
|
| 773 |
+
let pgn = r#"[Event "Rated Rapid game"]
|
| 774 |
+
[Site "https://lichess.org/abc123"]
|
| 775 |
+
[White "alice"]
|
| 776 |
+
[Black "bob"]
|
| 777 |
+
[Result "1-0"]
|
| 778 |
+
[WhiteElo "1873"]
|
| 779 |
+
[BlackElo "1844"]
|
| 780 |
+
[WhiteRatingDiff "+6"]
|
| 781 |
+
[BlackRatingDiff "-26"]
|
| 782 |
+
[ECO "C20"]
|
| 783 |
+
[Opening "King's Pawn Game"]
|
| 784 |
+
[TimeControl "600+0"]
|
| 785 |
+
[Termination "Normal"]
|
| 786 |
+
[UTCDate "2025.01.15"]
|
| 787 |
+
[UTCTime "12:30:00"]
|
| 788 |
+
|
| 789 |
+
1. e4 { [%clk 0:10:00] [%eval 0.23] } 1... e5 { [%clk 0:09:58] [%eval 0.31] } 2. Nf3 { [%clk 0:09:50] [%eval 0.25] } 2... Nc6 { [%clk 0:09:45] [%eval 0.30] } 1-0
|
| 790 |
+
"#;
|
| 791 |
+
let games = parse_pgn_enriched(pgn, 256, 100, 2);
|
| 792 |
+
assert_eq!(games.len(), 1);
|
| 793 |
+
let g = &games[0];
|
| 794 |
+
assert_eq!(g.game_length, 4);
|
| 795 |
+
assert_eq!(g.clocks, vec![600, 598, 590, 585]);
|
| 796 |
+
assert_eq!(g.evals, vec![23, 31, 25, 30]);
|
| 797 |
+
assert_eq!(g.headers.get("White").unwrap(), "alice");
|
| 798 |
+
assert_eq!(g.headers.get("WhiteElo").unwrap(), "1873");
|
| 799 |
+
assert_eq!(g.headers.get("Site").unwrap(), "https://lichess.org/abc123");
|
| 800 |
+
assert_eq!(g.headers.get("ECO").unwrap(), "C20");
|
| 801 |
+
assert_eq!(g.headers.get("TimeControl").unwrap(), "600+0");
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
#[test]
|
| 805 |
+
fn test_enriched_tokens_match_legacy() {
|
| 806 |
+
// Enriched parsing should produce the same tokens as the legacy pipeline
|
| 807 |
+
let pgn = r#"[Event "Test"]
|
| 808 |
+
|
| 809 |
+
1. e4 { [%clk 0:10:00] } 1... e5 { [%clk 0:09:58] } 2. Nf3 { [%clk 0:09:50] } 2... Nc6 { [%clk 0:09:45] } 1-0
|
| 810 |
+
"#;
|
| 811 |
+
let enriched = parse_pgn_enriched(pgn, 256, 100, 2);
|
| 812 |
+
let legacy = parse_pgn_to_san(pgn, 100);
|
| 813 |
+
|
| 814 |
+
assert_eq!(enriched.len(), 1);
|
| 815 |
+
assert_eq!(legacy.len(), 1);
|
| 816 |
+
|
| 817 |
+
// Convert legacy SAN to tokens for comparison
|
| 818 |
+
let refs: Vec<&str> = legacy[0].iter().map(|s| s.as_str()).collect();
|
| 819 |
+
let (legacy_tokens, legacy_n) = san_moves_to_tokens(&refs, 256);
|
| 820 |
+
|
| 821 |
+
assert_eq!(enriched[0].tokens, legacy_tokens);
|
| 822 |
+
assert_eq!(enriched[0].game_length, legacy_n);
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
#[test]
|
| 826 |
+
fn test_count_games_in_date_range() {
|
| 827 |
+
let pgn = r#"[Event "Game 1"]
|
| 828 |
+
[UTCDate "2023.12.05"]
|
| 829 |
+
|
| 830 |
+
1. e4 e5 1-0
|
| 831 |
+
|
| 832 |
+
[Event "Game 2"]
|
| 833 |
+
[UTCDate "2023.12.20"]
|
| 834 |
+
|
| 835 |
+
1. d4 d5 0-1
|
| 836 |
+
|
| 837 |
+
[Event "Game 3"]
|
| 838 |
+
[UTCDate "2025.01.15"]
|
| 839 |
+
|
| 840 |
+
1. e4 c5 1-0
|
| 841 |
+
"#;
|
| 842 |
+
assert_eq!(count_games_in_date_range(pgn, "2023.12.01", "2023.12.31"), 2);
|
| 843 |
+
assert_eq!(count_games_in_date_range(pgn, "2023.12.01", "2023.12.14"), 1);
|
| 844 |
+
assert_eq!(count_games_in_date_range(pgn, "2023.12.15", "2023.12.31"), 1);
|
| 845 |
+
assert_eq!(count_games_in_date_range(pgn, "2025.01.01", "2025.01.31"), 1);
|
| 846 |
+
assert_eq!(count_games_in_date_range(pgn, "2024.01.01", "2024.12.31"), 0);
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
#[test]
|
| 850 |
+
fn test_sampled_parsing() {
|
| 851 |
+
let pgn = r#"[Event "Game 1"]
|
| 852 |
+
[UTCDate "2023.12.05"]
|
| 853 |
+
|
| 854 |
+
1. e4 e5 1-0
|
| 855 |
+
|
| 856 |
+
[Event "Game 2"]
|
| 857 |
+
[UTCDate "2023.12.10"]
|
| 858 |
+
|
| 859 |
+
1. d4 d5 0-1
|
| 860 |
+
|
| 861 |
+
[Event "Game 3"]
|
| 862 |
+
[UTCDate "2023.12.20"]
|
| 863 |
+
|
| 864 |
+
1. e4 c5 1-0
|
| 865 |
+
|
| 866 |
+
[Event "Game 4"]
|
| 867 |
+
[UTCDate "2025.01.15"]
|
| 868 |
+
|
| 869 |
+
1. Nf3 d5 1-0
|
| 870 |
+
"#;
|
| 871 |
+
// 3 games match Dec 2023 (indices 0, 1, 2 within the date range)
|
| 872 |
+
assert_eq!(count_games_in_date_range(pgn, "2023.12.01", "2023.12.31"), 3);
|
| 873 |
+
|
| 874 |
+
// Sample only index 1 (Game 2)
|
| 875 |
+
let indices: HashSet<usize> = HashSet::from([1]);
|
| 876 |
+
let sampled = parse_pgn_enriched_sampled(
|
| 877 |
+
pgn, 256, 2, "2023.12.01", "2023.12.31", &indices, 0,
|
| 878 |
+
);
|
| 879 |
+
assert_eq!(sampled.len(), 1);
|
| 880 |
+
assert_eq!(sampled[0].headers.get("UTCDate").unwrap(), "2023.12.10");
|
| 881 |
+
|
| 882 |
+
// Sample indices 0 and 2 (Game 1 and Game 3)
|
| 883 |
+
let indices: HashSet<usize> = HashSet::from([0, 2]);
|
| 884 |
+
let sampled = parse_pgn_enriched_sampled(
|
| 885 |
+
pgn, 256, 2, "2023.12.01", "2023.12.31", &indices, 0,
|
| 886 |
+
);
|
| 887 |
+
assert_eq!(sampled.len(), 2);
|
| 888 |
+
assert_eq!(sampled[0].headers.get("UTCDate").unwrap(), "2023.12.05");
|
| 889 |
+
assert_eq!(sampled[1].headers.get("UTCDate").unwrap(), "2023.12.20");
|
| 890 |
+
|
| 891 |
+
// Sample with offset: simulating a second chunk where previous chunk had 1 match.
|
| 892 |
+
// Global index 2 = offset 1 + local index 1 => selects Game 2 (local idx 1).
|
| 893 |
+
let indices: HashSet<usize> = HashSet::from([2]);
|
| 894 |
+
let sampled = parse_pgn_enriched_sampled(
|
| 895 |
+
pgn, 256, 2, "2023.12.01", "2023.12.31", &indices, 1,
|
| 896 |
+
);
|
| 897 |
+
assert_eq!(sampled.len(), 1);
|
| 898 |
+
assert_eq!(sampled[0].headers.get("UTCDate").unwrap(), "2023.12.10");
|
| 899 |
+
|
| 900 |
+
// Offset that skips all local games: offset=3 means local indices are 3,4,5
|
| 901 |
+
// but we only ask for global index 0, which isn't in this chunk.
|
| 902 |
+
let indices: HashSet<usize> = HashSet::from([0]);
|
| 903 |
+
let sampled = parse_pgn_enriched_sampled(
|
| 904 |
+
pgn, 256, 2, "2023.12.01", "2023.12.31", &indices, 3,
|
| 905 |
+
);
|
| 906 |
+
assert_eq!(sampled.len(), 0);
|
| 907 |
+
}
|
| 908 |
}
|
scripts/extract_lichess_parquet.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Extract Lichess monthly PGN database dumps into PAWN-compatible Parquet.
|
| 3 |
+
|
| 4 |
+
Downloads a zstd-compressed PGN from database.lichess.org, parses games via
|
| 5 |
+
the Rust chess engine (tokens, clocks, evals, headers), builds a Polars
|
| 6 |
+
DataFrame, and writes sharded Parquet to disk with train/val/test splits.
|
| 7 |
+
|
| 8 |
+
The output schema stores pre-tokenized move sequences as list[int16],
|
| 9 |
+
clock annotations as list[uint16] (seconds remaining, 0x8000=missing),
|
| 10 |
+
eval annotations as list[int16] (centipawns, mate=±(32767-N),
|
| 11 |
+
0x8000=missing), and metadata columns. Player usernames are hashed to
|
| 12 |
+
uint64 via Polars xxHash64 (deterministic within a Polars version).
|
| 13 |
+
|
| 14 |
+
Training months are written as chronologically-ordered shards. Holdout
|
| 15 |
+
val/test data is uniformly sampled from a separate month via two-pass
|
| 16 |
+
Rust-side date filtering.
|
| 17 |
+
|
| 18 |
+
Designed to run on a CPU pod with the pawn Docker image.
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
python scripts/extract_lichess_parquet.py \\
|
| 22 |
+
--months 2025-01 2025-02 2025-03 \\
|
| 23 |
+
--output /workspace/lichess-parquet \\
|
| 24 |
+
--hf-repo thomas-schweich/lichess-pawn \\
|
| 25 |
+
--batch-size 500000
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import io
|
| 30 |
+
import os
|
| 31 |
+
import sys
|
| 32 |
+
import time
|
| 33 |
+
import urllib.request
|
| 34 |
+
from datetime import datetime
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
import chess_engine
|
| 39 |
+
import polars as pl
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Constants
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
LICHESS_URL_TEMPLATE = (
|
| 47 |
+
"https://database.lichess.org/standard/"
|
| 48 |
+
"lichess_db_standard_rated_{year_month}.pgn.zst"
|
| 49 |
+
)
|
| 50 |
+
MAX_PLY = 255 # Max plies per game (token sequence = outcome + 255 plies)
|
| 51 |
+
EVAL_MISSING = -32768 # i16::MIN sentinel for missing eval
|
| 52 |
+
SHARD_TARGET_GAMES = 1_000_000 # Target games per shard
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def log(msg: str) -> None:
|
| 56 |
+
ts = datetime.now().strftime("%H:%M:%S")
|
| 57 |
+
print(f"[{ts}] {msg}", flush=True)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# PGN streaming
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
def stream_pgn_games(fileobj, batch_size: int):
|
| 65 |
+
"""Yield batches of complete PGN game strings from a text stream.
|
| 66 |
+
|
| 67 |
+
Each batch is a single string containing `batch_size` complete games
|
| 68 |
+
(delimited by blank lines between the last movetext and next header).
|
| 69 |
+
"""
|
| 70 |
+
import zstandard as zstd
|
| 71 |
+
|
| 72 |
+
dctx = zstd.ZstdDecompressor()
|
| 73 |
+
reader = dctx.stream_reader(fileobj)
|
| 74 |
+
text_reader = io.TextIOWrapper(reader, encoding="latin-1", errors="replace")
|
| 75 |
+
|
| 76 |
+
buf = []
|
| 77 |
+
game_count = 0
|
| 78 |
+
in_movetext = False
|
| 79 |
+
|
| 80 |
+
for line in text_reader:
|
| 81 |
+
stripped = line.strip()
|
| 82 |
+
|
| 83 |
+
if not stripped:
|
| 84 |
+
if in_movetext:
|
| 85 |
+
# End of movetext — game boundary
|
| 86 |
+
game_count += 1
|
| 87 |
+
in_movetext = False
|
| 88 |
+
buf.append(line)
|
| 89 |
+
if game_count >= batch_size:
|
| 90 |
+
yield "".join(buf), game_count
|
| 91 |
+
buf.clear()
|
| 92 |
+
game_count = 0
|
| 93 |
+
continue
|
| 94 |
+
buf.append(line)
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
if stripped.startswith("["):
|
| 98 |
+
in_movetext = False
|
| 99 |
+
else:
|
| 100 |
+
in_movetext = True
|
| 101 |
+
|
| 102 |
+
buf.append(line)
|
| 103 |
+
|
| 104 |
+
# Final batch
|
| 105 |
+
if buf:
|
| 106 |
+
yield "".join(buf), game_count
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def download_zst(year_month: str, output_dir: Path) -> Path:
|
| 110 |
+
"""Download a Lichess zstd PGN dump to disk. Returns path to .zst file."""
|
| 111 |
+
url = LICHESS_URL_TEMPLATE.format(year_month=year_month)
|
| 112 |
+
zst_path = output_dir / f"lichess_{year_month}.pgn.zst"
|
| 113 |
+
if zst_path.exists():
|
| 114 |
+
log(f" Using cached {zst_path} ({zst_path.stat().st_size / 1e9:.1f} GB)")
|
| 115 |
+
return zst_path
|
| 116 |
+
|
| 117 |
+
log(f" Downloading {url}")
|
| 118 |
+
req = urllib.request.Request(url)
|
| 119 |
+
req.add_header("User-Agent", "pawn-lichess-extract/1.0")
|
| 120 |
+
response = urllib.request.urlopen(req)
|
| 121 |
+
|
| 122 |
+
# Write to a temp file and rename on completion to avoid partial downloads
|
| 123 |
+
tmp_path = zst_path.with_suffix(".zst.tmp")
|
| 124 |
+
t0 = time.monotonic()
|
| 125 |
+
downloaded = 0
|
| 126 |
+
try:
|
| 127 |
+
with open(tmp_path, "wb") as f:
|
| 128 |
+
while True:
|
| 129 |
+
chunk = response.read(8 * 1024 * 1024) # 8 MB
|
| 130 |
+
if not chunk:
|
| 131 |
+
break
|
| 132 |
+
f.write(chunk)
|
| 133 |
+
downloaded += len(chunk)
|
| 134 |
+
elapsed = time.monotonic() - t0
|
| 135 |
+
rate_mb = (downloaded / 1e6) / elapsed if elapsed > 0 else 0
|
| 136 |
+
print(f"\r Downloaded {downloaded / 1e9:.2f} GB ({rate_mb:.0f} MB/s)", end="", flush=True)
|
| 137 |
+
print(flush=True)
|
| 138 |
+
tmp_path.rename(zst_path)
|
| 139 |
+
except BaseException:
|
| 140 |
+
tmp_path.unlink(missing_ok=True)
|
| 141 |
+
raise
|
| 142 |
+
|
| 143 |
+
response.close()
|
| 144 |
+
log(f" Saved {zst_path} ({downloaded / 1e9:.2f} GB in {time.monotonic() - t0:.0f}s)")
|
| 145 |
+
return zst_path
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def stream_pgn_from_zst(zst_path: Path, batch_size: int):
|
| 149 |
+
"""Yield PGN text batches from a local .zst file."""
|
| 150 |
+
with open(zst_path, "rb") as f:
|
| 151 |
+
yield from stream_pgn_games(f, batch_size)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def download_month(year_month: str, output_dir: Path, batch_size: int):
|
| 155 |
+
"""Download and parse a single month's PGN dump, yielding parsed batches."""
|
| 156 |
+
zst_path = download_zst(year_month, output_dir)
|
| 157 |
+
total_games = 0
|
| 158 |
+
batch_num = 0
|
| 159 |
+
|
| 160 |
+
for pgn_text, n_games_in_chunk in stream_pgn_from_zst(zst_path, batch_size):
|
| 161 |
+
if not pgn_text.strip():
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
t0 = time.monotonic()
|
| 165 |
+
parsed = chess_engine.parse_pgn_enriched(
|
| 166 |
+
pgn_text, max_ply=MAX_PLY, max_games=batch_size * 2, min_ply=1
|
| 167 |
+
)
|
| 168 |
+
dt = time.monotonic() - t0
|
| 169 |
+
|
| 170 |
+
n = parsed["tokens"].shape[0]
|
| 171 |
+
total_games += n
|
| 172 |
+
batch_num += 1
|
| 173 |
+
rate = n / dt if dt > 0 else 0
|
| 174 |
+
log(f" [{year_month}] batch {batch_num}: {n:,} games parsed in {dt:.1f}s ({rate:,.0f} games/s) | total: {total_games:,}")
|
| 175 |
+
|
| 176 |
+
yield parsed
|
| 177 |
+
|
| 178 |
+
log(f" [{year_month}] Done — {total_games:,} games total")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def sample_holdout(
|
| 182 |
+
zst_path: Path,
|
| 183 |
+
date_start: str,
|
| 184 |
+
date_end: str,
|
| 185 |
+
n_games: int,
|
| 186 |
+
batch_size: int,
|
| 187 |
+
seed: int,
|
| 188 |
+
) -> pl.DataFrame:
|
| 189 |
+
"""Uniformly sample n_games from a date range within a .zst PGN dump.
|
| 190 |
+
|
| 191 |
+
Two-pass approach:
|
| 192 |
+
1. Count games in the date range (header-only scan, no tokenization)
|
| 193 |
+
2. Generate random indices, parse only those games
|
| 194 |
+
"""
|
| 195 |
+
# Pass 1: count
|
| 196 |
+
log(f" Pass 1: counting games in [{date_start}, {date_end}]")
|
| 197 |
+
total_in_range = 0
|
| 198 |
+
t0 = time.monotonic()
|
| 199 |
+
for pgn_text, _ in stream_pgn_from_zst(zst_path, batch_size):
|
| 200 |
+
if not pgn_text.strip():
|
| 201 |
+
continue
|
| 202 |
+
total_in_range += chess_engine.count_pgn_games_in_date_range(
|
| 203 |
+
pgn_text, date_start, date_end
|
| 204 |
+
)
|
| 205 |
+
dt = time.monotonic() - t0
|
| 206 |
+
log(f" Found {total_in_range:,} games in range ({dt:.1f}s)")
|
| 207 |
+
|
| 208 |
+
if total_in_range == 0:
|
| 209 |
+
return pl.DataFrame()
|
| 210 |
+
|
| 211 |
+
# Generate random sample indices
|
| 212 |
+
actual_n = min(n_games, total_in_range)
|
| 213 |
+
rng = np.random.default_rng(seed)
|
| 214 |
+
indices = set(rng.choice(total_in_range, size=actual_n, replace=False).tolist())
|
| 215 |
+
log(f" Sampling {actual_n:,} of {total_in_range:,} games (seed={seed})")
|
| 216 |
+
|
| 217 |
+
# Pass 2: parse only sampled games
|
| 218 |
+
log(f" Pass 2: parsing sampled games")
|
| 219 |
+
frames = []
|
| 220 |
+
game_offset = 0
|
| 221 |
+
t0 = time.monotonic()
|
| 222 |
+
for pgn_text, _ in stream_pgn_from_zst(zst_path, batch_size):
|
| 223 |
+
if not pgn_text.strip():
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
# Count how many date-matching games are in this chunk
|
| 227 |
+
chunk_count = chess_engine.count_pgn_games_in_date_range(
|
| 228 |
+
pgn_text, date_start, date_end
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Check if any of our target indices fall in this chunk's range
|
| 232 |
+
chunk_indices = [
|
| 233 |
+
i for i in range(game_offset, game_offset + chunk_count)
|
| 234 |
+
if i in indices
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
if chunk_indices:
|
| 238 |
+
parsed = chess_engine.parse_pgn_sampled(
|
| 239 |
+
pgn_text,
|
| 240 |
+
chunk_indices,
|
| 241 |
+
date_start,
|
| 242 |
+
date_end,
|
| 243 |
+
game_offset=game_offset,
|
| 244 |
+
max_ply=MAX_PLY,
|
| 245 |
+
min_ply=1,
|
| 246 |
+
)
|
| 247 |
+
n = parsed["tokens"].shape[0]
|
| 248 |
+
if n > 0:
|
| 249 |
+
frames.append(batch_to_dataframe(parsed))
|
| 250 |
+
|
| 251 |
+
game_offset += chunk_count
|
| 252 |
+
|
| 253 |
+
dt = time.monotonic() - t0
|
| 254 |
+
if not frames:
|
| 255 |
+
log(f" No games parsed")
|
| 256 |
+
return pl.DataFrame()
|
| 257 |
+
|
| 258 |
+
result = pl.concat(frames)
|
| 259 |
+
log(f" Parsed {len(result):,} games ({dt:.1f}s)")
|
| 260 |
+
return result
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# ---------------------------------------------------------------------------
|
| 264 |
+
# DataFrame construction
|
| 265 |
+
# ---------------------------------------------------------------------------
|
| 266 |
+
|
| 267 |
+
def numpy_rows_to_list_series(
|
| 268 |
+
arr: np.ndarray, lengths: np.ndarray, name: str, inner_dtype: pl.DataType
|
| 269 |
+
) -> pl.Series:
|
| 270 |
+
"""Convert a 0-padded (N, max_ply) numpy array to a Polars List series,
|
| 271 |
+
trimming each row to its actual game length."""
|
| 272 |
+
rows = [arr[i, :lengths[i]].tolist() for i in range(len(arr))]
|
| 273 |
+
return pl.Series(name, rows, dtype=pl.List(inner_dtype))
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def batch_to_dataframe(parsed: dict) -> pl.DataFrame:
|
| 277 |
+
"""Convert a parsed batch dict from Rust into a Polars DataFrame.
|
| 278 |
+
|
| 279 |
+
Rust returns numpy arrays: tokens/clocks/evals as (N, max_ply),
|
| 280 |
+
scalar fields as (N,) arrays, and strings as Python lists.
|
| 281 |
+
"""
|
| 282 |
+
tokens: np.ndarray = parsed["tokens"] # (N, max_ply) i16
|
| 283 |
+
n = tokens.shape[0]
|
| 284 |
+
if n == 0:
|
| 285 |
+
return pl.DataFrame()
|
| 286 |
+
|
| 287 |
+
lengths: np.ndarray = parsed["game_lengths"] # (N,) u16
|
| 288 |
+
|
| 289 |
+
# Parse datetime strings -> proper datetime
|
| 290 |
+
# Format: "YYYY.MM.DD HH:MM:SS"
|
| 291 |
+
datetimes = []
|
| 292 |
+
for dt_str in parsed["date_time"]:
|
| 293 |
+
if dt_str and len(dt_str) >= 10:
|
| 294 |
+
try:
|
| 295 |
+
datetimes.append(datetime.strptime(dt_str, "%Y.%m.%d %H:%M:%S"))
|
| 296 |
+
except ValueError:
|
| 297 |
+
try:
|
| 298 |
+
datetimes.append(datetime.strptime(dt_str[:10], "%Y.%m.%d"))
|
| 299 |
+
except ValueError:
|
| 300 |
+
datetimes.append(None)
|
| 301 |
+
else:
|
| 302 |
+
datetimes.append(None)
|
| 303 |
+
|
| 304 |
+
df = pl.DataFrame({
|
| 305 |
+
"tokens": numpy_rows_to_list_series(tokens, lengths, "tokens", pl.Int16),
|
| 306 |
+
"clock": numpy_rows_to_list_series(parsed["clocks"], lengths, "clock", pl.UInt16),
|
| 307 |
+
"eval": numpy_rows_to_list_series(parsed["evals"], lengths, "eval", pl.Int16),
|
| 308 |
+
"game_length": pl.Series("game_length", parsed["game_lengths"], dtype=pl.UInt16),
|
| 309 |
+
"result": pl.Series("result", parsed["result"], dtype=pl.Utf8),
|
| 310 |
+
"white_player": pl.Series("white_player", parsed["white"], dtype=pl.Utf8),
|
| 311 |
+
"black_player": pl.Series("black_player", parsed["black"], dtype=pl.Utf8),
|
| 312 |
+
"white_elo": pl.Series("white_elo", parsed["white_elo"], dtype=pl.UInt16),
|
| 313 |
+
"black_elo": pl.Series("black_elo", parsed["black_elo"], dtype=pl.UInt16),
|
| 314 |
+
"white_rating_diff": pl.Series("white_rating_diff", parsed["white_rating_diff"], dtype=pl.Int16),
|
| 315 |
+
"black_rating_diff": pl.Series("black_rating_diff", parsed["black_rating_diff"], dtype=pl.Int16),
|
| 316 |
+
"eco": pl.Series("eco", parsed["eco"], dtype=pl.Utf8),
|
| 317 |
+
"opening": pl.Series("opening", parsed["opening"], dtype=pl.Utf8),
|
| 318 |
+
"time_control": pl.Series("time_control", parsed["time_control"], dtype=pl.Utf8),
|
| 319 |
+
"termination": pl.Series("termination", parsed["termination"], dtype=pl.Utf8),
|
| 320 |
+
"date": pl.Series("date", datetimes, dtype=pl.Datetime("ms")),
|
| 321 |
+
"site": pl.Series("site", parsed["site"], dtype=pl.Utf8),
|
| 322 |
+
})
|
| 323 |
+
|
| 324 |
+
# Hash usernames: vectorized xxHash64 via Polars.
|
| 325 |
+
# NOTE: hash() output is deterministic within a Polars version but the
|
| 326 |
+
# algorithm is not guaranteed stable across major versions. Originally
|
| 327 |
+
# recorded with Polars 1.39.3. Pin Polars version (via uv.lock) and
|
| 328 |
+
# tag the repo to ensure reproducibility. See test_enriched_pgn.py
|
| 329 |
+
# TestPlayerHashRegression for the snapshot test.
|
| 330 |
+
df = df.with_columns(
|
| 331 |
+
pl.col("white_player").hash().alias("white_player"),
|
| 332 |
+
pl.col("black_player").hash().alias("black_player"),
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
return df
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# ---------------------------------------------------------------------------
|
| 339 |
+
# Shard writing
|
| 340 |
+
# ---------------------------------------------------------------------------
|
| 341 |
+
|
| 342 |
+
def write_shard(
|
| 343 |
+
df: pl.DataFrame,
|
| 344 |
+
output_dir: Path,
|
| 345 |
+
split: str,
|
| 346 |
+
shard_idx: int,
|
| 347 |
+
total_shards: int,
|
| 348 |
+
) -> Path:
|
| 349 |
+
"""Write a single Parquet shard with HF-compatible naming."""
|
| 350 |
+
name = f"{split}-{shard_idx:05d}-of-{total_shards:05d}.parquet"
|
| 351 |
+
path = output_dir / "data" / name
|
| 352 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 353 |
+
df.write_parquet(path, compression="zstd", compression_level=3)
|
| 354 |
+
size_mb = path.stat().st_size / 1024 / 1024
|
| 355 |
+
log(f" Wrote {name}: {len(df):,} games, {size_mb:.1f} MB")
|
| 356 |
+
return path
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# ---------------------------------------------------------------------------
|
| 360 |
+
# HuggingFace upload
|
| 361 |
+
# ---------------------------------------------------------------------------
|
| 362 |
+
|
| 363 |
+
def upload_to_hf(output_dir: Path, hf_repo: str) -> None:
|
| 364 |
+
"""Upload the output directory to HuggingFace as a dataset."""
|
| 365 |
+
from huggingface_hub import HfApi
|
| 366 |
+
|
| 367 |
+
api = HfApi()
|
| 368 |
+
log(f"Uploading to HuggingFace: {hf_repo}")
|
| 369 |
+
|
| 370 |
+
# Create repo if it doesn't exist
|
| 371 |
+
api.create_repo(hf_repo, repo_type="dataset", exist_ok=True)
|
| 372 |
+
|
| 373 |
+
# Upload the data directory
|
| 374 |
+
api.upload_folder(
|
| 375 |
+
repo_id=hf_repo,
|
| 376 |
+
folder_path=str(output_dir),
|
| 377 |
+
repo_type="dataset",
|
| 378 |
+
)
|
| 379 |
+
log(f"Upload complete: https://huggingface.co/datasets/{hf_repo}")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# ---------------------------------------------------------------------------
|
| 383 |
+
# Main
|
| 384 |
+
# ---------------------------------------------------------------------------
|
| 385 |
+
|
| 386 |
+
class SplitBuffer:
|
| 387 |
+
"""Accumulates DataFrames for a single split and flushes to Parquet shards."""
|
| 388 |
+
|
| 389 |
+
def __init__(self, split: str, shard_size: int, output_dir: Path):
|
| 390 |
+
self.split = split
|
| 391 |
+
self.shard_size = shard_size
|
| 392 |
+
self.output_dir = output_dir / "data"
|
| 393 |
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
| 394 |
+
self.frames: list[pl.DataFrame] = []
|
| 395 |
+
self.buffered = 0
|
| 396 |
+
self.total_games = 0
|
| 397 |
+
self.shard_paths: list[Path] = []
|
| 398 |
+
self.shard_idx = 0
|
| 399 |
+
|
| 400 |
+
def add(self, df: pl.DataFrame) -> None:
|
| 401 |
+
if df.is_empty():
|
| 402 |
+
return
|
| 403 |
+
self.frames.append(df)
|
| 404 |
+
self.buffered += len(df)
|
| 405 |
+
self.total_games += len(df)
|
| 406 |
+
self._flush_full()
|
| 407 |
+
|
| 408 |
+
def _flush_full(self) -> None:
|
| 409 |
+
while self.buffered >= self.shard_size:
|
| 410 |
+
combined = pl.concat(self.frames)
|
| 411 |
+
shard_df = combined.head(self.shard_size)
|
| 412 |
+
leftover = combined.slice(self.shard_size)
|
| 413 |
+
self._write_shard(shard_df)
|
| 414 |
+
self.frames = [leftover] if len(leftover) > 0 else []
|
| 415 |
+
self.buffered = len(leftover) if len(leftover) > 0 else 0
|
| 416 |
+
|
| 417 |
+
def flush_remaining(self) -> None:
|
| 418 |
+
if self.frames:
|
| 419 |
+
combined = pl.concat(self.frames)
|
| 420 |
+
if len(combined) > 0:
|
| 421 |
+
self._write_shard(combined)
|
| 422 |
+
self.frames.clear()
|
| 423 |
+
self.buffered = 0
|
| 424 |
+
|
| 425 |
+
def _write_shard(self, df: pl.DataFrame) -> None:
|
| 426 |
+
# Write with placeholder name; rename after all shards are counted
|
| 427 |
+
path = self.output_dir / f"{self.split}-temp-{self.shard_idx:05d}.parquet"
|
| 428 |
+
df.write_parquet(path, compression="zstd", compression_level=3)
|
| 429 |
+
size_mb = path.stat().st_size / 1024 / 1024
|
| 430 |
+
log(f" [{self.split}] shard {self.shard_idx}: {len(df):,} games, {size_mb:.1f} MB")
|
| 431 |
+
self.shard_paths.append(path)
|
| 432 |
+
self.shard_idx += 1
|
| 433 |
+
|
| 434 |
+
def rename_shards(self) -> list[Path]:
|
| 435 |
+
"""Rename temp shards to HF-compatible names with correct total count."""
|
| 436 |
+
n = len(self.shard_paths)
|
| 437 |
+
final = []
|
| 438 |
+
for i, path in enumerate(self.shard_paths):
|
| 439 |
+
new_name = f"{self.split}-{i:05d}-of-{n:05d}.parquet"
|
| 440 |
+
new_path = path.parent / new_name
|
| 441 |
+
path.rename(new_path)
|
| 442 |
+
final.append(new_path)
|
| 443 |
+
log(f" {path.name} -> {new_name}")
|
| 444 |
+
self.shard_paths = final
|
| 445 |
+
return final
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def main():
|
| 449 |
+
parser = argparse.ArgumentParser(
|
| 450 |
+
description="Extract Lichess PGN dumps to PAWN-compatible Parquet"
|
| 451 |
+
)
|
| 452 |
+
parser.add_argument(
|
| 453 |
+
"--months", nargs="+", required=True,
|
| 454 |
+
help="Training month(s) to download, e.g. 2025-01 2025-02 2025-03"
|
| 455 |
+
)
|
| 456 |
+
parser.add_argument(
|
| 457 |
+
"--output", type=Path, default=Path("/workspace/lichess-parquet"),
|
| 458 |
+
help="Output directory for Parquet shards"
|
| 459 |
+
)
|
| 460 |
+
parser.add_argument(
|
| 461 |
+
"--hf-repo", type=str, default=None,
|
| 462 |
+
help="HuggingFace dataset repo to push to (e.g. thomas-schweich/pawn-lichess-full)"
|
| 463 |
+
)
|
| 464 |
+
parser.add_argument(
|
| 465 |
+
"--batch-size", type=int, default=500_000,
|
| 466 |
+
help="Games per batch during parsing (controls memory usage)"
|
| 467 |
+
)
|
| 468 |
+
parser.add_argument(
|
| 469 |
+
"--shard-size", type=int, default=SHARD_TARGET_GAMES,
|
| 470 |
+
help="Target games per output shard"
|
| 471 |
+
)
|
| 472 |
+
parser.add_argument(
|
| 473 |
+
"--holdout-month", type=str, default=None,
|
| 474 |
+
help="Month to use for val/test (e.g. 2023-12). First half of month "
|
| 475 |
+
"-> val, second half -> test. Randomly samples --holdout-games "
|
| 476 |
+
"from each half."
|
| 477 |
+
)
|
| 478 |
+
parser.add_argument(
|
| 479 |
+
"--holdout-games", type=int, default=50_000,
|
| 480 |
+
help="Number of games to sample for each of val and test (default: 50000)"
|
| 481 |
+
)
|
| 482 |
+
parser.add_argument(
|
| 483 |
+
"--max-games", type=int, default=None,
|
| 484 |
+
help="Stop after this many training games (for testing)"
|
| 485 |
+
)
|
| 486 |
+
parser.add_argument(
|
| 487 |
+
"--seed", type=int, default=42,
|
| 488 |
+
help="Random seed for holdout sampling (default: 42)"
|
| 489 |
+
)
|
| 490 |
+
args = parser.parse_args()
|
| 491 |
+
|
| 492 |
+
log("=== Lichess Parquet Extraction ===")
|
| 493 |
+
log(f"Training months: {args.months}")
|
| 494 |
+
log(f"Output: {args.output}")
|
| 495 |
+
log(f"Batch size: {args.batch_size:,}")
|
| 496 |
+
log(f"Shard size: {args.shard_size:,}")
|
| 497 |
+
if args.holdout_month:
|
| 498 |
+
log(f"Holdout month: {args.holdout_month}")
|
| 499 |
+
log(f"Holdout games per split: {args.holdout_games:,}")
|
| 500 |
+
log(f"Holdout seed: {args.seed}")
|
| 501 |
+
if args.max_games:
|
| 502 |
+
log(f"Max training games: {args.max_games:,}")
|
| 503 |
+
log("")
|
| 504 |
+
|
| 505 |
+
args.output.mkdir(parents=True, exist_ok=True)
|
| 506 |
+
|
| 507 |
+
# ── Phase 1: Process holdout month (val/test) ──────────────────────
|
| 508 |
+
buffers = {
|
| 509 |
+
"train": SplitBuffer("train", args.shard_size, args.output),
|
| 510 |
+
"validation": SplitBuffer("validation", args.shard_size, args.output),
|
| 511 |
+
"test": SplitBuffer("test", args.shard_size, args.output),
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
if args.holdout_month:
|
| 515 |
+
log(f"\n=== Processing holdout month {args.holdout_month} ===")
|
| 516 |
+
|
| 517 |
+
# Download the zstd file (reused for both count and parse passes)
|
| 518 |
+
zst_path = download_zst(args.holdout_month, args.output)
|
| 519 |
+
|
| 520 |
+
# Date ranges: first half of month -> val, second half -> test
|
| 521 |
+
# UTCDate format is "YYYY.MM.DD"
|
| 522 |
+
year, mon = args.holdout_month.split("-")
|
| 523 |
+
val_start = f"{year}.{mon}.01"
|
| 524 |
+
val_end = f"{year}.{mon}.14"
|
| 525 |
+
test_start = f"{year}.{mon}.15"
|
| 526 |
+
test_end = f"{year}.{mon}.31"
|
| 527 |
+
|
| 528 |
+
log(f" Val date range: [{val_start}, {val_end}]")
|
| 529 |
+
val_df = sample_holdout(
|
| 530 |
+
zst_path, val_start, val_end,
|
| 531 |
+
args.holdout_games, args.batch_size, args.seed,
|
| 532 |
+
)
|
| 533 |
+
if not val_df.is_empty():
|
| 534 |
+
buffers["validation"].add(val_df)
|
| 535 |
+
|
| 536 |
+
log(f" Test date range: [{test_start}, {test_end}]")
|
| 537 |
+
test_df = sample_holdout(
|
| 538 |
+
zst_path, test_start, test_end,
|
| 539 |
+
args.holdout_games, args.batch_size, args.seed + 1,
|
| 540 |
+
)
|
| 541 |
+
if not test_df.is_empty():
|
| 542 |
+
buffers["test"].add(test_df)
|
| 543 |
+
|
| 544 |
+
# ── Phase 2: Process training months ─────────────────���─────────────
|
| 545 |
+
total_train = 0
|
| 546 |
+
stop = False
|
| 547 |
+
|
| 548 |
+
for month in args.months:
|
| 549 |
+
if stop:
|
| 550 |
+
break
|
| 551 |
+
log(f"\n=== Processing {month} (train) ===")
|
| 552 |
+
|
| 553 |
+
for parsed in download_month(month, args.output, args.batch_size):
|
| 554 |
+
df = batch_to_dataframe(parsed)
|
| 555 |
+
if df.is_empty():
|
| 556 |
+
continue
|
| 557 |
+
|
| 558 |
+
if args.max_games:
|
| 559 |
+
remaining = args.max_games - total_train
|
| 560 |
+
if remaining <= 0:
|
| 561 |
+
stop = True
|
| 562 |
+
break
|
| 563 |
+
if len(df) > remaining:
|
| 564 |
+
df = df.head(remaining)
|
| 565 |
+
|
| 566 |
+
total_train += len(df)
|
| 567 |
+
buffers["train"].add(df)
|
| 568 |
+
|
| 569 |
+
if args.max_games and total_train >= args.max_games:
|
| 570 |
+
stop = True
|
| 571 |
+
break
|
| 572 |
+
|
| 573 |
+
# Flush remaining data in each buffer
|
| 574 |
+
for buf in buffers.values():
|
| 575 |
+
buf.flush_remaining()
|
| 576 |
+
|
| 577 |
+
log(f"\n=== Renaming shards ===")
|
| 578 |
+
final_paths = []
|
| 579 |
+
for buf in buffers.values():
|
| 580 |
+
if buf.shard_paths:
|
| 581 |
+
final_paths.extend(buf.rename_shards())
|
| 582 |
+
|
| 583 |
+
# Summary
|
| 584 |
+
log(f"\n=== Summary ===")
|
| 585 |
+
total_games = sum(buf.total_games for buf in buffers.values())
|
| 586 |
+
log(f"Total games: {total_games:,}")
|
| 587 |
+
for name, buf in buffers.items():
|
| 588 |
+
if buf.total_games > 0:
|
| 589 |
+
log(f" {name}: {buf.total_games:,} games, {len(buf.shard_paths)} shards")
|
| 590 |
+
|
| 591 |
+
if final_paths:
|
| 592 |
+
total_size = sum(p.stat().st_size for p in final_paths)
|
| 593 |
+
log(f"Total size: {total_size / 1024 / 1024 / 1024:.2f} GB")
|
| 594 |
+
|
| 595 |
+
# Upload to HuggingFace
|
| 596 |
+
if args.hf_repo:
|
| 597 |
+
upload_to_hf(args.output, args.hf_repo)
|
| 598 |
+
|
| 599 |
+
log("\nDone!")
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
if __name__ == "__main__":
|
| 603 |
+
main()
|
tests/test_enriched_pgn.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for enriched PGN parsing and dataset extraction pipeline."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import polars as pl
|
| 5 |
+
import pytest
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import chess_engine
|
| 10 |
+
|
| 11 |
+
# Import extraction helper
|
| 12 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "scripts"))
|
| 13 |
+
from extract_lichess_parquet import batch_to_dataframe
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
# Test PGN data — each game has distinct moves, metadata, and annotations
|
| 18 |
+
# ---------------------------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
PGNS = {
|
| 21 |
+
"alice_v_bob": """\
|
| 22 |
+
[Event "Rated Rapid game"]
|
| 23 |
+
[Site "https://lichess.org/game001"]
|
| 24 |
+
[White "alice"]
|
| 25 |
+
[Black "bob"]
|
| 26 |
+
[Result "1-0"]
|
| 27 |
+
[WhiteElo "1873"]
|
| 28 |
+
[BlackElo "1844"]
|
| 29 |
+
[WhiteRatingDiff "+6"]
|
| 30 |
+
[BlackRatingDiff "-26"]
|
| 31 |
+
[ECO "C20"]
|
| 32 |
+
[Opening "King's Pawn Game"]
|
| 33 |
+
[TimeControl "600+0"]
|
| 34 |
+
[Termination "Normal"]
|
| 35 |
+
[UTCDate "2025.01.10"]
|
| 36 |
+
[UTCTime "10:00:00"]
|
| 37 |
+
|
| 38 |
+
1. e4 { [%clk 0:10:00] [%eval 0.23] } 1... e5 { [%clk 0:09:58] [%eval 0.31] } 2. Nf3 { [%clk 0:09:50] [%eval 0.25] } 2... Nc6 { [%clk 0:09:45] [%eval 0.30] } 1-0
|
| 39 |
+
""",
|
| 40 |
+
"bob_v_alice": """\
|
| 41 |
+
[Event "Rated Blitz game"]
|
| 42 |
+
[Site "https://lichess.org/game002"]
|
| 43 |
+
[White "bob"]
|
| 44 |
+
[Black "alice"]
|
| 45 |
+
[Result "0-1"]
|
| 46 |
+
[WhiteElo "1850"]
|
| 47 |
+
[BlackElo "1880"]
|
| 48 |
+
[WhiteRatingDiff "-5"]
|
| 49 |
+
[BlackRatingDiff "+5"]
|
| 50 |
+
[ECO "B20"]
|
| 51 |
+
[Opening "Sicilian Defense"]
|
| 52 |
+
[TimeControl "300+3"]
|
| 53 |
+
[Termination "Time forfeit"]
|
| 54 |
+
[UTCDate "2025.02.14"]
|
| 55 |
+
[UTCTime "20:00:00"]
|
| 56 |
+
|
| 57 |
+
1. e4 { [%clk 0:05:00] [%eval 0.20] } 1... c5 { [%clk 0:04:55] [%eval 0.25] } 2. d4 { [%clk 0:04:48] [%eval 0.40] } 0-1
|
| 58 |
+
""",
|
| 59 |
+
"alice_v_xavier": """\
|
| 60 |
+
[Event "Rated Classical game"]
|
| 61 |
+
[Site "https://lichess.org/game003"]
|
| 62 |
+
[White "alice"]
|
| 63 |
+
[Black "xavier"]
|
| 64 |
+
[Result "1/2-1/2"]
|
| 65 |
+
[WhiteElo "1900"]
|
| 66 |
+
[BlackElo "2100"]
|
| 67 |
+
[WhiteRatingDiff "+3"]
|
| 68 |
+
[BlackRatingDiff "-1"]
|
| 69 |
+
[ECO "D30"]
|
| 70 |
+
[Opening "Queen's Gambit Declined"]
|
| 71 |
+
[TimeControl "1800+30"]
|
| 72 |
+
[Termination "Normal"]
|
| 73 |
+
[UTCDate "2025.03.01"]
|
| 74 |
+
[UTCTime "15:30:00"]
|
| 75 |
+
|
| 76 |
+
1. d4 { [%clk 0:30:00] [%eval 0.10] } 1... d5 { [%clk 0:29:50] [%eval 0.15] } 2. c4 { [%clk 0:29:40] [%eval 0.20] } 2... e6 { [%clk 0:29:30] [%eval 0.18] } 3. Nf3 { [%clk 0:29:20] [%eval 0.22] } 1/2-1/2
|
| 77 |
+
""",
|
| 78 |
+
"xavier_v_alice": """\
|
| 79 |
+
[Event "Rated Rapid game"]
|
| 80 |
+
[Site "https://lichess.org/game004"]
|
| 81 |
+
[White "xavier"]
|
| 82 |
+
[Black "alice"]
|
| 83 |
+
[Result "1-0"]
|
| 84 |
+
[WhiteElo "2105"]
|
| 85 |
+
[BlackElo "1895"]
|
| 86 |
+
[WhiteRatingDiff "+2"]
|
| 87 |
+
[BlackRatingDiff "-4"]
|
| 88 |
+
[ECO "A45"]
|
| 89 |
+
[Opening "Trompowsky Attack"]
|
| 90 |
+
[TimeControl "900+10"]
|
| 91 |
+
[Termination "Normal"]
|
| 92 |
+
[UTCDate "2025.01.20"]
|
| 93 |
+
[UTCTime "18:00:00"]
|
| 94 |
+
|
| 95 |
+
1. d4 { [%clk 0:15:00] } 1... Nf6 { [%clk 0:14:55] } 2. Bg5 { [%clk 0:14:48] } 1-0
|
| 96 |
+
""",
|
| 97 |
+
"bob_v_xavier": """\
|
| 98 |
+
[Event "Rated Bullet game"]
|
| 99 |
+
[Site "https://lichess.org/game005"]
|
| 100 |
+
[White "bob"]
|
| 101 |
+
[Black "xavier"]
|
| 102 |
+
[Result "0-1"]
|
| 103 |
+
[WhiteElo "1840"]
|
| 104 |
+
[BlackElo "2110"]
|
| 105 |
+
[WhiteRatingDiff "-8"]
|
| 106 |
+
[BlackRatingDiff "+3"]
|
| 107 |
+
[ECO "C50"]
|
| 108 |
+
[Opening "Italian Game"]
|
| 109 |
+
[TimeControl "60+0"]
|
| 110 |
+
[Termination "Normal"]
|
| 111 |
+
[UTCDate "2025.02.28"]
|
| 112 |
+
[UTCTime "23:59:00"]
|
| 113 |
+
|
| 114 |
+
1. e4 { [%clk 0:01:00] [%eval 0.20] } 1... e5 { [%clk 0:00:59] [%eval 0.25] } 2. Nf3 { [%clk 0:00:55] [%eval 0.30] } 2... Nc6 { [%clk 0:00:53] [%eval 0.28] } 3. Bc4 { [%clk 0:00:50] [%eval 0.35] } 3... Bc5 { [%clk 0:00:48] [%eval 0.30] } 0-1
|
| 115 |
+
""",
|
| 116 |
+
"xavier_v_bob": """\
|
| 117 |
+
[Event "Rated Rapid game"]
|
| 118 |
+
[Site "https://lichess.org/game006"]
|
| 119 |
+
[White "xavier"]
|
| 120 |
+
[Black "bob"]
|
| 121 |
+
[Result "1-0"]
|
| 122 |
+
[WhiteElo "2115"]
|
| 123 |
+
[BlackElo "1835"]
|
| 124 |
+
[WhiteRatingDiff "+1"]
|
| 125 |
+
[BlackRatingDiff "-7"]
|
| 126 |
+
[ECO "E00"]
|
| 127 |
+
[Opening "Queen's Pawn Game"]
|
| 128 |
+
[TimeControl "600+5"]
|
| 129 |
+
[Termination "Normal"]
|
| 130 |
+
[UTCDate "2025.03.15"]
|
| 131 |
+
[UTCTime "09:00:00"]
|
| 132 |
+
|
| 133 |
+
1. d4 { [%clk 0:10:00] [%eval 0.15] } 1... Nf6 { [%clk 0:09:55] [%eval 0.20] } 2. c4 { [%clk 0:09:48] [%eval 0.25] } 2... e6 { [%clk 0:09:40] [%eval 0.22] } 3. Nc3 { [%clk 0:09:35] [%eval 0.28] } 3... Bb4 { [%clk 0:09:28] [%eval 0.30] } 4. Qc2 { [%clk 0:09:20] [%eval 0.32] } 1-0
|
| 134 |
+
""",
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# ---------------------------------------------------------------------------
|
| 139 |
+
# Tests
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class TestEnrichedParsing:
|
| 144 |
+
"""Test the Rust parse_pgn_enriched function."""
|
| 145 |
+
|
| 146 |
+
def test_basic_parsing(self):
|
| 147 |
+
pgn = PGNS["alice_v_bob"]
|
| 148 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 149 |
+
assert r["tokens"].shape == (1, 255)
|
| 150 |
+
assert r["clocks"].shape == (1, 255)
|
| 151 |
+
assert r["evals"].shape == (1, 255)
|
| 152 |
+
assert r["game_lengths"].shape == (1,)
|
| 153 |
+
assert r["game_lengths"][0] == 4
|
| 154 |
+
|
| 155 |
+
def test_return_types(self):
|
| 156 |
+
pgn = PGNS["alice_v_bob"]
|
| 157 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 158 |
+
for key in ("tokens", "clocks", "evals"):
|
| 159 |
+
assert isinstance(r[key], np.ndarray), f"{key} should be ndarray"
|
| 160 |
+
for key in ("game_lengths", "white_elo", "black_elo",
|
| 161 |
+
"white_rating_diff", "black_rating_diff"):
|
| 162 |
+
assert isinstance(r[key], np.ndarray), f"{key} should be ndarray"
|
| 163 |
+
for key in ("result", "white", "black", "eco", "opening",
|
| 164 |
+
"time_control", "termination", "date_time", "site"):
|
| 165 |
+
assert isinstance(r[key], list), f"{key} should be list"
|
| 166 |
+
|
| 167 |
+
def test_clock_extraction(self):
|
| 168 |
+
pgn = PGNS["alice_v_bob"]
|
| 169 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 170 |
+
clocks = r["clocks"][0, :4]
|
| 171 |
+
assert list(clocks) == [600, 598, 590, 585]
|
| 172 |
+
|
| 173 |
+
def test_eval_extraction(self):
|
| 174 |
+
pgn = PGNS["alice_v_bob"]
|
| 175 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 176 |
+
evals = r["evals"][0, :4]
|
| 177 |
+
assert list(evals) == [23, 31, 25, 30]
|
| 178 |
+
|
| 179 |
+
def test_missing_eval_sentinel(self):
|
| 180 |
+
pgn = PGNS["xavier_v_alice"] # no eval annotations
|
| 181 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 182 |
+
length = r["game_lengths"][0]
|
| 183 |
+
evals = r["evals"][0, :length]
|
| 184 |
+
# Rust uses i16::MIN (-32768) as the "no eval" sentinel
|
| 185 |
+
assert all(e == -32768 for e in evals), (
|
| 186 |
+
f"Missing evals should be -32768 (i16::MIN), got {list(evals)}"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def test_padding_is_zero(self):
|
| 190 |
+
pgn = PGNS["alice_v_bob"]
|
| 191 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 192 |
+
length = r["game_lengths"][0]
|
| 193 |
+
assert np.all(r["tokens"][0, length:] == 0)
|
| 194 |
+
assert np.all(r["clocks"][0, length:] == 0)
|
| 195 |
+
assert np.all(r["evals"][0, length:] == 0)
|
| 196 |
+
|
| 197 |
+
def test_headers_extracted(self):
|
| 198 |
+
pgn = PGNS["alice_v_bob"]
|
| 199 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 200 |
+
assert r["white"][0] == "alice"
|
| 201 |
+
assert r["black"][0] == "bob"
|
| 202 |
+
assert r["result"][0] == "1-0"
|
| 203 |
+
assert r["white_elo"][0] == 1873
|
| 204 |
+
assert r["black_elo"][0] == 1844
|
| 205 |
+
assert r["white_rating_diff"][0] == 6
|
| 206 |
+
assert r["black_rating_diff"][0] == -26
|
| 207 |
+
assert r["eco"][0] == "C20"
|
| 208 |
+
assert r["time_control"][0] == "600+0"
|
| 209 |
+
assert r["site"][0] == "https://lichess.org/game001"
|
| 210 |
+
|
| 211 |
+
def test_different_games_produce_different_tokens(self):
|
| 212 |
+
"""Each test PGN has distinct moves — tokens must differ."""
|
| 213 |
+
all_tokens = {}
|
| 214 |
+
for name, pgn in PGNS.items():
|
| 215 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 216 |
+
length = r["game_lengths"][0]
|
| 217 |
+
all_tokens[name] = tuple(r["tokens"][0, :length])
|
| 218 |
+
|
| 219 |
+
names = list(all_tokens.keys())
|
| 220 |
+
for i in range(len(names)):
|
| 221 |
+
for j in range(i + 1, len(names)):
|
| 222 |
+
assert all_tokens[names[i]] != all_tokens[names[j]], (
|
| 223 |
+
f"{names[i]} and {names[j]} should have different token sequences"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
class TestPlayerHashing:
|
| 228 |
+
"""Test that player username hashing is deterministic and independent of context.
|
| 229 |
+
|
| 230 |
+
Each PGN has different moves, metadata, Elo, time control, dates, and
|
| 231 |
+
game lengths — the only thing shared is the player name strings. If
|
| 232 |
+
hashing accidentally depended on row context, these would diverge.
|
| 233 |
+
"""
|
| 234 |
+
|
| 235 |
+
@pytest.fixture
|
| 236 |
+
def player_hashes(self):
|
| 237 |
+
"""Parse all 6 PGNs separately and collect per-player hash values."""
|
| 238 |
+
hashes = {} # name -> list of observed uint64 hashes
|
| 239 |
+
for name, pgn in PGNS.items():
|
| 240 |
+
r = chess_engine.parse_pgn_enriched(pgn)
|
| 241 |
+
df = batch_to_dataframe(r)
|
| 242 |
+
w_name = r["white"][0]
|
| 243 |
+
b_name = r["black"][0]
|
| 244 |
+
w_hash = df["white_player"][0]
|
| 245 |
+
b_hash = df["black_player"][0]
|
| 246 |
+
hashes.setdefault(w_name, []).append(w_hash)
|
| 247 |
+
hashes.setdefault(b_name, []).append(b_hash)
|
| 248 |
+
return hashes
|
| 249 |
+
|
| 250 |
+
def test_same_name_always_same_hash(self, player_hashes):
|
| 251 |
+
"""A player name must always produce the same hash regardless of
|
| 252 |
+
which game it appears in, whether as white or black, and what
|
| 253 |
+
the surrounding metadata looks like."""
|
| 254 |
+
for name, vals in player_hashes.items():
|
| 255 |
+
unique = set(vals)
|
| 256 |
+
assert len(unique) == 1, (
|
| 257 |
+
f"'{name}' produced {len(unique)} distinct hashes across "
|
| 258 |
+
f"{len(vals)} appearances: {unique}"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def test_different_names_different_hashes(self, player_hashes):
|
| 262 |
+
"""alice, bob, and xavier must all have distinct hashes."""
|
| 263 |
+
canonical = {name: vals[0] for name, vals in player_hashes.items()}
|
| 264 |
+
hash_vals = list(canonical.values())
|
| 265 |
+
assert len(set(hash_vals)) == len(hash_vals), (
|
| 266 |
+
f"Hash collision among players: {canonical}"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def test_hash_dtype_is_uint64(self, player_hashes):
|
| 270 |
+
# Parse one game and check the Polars column dtype
|
| 271 |
+
r = chess_engine.parse_pgn_enriched(PGNS["alice_v_bob"])
|
| 272 |
+
df = batch_to_dataframe(r)
|
| 273 |
+
assert df["white_player"].dtype == pl.UInt64
|
| 274 |
+
assert df["black_player"].dtype == pl.UInt64
|
| 275 |
+
|
| 276 |
+
def test_hash_appears_in_both_columns(self, player_hashes):
|
| 277 |
+
"""alice appears as both white and black — hash must match in both columns."""
|
| 278 |
+
# alice is white in alice_v_bob and black in bob_v_alice
|
| 279 |
+
r1 = chess_engine.parse_pgn_enriched(PGNS["alice_v_bob"])
|
| 280 |
+
df1 = batch_to_dataframe(r1)
|
| 281 |
+
r2 = chess_engine.parse_pgn_enriched(PGNS["bob_v_alice"])
|
| 282 |
+
df2 = batch_to_dataframe(r2)
|
| 283 |
+
|
| 284 |
+
alice_as_white = df1["white_player"][0]
|
| 285 |
+
alice_as_black = df2["black_player"][0]
|
| 286 |
+
assert alice_as_white == alice_as_black, (
|
| 287 |
+
f"alice hash differs: white={alice_as_white}, black={alice_as_black}"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class TestPlayerHashRegression:
|
| 292 |
+
"""Snapshot test: catch if a Polars update changes the hash algorithm.
|
| 293 |
+
|
| 294 |
+
These exact values were recorded with Polars 1.39.3 using the default
|
| 295 |
+
hash() seed (xxHash64). If this test fails after a Polars upgrade, the
|
| 296 |
+
dataset must be regenerated to stay consistent (or the old Polars
|
| 297 |
+
version must be pinned).
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
EXPECTED_HASHES = {
|
| 301 |
+
"alice": 573680751236103438,
|
| 302 |
+
"bob": 11376496890720967193,
|
| 303 |
+
"xavier": 2453512920044318708,
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
def test_hash_values_match_snapshot(self):
|
| 307 |
+
"""Verify that pl.Series.hash() produces the exact same uint64
|
| 308 |
+
values that were recorded when the dataset was built."""
|
| 309 |
+
for name, expected in self.EXPECTED_HASHES.items():
|
| 310 |
+
actual = pl.Series([name]).hash()[0]
|
| 311 |
+
assert actual == expected, (
|
| 312 |
+
f"Hash regression for '{name}': expected {expected}, got {actual}. "
|
| 313 |
+
f"Polars hash algorithm may have changed — dataset must be regenerated."
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def test_snapshot_matches_pipeline(self):
|
| 317 |
+
"""The snapshot values must agree with what batch_to_dataframe produces."""
|
| 318 |
+
combined = "\n".join(PGNS.values())
|
| 319 |
+
r = chess_engine.parse_pgn_enriched(combined)
|
| 320 |
+
df = batch_to_dataframe(r)
|
| 321 |
+
|
| 322 |
+
for name, expected in self.EXPECTED_HASHES.items():
|
| 323 |
+
# Find rows where this player appears as white
|
| 324 |
+
white_rows = [
|
| 325 |
+
i for i, w in enumerate(r["white"]) if w == name
|
| 326 |
+
]
|
| 327 |
+
for i in white_rows:
|
| 328 |
+
actual = df["white_player"][i]
|
| 329 |
+
assert actual == expected, (
|
| 330 |
+
f"Pipeline hash for '{name}' (white, row {i}): "
|
| 331 |
+
f"expected {expected}, got {actual}"
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# Find rows where this player appears as black
|
| 335 |
+
black_rows = [
|
| 336 |
+
i for i, b in enumerate(r["black"]) if b == name
|
| 337 |
+
]
|
| 338 |
+
for i in black_rows:
|
| 339 |
+
actual = df["black_player"][i]
|
| 340 |
+
assert actual == expected, (
|
| 341 |
+
f"Pipeline hash for '{name}' (black, row {i}): "
|
| 342 |
+
f"expected {expected}, got {actual}"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class TestBatchToDataframe:
|
| 347 |
+
"""Test the full batch_to_dataframe pipeline."""
|
| 348 |
+
|
| 349 |
+
def test_schema(self):
|
| 350 |
+
r = chess_engine.parse_pgn_enriched(PGNS["alice_v_bob"])
|
| 351 |
+
df = batch_to_dataframe(r)
|
| 352 |
+
assert df["tokens"].dtype == pl.List(pl.Int16)
|
| 353 |
+
assert df["clock"].dtype == pl.List(pl.UInt16)
|
| 354 |
+
assert df["eval"].dtype == pl.List(pl.Int16)
|
| 355 |
+
assert df["game_length"].dtype == pl.UInt16
|
| 356 |
+
assert df["white_elo"].dtype == pl.UInt16
|
| 357 |
+
assert df["black_elo"].dtype == pl.UInt16
|
| 358 |
+
assert df["white_rating_diff"].dtype == pl.Int16
|
| 359 |
+
assert df["black_rating_diff"].dtype == pl.Int16
|
| 360 |
+
assert df["white_player"].dtype == pl.UInt64
|
| 361 |
+
assert df["black_player"].dtype == pl.UInt64
|
| 362 |
+
|
| 363 |
+
def test_list_columns_trimmed_to_game_length(self):
|
| 364 |
+
"""List columns should contain exactly game_length elements (no padding)."""
|
| 365 |
+
r = chess_engine.parse_pgn_enriched(PGNS["bob_v_xavier"])
|
| 366 |
+
df = batch_to_dataframe(r)
|
| 367 |
+
gl = df["game_length"][0]
|
| 368 |
+
assert len(df["tokens"][0]) == gl
|
| 369 |
+
assert len(df["clock"][0]) == gl
|
| 370 |
+
assert len(df["eval"][0]) == gl
|
| 371 |
+
|
| 372 |
+
def test_parquet_roundtrip(self, tmp_path):
|
| 373 |
+
"""Write to Parquet and read back — all values must survive."""
|
| 374 |
+
r = chess_engine.parse_pgn_enriched(PGNS["xavier_v_bob"])
|
| 375 |
+
df = batch_to_dataframe(r)
|
| 376 |
+
path = tmp_path / "test.parquet"
|
| 377 |
+
df.write_parquet(path, compression="zstd")
|
| 378 |
+
df2 = pl.read_parquet(path)
|
| 379 |
+
assert df.shape == df2.shape
|
| 380 |
+
assert df["tokens"].to_list() == df2["tokens"].to_list()
|
| 381 |
+
assert df["clock"].to_list() == df2["clock"].to_list()
|
| 382 |
+
assert df["eval"].to_list() == df2["eval"].to_list()
|
| 383 |
+
assert df["white_player"].to_list() == df2["white_player"].to_list()
|
| 384 |
+
|
| 385 |
+
def test_multi_game_batch(self):
|
| 386 |
+
"""Parse all 6 games in a single PGN string."""
|
| 387 |
+
combined = "\n".join(PGNS.values())
|
| 388 |
+
r = chess_engine.parse_pgn_enriched(combined)
|
| 389 |
+
df = batch_to_dataframe(r)
|
| 390 |
+
assert len(df) == 6
|
| 391 |
+
# Each game should have a different game length
|
| 392 |
+
lengths = df["game_length"].to_list()
|
| 393 |
+
assert len(set(lengths)) > 1, "Games should have different lengths"
|