thomas-schweich commited on
Commit
87cdae5
·
unverified ·
1 Parent(s): a819b4b

Add Lichess PGN -> PAWN Parquet extraction pipeline (#4)

Browse files
.dockerignore CHANGED
@@ -14,6 +14,8 @@ deploy/
14
  !deploy/entrypoint-extract.sh
15
  !deploy/entrypoint-lc0.sh
16
  !deploy/entrypoint-rosa-sweep.sh
 
 
17
  *.so
18
  CLAUDE.md
19
  docs/
 
14
  !deploy/entrypoint-extract.sh
15
  !deploy/entrypoint-lc0.sh
16
  !deploy/entrypoint-rosa-sweep.sh
17
+ !deploy/entrypoint-lc0-selfplay.sh
18
+ !deploy/entrypoint-lichess-parquet.sh
19
  *.so
20
  CLAUDE.md
21
  docs/
Dockerfile CHANGED
@@ -92,6 +92,13 @@ COPY deploy/entrypoint-rosa-sweep.sh /entrypoint-rosa-sweep.sh
92
  RUN chmod +x /entrypoint-rosa-sweep.sh
93
  ENTRYPOINT ["/entrypoint-rosa-sweep.sh"]
94
 
 
 
 
 
 
 
 
95
  # ── Interactive (default) — SSH + Jupyter, stays alive ───────────────
96
  FROM runtime-base AS interactive
97
  # Inherits /start.sh entrypoint from Runpod base image
 
92
  RUN chmod +x /entrypoint-rosa-sweep.sh
93
  ENTRYPOINT ["/entrypoint-rosa-sweep.sh"]
94
 
95
+ # ── Lichess extract — downloads PGN, writes Parquet, pushes to HF ───
96
+ FROM runtime-base AS lichess-extract
97
+ RUN pip install --no-cache-dir zstandard
98
+ COPY deploy/entrypoint-lichess-parquet.sh /entrypoint-lichess-parquet.sh
99
+ RUN chmod +x /entrypoint-lichess-parquet.sh
100
+ ENTRYPOINT ["/entrypoint-lichess-parquet.sh"]
101
+
102
  # ── Interactive (default) — SSH + Jupyter, stays alive ───────────────
103
  FROM runtime-base AS interactive
104
  # Inherits /start.sh entrypoint from Runpod base image
deploy/entrypoint-lichess-parquet.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Lichess PGN -> PAWN Parquet extraction entrypoint.
3
+ # Downloads monthly database dumps, parses via Rust engine, writes sharded
4
+ # Parquet with train/val/test splits, and optionally pushes to HuggingFace.
5
+ #
6
+ # Required env vars:
7
+ # MONTHS — space-separated training months (e.g., "2025-01 2025-02 2025-03")
8
+ #
9
+ # Optional env vars:
10
+ # HF_TOKEN — HuggingFace token (for pushing dataset)
11
+ # HF_REPO — HuggingFace dataset repo (e.g., "thomas-schweich/pawn-lichess-full")
12
+ # HOLDOUT_MONTH — month for val/test (e.g., "2023-12")
13
+ # HOLDOUT_GAMES — games per split from holdout month (default: 50000)
14
+ # BATCH_SIZE — games per parsing batch (default: 500000)
15
+ # SHARD_SIZE — games per output shard (default: 1000000)
16
+ # MAX_GAMES — stop after this many training games (for testing)
17
+ # OUTPUT_DIR — output directory (default: /workspace/lichess-parquet)
18
+ # SEED — random seed for holdout sampling (default: 42)
19
+ set -euo pipefail
20
+
21
+ echo "=== Lichess Parquet Extraction ==="
22
+ echo " Training months: ${MONTHS:?MONTHS env var is required}"
23
+ echo " Holdout month: ${HOLDOUT_MONTH:-none}"
24
+ echo " Holdout games/split: ${HOLDOUT_GAMES:-50000}"
25
+ echo " HF Repo: ${HF_REPO:-none}"
26
+ echo " Batch size: ${BATCH_SIZE:-500000}"
27
+ echo " Shard size: ${SHARD_SIZE:-1000000}"
28
+ echo ""
29
+
30
+ # Persist HF token if provided
31
+ if [ -n "${HF_TOKEN:-}" ]; then
32
+ mkdir -p ~/.cache/huggingface
33
+ echo -n "$HF_TOKEN" > ~/.cache/huggingface/token
34
+ echo "HF token persisted"
35
+ fi
36
+
37
+ # Install zstandard if not available (needed for streaming decompression)
38
+ python3 -c "import zstandard" 2>/dev/null || pip install --no-cache-dir zstandard
39
+
40
+ # Build the command as an array to avoid shell injection
41
+ CMD=(python3 /opt/pawn/scripts/extract_lichess_parquet.py
42
+ --months $MONTHS
43
+ --output "${OUTPUT_DIR:-/workspace/lichess-parquet}"
44
+ --batch-size "${BATCH_SIZE:-500000}"
45
+ --shard-size "${SHARD_SIZE:-1000000}"
46
+ --seed "${SEED:-42}"
47
+ )
48
+
49
+ if [ -n "${HOLDOUT_MONTH:-}" ]; then
50
+ CMD+=(--holdout-month "$HOLDOUT_MONTH")
51
+ CMD+=(--holdout-games "${HOLDOUT_GAMES:-50000}")
52
+ fi
53
+ if [ -n "${HF_REPO:-}" ]; then
54
+ CMD+=(--hf-repo "$HF_REPO")
55
+ fi
56
+ if [ -n "${MAX_GAMES:-}" ]; then
57
+ CMD+=(--max-games "$MAX_GAMES")
58
+ fi
59
+
60
+ echo "Running: ${CMD[*]}"
61
+ echo ""
62
+ exec "${CMD[@]}"
engine/python/chess_engine/__init__.py CHANGED
@@ -26,6 +26,9 @@ from chess_engine._engine import (
26
  # PGN parsing
27
  parse_pgn_file,
28
  pgn_to_tokens,
 
 
 
29
  # UCI parsing
30
  parse_uci_file,
31
  uci_to_tokens,
@@ -60,6 +63,9 @@ __all__ = [
60
  "validate_games",
61
  "parse_pgn_file",
62
  "pgn_to_tokens",
 
 
 
63
  "parse_uci_file",
64
  "uci_to_tokens",
65
  "pgn_to_uci",
 
26
  # PGN parsing
27
  parse_pgn_file,
28
  pgn_to_tokens,
29
+ parse_pgn_enriched,
30
+ count_pgn_games_in_date_range,
31
+ parse_pgn_sampled,
32
  # UCI parsing
33
  parse_uci_file,
34
  uci_to_tokens,
 
63
  "validate_games",
64
  "parse_pgn_file",
65
  "pgn_to_tokens",
66
+ "parse_pgn_enriched",
67
+ "count_pgn_games_in_date_range",
68
+ "parse_pgn_sampled",
69
  "parse_uci_file",
70
  "uci_to_tokens",
71
  "pgn_to_uci",
engine/src/lib.rs CHANGED
@@ -809,6 +809,267 @@ fn pgn_to_uci(py: Python<'_>, games: Vec<Vec<String>>) -> PyResult<Vec<Vec<Strin
809
  Ok(results.into_iter().map(|(uci, _)| uci).collect())
810
  }
811
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
812
  // ---------------------------------------------------------------------------
813
  // UCI engine self-play generation
814
  // ---------------------------------------------------------------------------
@@ -1172,6 +1433,9 @@ fn _engine(m: &Bound<'_, PyModule>) -> PyResult<()> {
1172
  m.add_function(wrap_pyfunction!(parse_uci_file, m)?)?;
1173
  m.add_function(wrap_pyfunction!(uci_to_tokens, m)?)?;
1174
  m.add_function(wrap_pyfunction!(pgn_to_uci, m)?)?;
 
 
 
1175
  m.add_function(wrap_pyfunction!(generate_engine_games_py, m)?)?;
1176
  m.add_function(wrap_pyfunction!(compute_accuracy_ceiling_py, m)?)?;
1177
  Ok(())
 
809
  Ok(results.into_iter().map(|(uci, _)| uci).collect())
810
  }
811
 
812
+ // ---------------------------------------------------------------------------
813
+ // Enriched PGN parsing (for dataset construction)
814
+ // ---------------------------------------------------------------------------
815
+
816
+ /// Parse PGN text with full annotation extraction for dataset building.
817
+ ///
818
+ /// Extracts move tokens, clock annotations, eval annotations, and all PGN
819
+ /// headers in a single pass. Designed for streaming: Python passes chunks of
820
+ /// PGN text (containing complete games), Rust returns structured columns.
821
+ ///
822
+ /// Returns a dict with:
823
+ /// tokens: ndarray[i16, (N, max_ply)] — PAWN token IDs, 0-padded
824
+ /// clocks: ndarray[u16, (N, max_ply)] — seconds remaining, 0-padded
825
+ /// evals: ndarray[i16, (N, max_ply)] — centipawns, 0-padded (i16::MIN = no annotation)
826
+ /// game_lengths: ndarray[u16, (N,)] — number of plies per game
827
+ /// white_elo: ndarray[u16, (N,)] — white Elo (0 if missing)
828
+ /// black_elo: ndarray[u16, (N,)] — black Elo (0 if missing)
829
+ /// white_rating_diff: ndarray[i16, (N,)] — white rating change (0 if missing)
830
+ /// black_rating_diff: ndarray[i16, (N,)] — black rating change (0 if missing)
831
+ /// result: list[str] — "1-0", "0-1", "1/2-1/2", or ""
832
+ /// white: list[str] — white player name
833
+ /// black: list[str] — black player name
834
+ /// eco: list[str] — ECO code
835
+ /// opening: list[str] — opening name
836
+ /// time_control: list[str] — time control string
837
+ /// termination: list[str] — termination reason
838
+ /// date_time: list[str] — "YYYY.MM.DD HH:MM:SS" UTC
839
+ /// site: list[str] — game URL
840
+ #[pyfunction]
841
+ #[pyo3(signature = (content, max_ply=255, max_games=1_000_000, min_ply=1))]
842
+ fn parse_pgn_enriched<'py>(
843
+ py: Python<'py>,
844
+ content: &str,
845
+ max_ply: usize,
846
+ max_games: usize,
847
+ min_ply: usize,
848
+ ) -> PyResult<PyObject> {
849
+ let games = py.allow_threads(|| {
850
+ pgn::parse_pgn_enriched(content, max_ply, max_games, min_ply)
851
+ });
852
+
853
+ let n = games.len();
854
+ let dict = PyDict::new(py);
855
+
856
+ // Flat 0-padded arrays for tokens, clocks, evals (N * max_ply)
857
+ let mut flat_tokens = vec![0i16; n * max_ply];
858
+ let mut flat_clocks = vec![0u16; n * max_ply];
859
+ let mut flat_evals = vec![0i16; n * max_ply];
860
+
861
+ // Scalar arrays
862
+ let mut lengths_out = Vec::with_capacity(n);
863
+ let mut white_elo_out = Vec::with_capacity(n);
864
+ let mut black_elo_out = Vec::with_capacity(n);
865
+ let mut white_rd_out = Vec::with_capacity(n);
866
+ let mut black_rd_out = Vec::with_capacity(n);
867
+
868
+ // String lists
869
+ let mut result_out = Vec::with_capacity(n);
870
+ let mut white_out = Vec::with_capacity(n);
871
+ let mut black_out = Vec::with_capacity(n);
872
+ let mut eco_out = Vec::with_capacity(n);
873
+ let mut opening_out = Vec::with_capacity(n);
874
+ let mut tc_out = Vec::with_capacity(n);
875
+ let mut term_out = Vec::with_capacity(n);
876
+ let mut datetime_out = Vec::with_capacity(n);
877
+ let mut site_out = Vec::with_capacity(n);
878
+
879
+ for (gi, g) in games.iter().enumerate() {
880
+ let offset = gi * max_ply;
881
+ let len = g.game_length.min(max_ply);
882
+ for t in 0..len {
883
+ flat_tokens[offset + t] = g.tokens[t] as i16;
884
+ flat_clocks[offset + t] = g.clocks[t];
885
+ flat_evals[offset + t] = g.evals[t];
886
+ }
887
+
888
+ lengths_out.push(g.game_length as u16);
889
+
890
+ let h = &g.headers;
891
+ white_elo_out.push(h.get("WhiteElo").and_then(|s| s.parse::<u16>().ok()).unwrap_or(0));
892
+ black_elo_out.push(h.get("BlackElo").and_then(|s| s.parse::<u16>().ok()).unwrap_or(0));
893
+ white_rd_out.push(h.get("WhiteRatingDiff").and_then(|s| s.parse::<i16>().ok()).unwrap_or(0));
894
+ black_rd_out.push(h.get("BlackRatingDiff").and_then(|s| s.parse::<i16>().ok()).unwrap_or(0));
895
+
896
+ result_out.push(h.get("Result").cloned().unwrap_or_default());
897
+ white_out.push(h.get("White").cloned().unwrap_or_default());
898
+ black_out.push(h.get("Black").cloned().unwrap_or_default());
899
+ eco_out.push(h.get("ECO").cloned().unwrap_or_default());
900
+ opening_out.push(h.get("Opening").cloned().unwrap_or_default());
901
+ tc_out.push(h.get("TimeControl").cloned().unwrap_or_default());
902
+ term_out.push(h.get("Termination").cloned().unwrap_or_default());
903
+ site_out.push(h.get("Site").cloned().unwrap_or_default());
904
+
905
+ let date = h.get("UTCDate").cloned().unwrap_or_default();
906
+ let time = h.get("UTCTime").cloned().unwrap_or_default();
907
+ if !date.is_empty() && !time.is_empty() {
908
+ datetime_out.push(format!("{} {}", date, time));
909
+ } else {
910
+ datetime_out.push(date);
911
+ }
912
+ }
913
+
914
+ // 2D numpy arrays: (N, max_ply)
915
+ let tokens_arr = numpy::PyArray::from_vec(py, flat_tokens).reshape([n, max_ply])?;
916
+ let clocks_arr = numpy::PyArray::from_vec(py, flat_clocks).reshape([n, max_ply])?;
917
+ let evals_arr = numpy::PyArray::from_vec(py, flat_evals).reshape([n, max_ply])?;
918
+
919
+ // 1D numpy arrays
920
+ let lengths_arr = numpy::PyArray::from_vec(py, lengths_out);
921
+ let white_elo_arr = numpy::PyArray::from_vec(py, white_elo_out);
922
+ let black_elo_arr = numpy::PyArray::from_vec(py, black_elo_out);
923
+ let white_rd_arr = numpy::PyArray::from_vec(py, white_rd_out);
924
+ let black_rd_arr = numpy::PyArray::from_vec(py, black_rd_out);
925
+
926
+ dict.set_item("tokens", tokens_arr)?;
927
+ dict.set_item("clocks", clocks_arr)?;
928
+ dict.set_item("evals", evals_arr)?;
929
+ dict.set_item("game_lengths", lengths_arr)?;
930
+ dict.set_item("white_elo", white_elo_arr)?;
931
+ dict.set_item("black_elo", black_elo_arr)?;
932
+ dict.set_item("white_rating_diff", white_rd_arr)?;
933
+ dict.set_item("black_rating_diff", black_rd_arr)?;
934
+ dict.set_item("result", result_out)?;
935
+ dict.set_item("white", white_out)?;
936
+ dict.set_item("black", black_out)?;
937
+ dict.set_item("eco", eco_out)?;
938
+ dict.set_item("opening", opening_out)?;
939
+ dict.set_item("time_control", tc_out)?;
940
+ dict.set_item("termination", term_out)?;
941
+ dict.set_item("date_time", datetime_out)?;
942
+ dict.set_item("site", site_out)?;
943
+
944
+ Ok(dict.into())
945
+ }
946
+
947
+ /// Count games in a PGN string whose UTCDate falls within [date_start, date_end].
948
+ /// Header-only scan — no tokenization. Very fast.
949
+ #[pyfunction]
950
+ fn count_pgn_games_in_date_range(
951
+ py: Python<'_>,
952
+ content: &str,
953
+ date_start: &str,
954
+ date_end: &str,
955
+ ) -> PyResult<usize> {
956
+ let count = py.allow_threads(|| {
957
+ pgn::count_games_in_date_range(content, date_start, date_end)
958
+ });
959
+ Ok(count)
960
+ }
961
+
962
+ /// Parse only specific games (by index within a date range) from a PGN string.
963
+ ///
964
+ /// Used for uniform random sampling: call count_pgn_games_in_date_range first
965
+ /// to get the total, generate random indices in Python, then call this to
966
+ /// parse only those games. `game_offset` is the cumulative count of
967
+ /// date-matching games from previous chunks.
968
+ ///
969
+ /// Returns the same dict format as parse_pgn_enriched.
970
+ #[pyfunction]
971
+ #[pyo3(signature = (content, indices, date_start, date_end, game_offset=0, max_ply=255, min_ply=1))]
972
+ fn parse_pgn_sampled<'py>(
973
+ py: Python<'py>,
974
+ content: &str,
975
+ indices: Vec<usize>,
976
+ date_start: &str,
977
+ date_end: &str,
978
+ game_offset: usize,
979
+ max_ply: usize,
980
+ min_ply: usize,
981
+ ) -> PyResult<PyObject> {
982
+ let index_set: std::collections::HashSet<usize> = indices.into_iter().collect();
983
+
984
+ let games = py.allow_threads(|| {
985
+ pgn::parse_pgn_enriched_sampled(
986
+ content, max_ply, min_ply, date_start, date_end, &index_set, game_offset,
987
+ )
988
+ });
989
+
990
+ // Reuse the same dict-building logic as parse_pgn_enriched
991
+ let n = games.len();
992
+ let dict = PyDict::new(py);
993
+
994
+ let mut flat_tokens = vec![0i16; n * max_ply];
995
+ let mut flat_clocks = vec![0u16; n * max_ply];
996
+ let mut flat_evals = vec![0i16; n * max_ply];
997
+ let mut lengths_out = Vec::with_capacity(n);
998
+ let mut white_elo_out = Vec::with_capacity(n);
999
+ let mut black_elo_out = Vec::with_capacity(n);
1000
+ let mut white_rd_out = Vec::with_capacity(n);
1001
+ let mut black_rd_out = Vec::with_capacity(n);
1002
+ let mut result_out = Vec::with_capacity(n);
1003
+ let mut white_out = Vec::with_capacity(n);
1004
+ let mut black_out = Vec::with_capacity(n);
1005
+ let mut eco_out = Vec::with_capacity(n);
1006
+ let mut opening_out = Vec::with_capacity(n);
1007
+ let mut tc_out = Vec::with_capacity(n);
1008
+ let mut term_out = Vec::with_capacity(n);
1009
+ let mut datetime_out = Vec::with_capacity(n);
1010
+ let mut site_out = Vec::with_capacity(n);
1011
+
1012
+ for (gi, g) in games.iter().enumerate() {
1013
+ let offset = gi * max_ply;
1014
+ let len = g.game_length.min(max_ply);
1015
+ for t in 0..len {
1016
+ flat_tokens[offset + t] = g.tokens[t] as i16;
1017
+ flat_clocks[offset + t] = g.clocks[t];
1018
+ flat_evals[offset + t] = g.evals[t];
1019
+ }
1020
+ lengths_out.push(g.game_length as u16);
1021
+ let h = &g.headers;
1022
+ white_elo_out.push(h.get("WhiteElo").and_then(|s| s.parse::<u16>().ok()).unwrap_or(0));
1023
+ black_elo_out.push(h.get("BlackElo").and_then(|s| s.parse::<u16>().ok()).unwrap_or(0));
1024
+ white_rd_out.push(h.get("WhiteRatingDiff").and_then(|s| s.parse::<i16>().ok()).unwrap_or(0));
1025
+ black_rd_out.push(h.get("BlackRatingDiff").and_then(|s| s.parse::<i16>().ok()).unwrap_or(0));
1026
+ result_out.push(h.get("Result").cloned().unwrap_or_default());
1027
+ white_out.push(h.get("White").cloned().unwrap_or_default());
1028
+ black_out.push(h.get("Black").cloned().unwrap_or_default());
1029
+ eco_out.push(h.get("ECO").cloned().unwrap_or_default());
1030
+ opening_out.push(h.get("Opening").cloned().unwrap_or_default());
1031
+ tc_out.push(h.get("TimeControl").cloned().unwrap_or_default());
1032
+ term_out.push(h.get("Termination").cloned().unwrap_or_default());
1033
+ site_out.push(h.get("Site").cloned().unwrap_or_default());
1034
+ let date = h.get("UTCDate").cloned().unwrap_or_default();
1035
+ let time = h.get("UTCTime").cloned().unwrap_or_default();
1036
+ if !date.is_empty() && !time.is_empty() {
1037
+ datetime_out.push(format!("{} {}", date, time));
1038
+ } else {
1039
+ datetime_out.push(date);
1040
+ }
1041
+ }
1042
+
1043
+ let tokens_arr = numpy::PyArray::from_vec(py, flat_tokens).reshape([n, max_ply])?;
1044
+ let clocks_arr = numpy::PyArray::from_vec(py, flat_clocks).reshape([n, max_ply])?;
1045
+ let evals_arr = numpy::PyArray::from_vec(py, flat_evals).reshape([n, max_ply])?;
1046
+ let lengths_arr = numpy::PyArray::from_vec(py, lengths_out);
1047
+ let white_elo_arr = numpy::PyArray::from_vec(py, white_elo_out);
1048
+ let black_elo_arr = numpy::PyArray::from_vec(py, black_elo_out);
1049
+ let white_rd_arr = numpy::PyArray::from_vec(py, white_rd_out);
1050
+ let black_rd_arr = numpy::PyArray::from_vec(py, black_rd_out);
1051
+
1052
+ dict.set_item("tokens", tokens_arr)?;
1053
+ dict.set_item("clocks", clocks_arr)?;
1054
+ dict.set_item("evals", evals_arr)?;
1055
+ dict.set_item("game_lengths", lengths_arr)?;
1056
+ dict.set_item("white_elo", white_elo_arr)?;
1057
+ dict.set_item("black_elo", black_elo_arr)?;
1058
+ dict.set_item("white_rating_diff", white_rd_arr)?;
1059
+ dict.set_item("black_rating_diff", black_rd_arr)?;
1060
+ dict.set_item("result", result_out)?;
1061
+ dict.set_item("white", white_out)?;
1062
+ dict.set_item("black", black_out)?;
1063
+ dict.set_item("eco", eco_out)?;
1064
+ dict.set_item("opening", opening_out)?;
1065
+ dict.set_item("time_control", tc_out)?;
1066
+ dict.set_item("termination", term_out)?;
1067
+ dict.set_item("date_time", datetime_out)?;
1068
+ dict.set_item("site", site_out)?;
1069
+
1070
+ Ok(dict.into())
1071
+ }
1072
+
1073
  // ---------------------------------------------------------------------------
1074
  // UCI engine self-play generation
1075
  // ---------------------------------------------------------------------------
 
1433
  m.add_function(wrap_pyfunction!(parse_uci_file, m)?)?;
1434
  m.add_function(wrap_pyfunction!(uci_to_tokens, m)?)?;
1435
  m.add_function(wrap_pyfunction!(pgn_to_uci, m)?)?;
1436
+ m.add_function(wrap_pyfunction!(parse_pgn_enriched, m)?)?;
1437
+ m.add_function(wrap_pyfunction!(count_pgn_games_in_date_range, m)?)?;
1438
+ m.add_function(wrap_pyfunction!(parse_pgn_sampled, m)?)?;
1439
  m.add_function(wrap_pyfunction!(generate_engine_games_py, m)?)?;
1440
  m.add_function(wrap_pyfunction!(compute_accuracy_ceiling_py, m)?)?;
1441
  Ok(())
engine/src/pgn.rs CHANGED
@@ -3,7 +3,11 @@
3
  //! Full pipeline in Rust: reads PGN files, extracts SAN move strings,
4
  //! converts to PAWN tokens via shakmaty. Uses rayon for parallel
5
  //! token conversion.
 
 
 
6
 
 
7
  use std::fs;
8
  use rayon::prelude::*;
9
  use shakmaty::{Chess, Position};
@@ -11,6 +15,420 @@ use shakmaty::san::San;
11
 
12
  use crate::board::move_to_token;
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  /// Convert a sequence of SAN move strings to PAWN token indices.
15
  ///
16
  /// Returns (tokens, n_valid) where tokens has length up to max_ply,
@@ -277,4 +695,214 @@ mod tests {
277
 
278
  fs::remove_file(path).ok();
279
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  }
 
3
  //! Full pipeline in Rust: reads PGN files, extracts SAN move strings,
4
  //! converts to PAWN tokens via shakmaty. Uses rayon for parallel
5
  //! token conversion.
6
+ //!
7
+ //! Also provides enriched parsing that extracts clock annotations,
8
+ //! eval annotations, and PGN headers for dataset construction.
9
 
10
+ use std::collections::{HashMap, HashSet};
11
  use std::fs;
12
  use rayon::prelude::*;
13
  use shakmaty::{Chess, Position};
 
15
 
16
  use crate::board::move_to_token;
17
 
18
+ // ---------------------------------------------------------------------------
19
+ // Enriched PGN parsing — extracts moves, clocks, evals, and headers
20
+ // ---------------------------------------------------------------------------
21
+
22
+ /// A fully parsed game with move tokens, annotations, and metadata.
23
+ pub struct EnrichedGame {
24
+ /// PAWN token indices for each ply (not padded).
25
+ pub tokens: Vec<u16>,
26
+ /// Seconds remaining on clock after each ply (0 = no annotation).
27
+ pub clocks: Vec<u16>,
28
+ /// Centipawns from white's perspective after each ply.
29
+ /// Mate scores: ±(32767-N). No annotation: 0x8000 (-32768 as i16).
30
+ pub evals: Vec<i16>,
31
+ /// Number of valid plies.
32
+ pub game_length: usize,
33
+ /// PGN header fields (e.g., "White" -> "alice", "WhiteElo" -> "1873").
34
+ pub headers: HashMap<String, String>,
35
+ }
36
+
37
+ /// Parse a PGN string into enriched games.
38
+ ///
39
+ /// Extracts SAN moves (tokenized), `[%clk h:mm:ss]` annotations,
40
+ /// `[%eval ±N.NN]` / `[%eval #±N]` annotations, and all PGN headers.
41
+ /// Tokenization uses shakmaty and is parallelized with rayon.
42
+ pub fn parse_pgn_enriched(
43
+ content: &str,
44
+ max_ply: usize,
45
+ max_games: usize,
46
+ min_ply: usize,
47
+ ) -> Vec<EnrichedGame> {
48
+ let raw_games = parse_raw_games(content, max_games, None, None);
49
+
50
+ // Phase 2: parallel tokenization + annotation extraction
51
+ raw_games
52
+ .into_par_iter()
53
+ .filter_map(|raw| {
54
+ let (san_moves, clocks_raw, evals_raw) = extract_moves_and_annotations(&raw.movetext);
55
+ if san_moves.len() < min_ply {
56
+ return None;
57
+ }
58
+
59
+ // Tokenize SAN moves via shakmaty
60
+ let refs: Vec<&str> = san_moves.iter().map(|s| s.as_str()).collect();
61
+ let (tokens, n_valid) = san_moves_to_tokens(&refs, max_ply);
62
+ if n_valid < min_ply {
63
+ return None;
64
+ }
65
+
66
+ // Trim annotations to match token count (moves may have failed to parse).
67
+ let clocks = clocks_raw.into_iter().take(n_valid).collect();
68
+ let evals = evals_raw.into_iter().take(n_valid).collect();
69
+
70
+ Some(EnrichedGame {
71
+ tokens,
72
+ clocks,
73
+ evals,
74
+ game_length: n_valid,
75
+ headers: raw.headers,
76
+ })
77
+ })
78
+ .collect()
79
+ }
80
+
81
+ /// Count games in a PGN string whose UTCDate falls within [start, end].
82
+ ///
83
+ /// Header-only scan — no movetext parsing, no tokenization.
84
+ /// Returns (count_in_range, offset) where offset is the running game index
85
+ /// that should be passed to the next chunk for correct global indexing.
86
+ pub fn count_games_in_date_range(
87
+ content: &str,
88
+ date_start: &str,
89
+ date_end: &str,
90
+ ) -> usize {
91
+ let mut count = 0;
92
+ let mut current_date: Option<String> = None;
93
+ let mut in_movetext = false;
94
+
95
+ for line in content.lines() {
96
+ let line = line.trim();
97
+ if line.is_empty() {
98
+ if in_movetext {
99
+ // End of game — check if the date was in range
100
+ if let Some(ref d) = current_date {
101
+ if d.as_str() >= date_start && d.as_str() <= date_end {
102
+ count += 1;
103
+ }
104
+ }
105
+ current_date = None;
106
+ in_movetext = false;
107
+ }
108
+ continue;
109
+ }
110
+
111
+ if line.starts_with('[') && line.ends_with(']') {
112
+ if let Some((key, value)) = parse_header_line(line) {
113
+ if key == "UTCDate" {
114
+ current_date = Some(value);
115
+ }
116
+ }
117
+ in_movetext = false;
118
+ } else {
119
+ in_movetext = true;
120
+ }
121
+ }
122
+
123
+ // Handle last game
124
+ if in_movetext {
125
+ if let Some(ref d) = current_date {
126
+ if d.as_str() >= date_start && d.as_str() <= date_end {
127
+ count += 1;
128
+ }
129
+ }
130
+ }
131
+
132
+ count
133
+ }
134
+
135
+ /// Parse a PGN string, but only tokenize games at specific indices within a
136
+ /// date range. Used for uniform random sampling: Python counts games in the
137
+ /// date range (via `count_games_in_date_range`), generates a random index
138
+ /// set, then calls this to parse only those games.
139
+ ///
140
+ /// `indices` are 0-based within the date-range-matching games of this chunk.
141
+ /// `game_offset` is the number of date-matching games seen in previous chunks,
142
+ /// so global index = game_offset + local_index.
143
+ pub fn parse_pgn_enriched_sampled(
144
+ content: &str,
145
+ max_ply: usize,
146
+ min_ply: usize,
147
+ date_start: &str,
148
+ date_end: &str,
149
+ indices: &HashSet<usize>,
150
+ game_offset: usize,
151
+ ) -> Vec<EnrichedGame> {
152
+ let raw_games = parse_raw_games(content, usize::MAX, Some((date_start, date_end)), Some((indices, game_offset)));
153
+
154
+ raw_games
155
+ .into_par_iter()
156
+ .filter_map(|raw| {
157
+ let (san_moves, clocks_raw, evals_raw) = extract_moves_and_annotations(&raw.movetext);
158
+ if san_moves.len() < min_ply {
159
+ return None;
160
+ }
161
+
162
+ let refs: Vec<&str> = san_moves.iter().map(|s| s.as_str()).collect();
163
+ let (tokens, n_valid) = san_moves_to_tokens(&refs, max_ply);
164
+ if n_valid < min_ply {
165
+ return None;
166
+ }
167
+
168
+ let clocks = clocks_raw.into_iter().take(n_valid).collect();
169
+ let evals = evals_raw.into_iter().take(n_valid).collect();
170
+
171
+ Some(EnrichedGame {
172
+ tokens,
173
+ clocks,
174
+ evals,
175
+ game_length: n_valid,
176
+ headers: raw.headers,
177
+ })
178
+ })
179
+ .collect()
180
+ }
181
+
182
+ /// Raw game data before tokenization.
183
+ struct RawGame {
184
+ headers: HashMap<String, String>,
185
+ movetext: String,
186
+ }
187
+
188
+ /// Single-threaded PGN line scanner. Extracts headers and raw movetext.
189
+ ///
190
+ /// If `date_range` is Some((start, end)), only games whose UTCDate falls
191
+ /// within [start, end] are included. If `sample` is Some((indices, offset)),
192
+ /// only games whose (offset + local_index) is in the index set are kept.
193
+ fn parse_raw_games(
194
+ content: &str,
195
+ max_games: usize,
196
+ date_range: Option<(&str, &str)>,
197
+ sample: Option<(&HashSet<usize>, usize)>,
198
+ ) -> Vec<RawGame> {
199
+ let mut games = Vec::new();
200
+ let mut headers: HashMap<String, String> = HashMap::new();
201
+ let mut movetext_lines: Vec<&str> = Vec::new();
202
+ let mut in_movetext = false;
203
+ let mut date_excluded = false; // UTCDate outside date_range
204
+ let mut has_utc_date = false; // saw a UTCDate header for this game
205
+ let mut date_matched_idx = 0usize; // count of date-matching games seen
206
+
207
+ for line in content.lines() {
208
+ let line = line.trim();
209
+ if line.is_empty() {
210
+ if in_movetext {
211
+ // End of movetext — game boundary.
212
+ // Exclude if date is out of range OR if date_range is active
213
+ // but no UTCDate header was found (consistent with count_games_in_date_range).
214
+ let excluded = date_excluded || (date_range.is_some() && !has_utc_date);
215
+ if !excluded && !movetext_lines.is_empty() {
216
+ // Game passed date filter. Check sample if present.
217
+ let keep = match sample {
218
+ Some((indices, offset)) => indices.contains(&(offset + date_matched_idx)),
219
+ None => true,
220
+ };
221
+ date_matched_idx += 1;
222
+
223
+ if keep {
224
+ games.push(RawGame {
225
+ headers: std::mem::take(&mut headers),
226
+ movetext: movetext_lines.join(" "),
227
+ });
228
+ if games.len() >= max_games {
229
+ break;
230
+ }
231
+ }
232
+ }
233
+ movetext_lines.clear();
234
+ headers.clear();
235
+ in_movetext = false;
236
+ date_excluded = false;
237
+ has_utc_date = false;
238
+ }
239
+ // Blank line between headers and movetext: don't reset state
240
+ continue;
241
+ }
242
+
243
+ // Header line: [Key "Value"]
244
+ if line.starts_with('[') && line.ends_with(']') {
245
+ if let Some((key, value)) = parse_header_line(line) {
246
+ if key == "UTCDate" {
247
+ has_utc_date = true;
248
+ if let Some((start, end)) = date_range {
249
+ if value.as_str() < start || value.as_str() > end {
250
+ date_excluded = true;
251
+ }
252
+ }
253
+ }
254
+ if !date_excluded {
255
+ headers.insert(key, value);
256
+ }
257
+ }
258
+ in_movetext = false;
259
+ continue;
260
+ }
261
+
262
+ if !date_excluded {
263
+ in_movetext = true;
264
+ movetext_lines.push(line);
265
+ }
266
+ }
267
+
268
+ // Handle last game
269
+ let last_excluded = date_excluded || (date_range.is_some() && !has_utc_date);
270
+ if in_movetext && !last_excluded && !movetext_lines.is_empty() && games.len() < max_games {
271
+ let keep = match sample {
272
+ Some((indices, offset)) => indices.contains(&(offset + date_matched_idx)),
273
+ None => true,
274
+ };
275
+ if keep {
276
+ games.push(RawGame {
277
+ headers: std::mem::take(&mut headers),
278
+ movetext: movetext_lines.join(" "),
279
+ });
280
+ }
281
+ }
282
+
283
+ games
284
+ }
285
+
286
+ /// Parse a PGN header line like `[White "alice"]` into ("White", "alice").
287
+ fn parse_header_line(line: &str) -> Option<(String, String)> {
288
+ // Strip surrounding brackets
289
+ let inner = line.strip_prefix('[')?.strip_suffix(']')?.trim();
290
+ let space = inner.find(' ')?;
291
+ let key = inner[..space].to_string();
292
+ let value_part = inner[space..].trim();
293
+ // Strip surrounding quotes
294
+ let value = value_part
295
+ .strip_prefix('"')
296
+ .and_then(|v| v.strip_suffix('"'))
297
+ .unwrap_or(value_part)
298
+ .to_string();
299
+ Some((key, value))
300
+ }
301
+
302
+ /// Sentinel for "no clock annotation" (0x8000 as u16 = 32768).
303
+ const CLOCK_NONE: u16 = 0x8000;
304
+ /// Sentinel for "no eval annotation" (0x8000 as i16 = -32768).
305
+ const EVAL_NONE: i16 = -0x8000; // i16::MIN
306
+
307
+ /// Extract SAN moves, clock annotations, and eval annotations from movetext.
308
+ ///
309
+ /// Returns (san_moves, clocks, evals) where clocks[i] is the clock after
310
+ /// move i (CLOCK_NONE if no annotation) and evals[i] is centipawns after
311
+ /// move i (EVAL_NONE if no annotation).
312
+ fn extract_moves_and_annotations(text: &str) -> (Vec<String>, Vec<u16>, Vec<i16>) {
313
+ let mut moves = Vec::new();
314
+ let mut clocks = Vec::new();
315
+ let mut evals = Vec::new();
316
+
317
+ let bytes = text.as_bytes();
318
+ let len = bytes.len();
319
+
320
+ // Lichess format: move { comment } move { comment } ...
321
+ // The comment annotates the move immediately before it.
322
+ let mut i = 0;
323
+ while i < len {
324
+ if bytes[i].is_ascii_whitespace() {
325
+ i += 1;
326
+ continue;
327
+ }
328
+
329
+ // Comment: { ... } — applies to the last pushed move
330
+ if bytes[i] == b'{' {
331
+ i += 1;
332
+ let start = i;
333
+ while i < len && bytes[i] != b'}' {
334
+ i += 1;
335
+ }
336
+ let comment = &text[start..i];
337
+ if i < len { i += 1; }
338
+
339
+ // Apply to last move
340
+ if let Some(last_clk) = clocks.last_mut() {
341
+ let mut clk = CLOCK_NONE;
342
+ let mut ev = EVAL_NONE;
343
+ parse_comment(comment, &mut clk, &mut ev);
344
+ if clk != CLOCK_NONE { *last_clk = clk; }
345
+ if ev != EVAL_NONE {
346
+ if let Some(last_ev) = evals.last_mut() {
347
+ *last_ev = ev;
348
+ }
349
+ }
350
+ }
351
+ continue;
352
+ }
353
+
354
+ let start = i;
355
+ while i < len && !bytes[i].is_ascii_whitespace() && bytes[i] != b'{' {
356
+ i += 1;
357
+ }
358
+ let token = &text[start..i];
359
+
360
+ if token.starts_with('$') { continue; }
361
+ if token == "1-0" || token == "0-1" || token == "1/2-1/2" || token == "*" { break; }
362
+
363
+ let stripped = token.trim_end_matches('.');
364
+ if !stripped.is_empty() && stripped.bytes().all(|b| b.is_ascii_digit()) { continue; }
365
+
366
+ moves.push(token.to_string());
367
+ clocks.push(CLOCK_NONE);
368
+ evals.push(EVAL_NONE);
369
+ }
370
+
371
+ (moves, clocks, evals)
372
+ }
373
+
374
+ /// Parse a PGN comment body for clock and eval annotations.
375
+ ///
376
+ /// Lichess format: `[%clk 0:03:00]` and `[%eval 1.23]` or `[%eval #-3]`.
377
+ fn parse_comment(comment: &str, clock: &mut u16, eval: &mut i16) {
378
+ // Clock: [%clk H:MM:SS]
379
+ if let Some(pos) = comment.find("[%clk ") {
380
+ let rest = &comment[pos + 6..];
381
+ if let Some(end) = rest.find(']') {
382
+ let clk_str = rest[..end].trim();
383
+ if let Some(secs) = parse_clock(clk_str) {
384
+ *clock = secs;
385
+ }
386
+ }
387
+ }
388
+
389
+ // Eval: [%eval 1.23] or [%eval #-3]
390
+ if let Some(pos) = comment.find("[%eval ") {
391
+ let rest = &comment[pos + 7..];
392
+ if let Some(end) = rest.find(']') {
393
+ let eval_str = rest[..end].trim();
394
+ if let Some(cp) = parse_eval(eval_str) {
395
+ *eval = cp;
396
+ }
397
+ }
398
+ }
399
+ }
400
+
401
+ /// Parse "H:MM:SS" into total seconds as u16.
402
+ fn parse_clock(s: &str) -> Option<u16> {
403
+ let parts: Vec<&str> = s.split(':').collect();
404
+ if parts.len() != 3 { return None; }
405
+ let h: u32 = parts[0].parse().ok()?;
406
+ let m: u32 = parts[1].parse().ok()?;
407
+ let s: u32 = parts[2].parse().ok()?;
408
+ let total = h * 3600 + m * 60 + s;
409
+ // Cap at 0x7FFF (32767) to avoid collision with CLOCK_NONE (0x8000)
410
+ Some(total.min(0x7FFF) as u16)
411
+ }
412
+
413
+ /// Parse eval string into centipawns (i16).
414
+ /// "1.23" → 123, "-0.50" → -50.
415
+ /// Mate scores: "#N" → 32767-N, "#-N" → -(32767-N).
416
+ /// Bit 14 is always set for mates, making them detectable via bitmask.
417
+ /// Centipawn values are clamped to ±16383 to avoid overlap with the mate range.
418
+ fn parse_eval(s: &str) -> Option<i16> {
419
+ if s.starts_with('#') {
420
+ let rest = &s[1..];
421
+ let n: i32 = rest.parse().ok()?;
422
+ let abs_n = n.unsigned_abs().max(1) as i16;
423
+ let mate_val = 32767 - abs_n;
424
+ Some(if n > 0 { mate_val } else { -mate_val })
425
+ } else {
426
+ let f: f64 = s.parse().ok()?;
427
+ let cp = (f * 100.0).round() as i32;
428
+ Some(cp.clamp(-16383, 16383) as i16)
429
+ }
430
+ }
431
+
432
  /// Convert a sequence of SAN move strings to PAWN token indices.
433
  ///
434
  /// Returns (tokens, n_valid) where tokens has length up to max_ply,
 
695
 
696
  fs::remove_file(path).ok();
697
  }
698
+
699
+ // --- Enriched parsing tests ---
700
+
701
+ #[test]
702
+ fn test_parse_clock() {
703
+ assert_eq!(parse_clock("0:10:00"), Some(600));
704
+ assert_eq!(parse_clock("1:30:00"), Some(5400));
705
+ assert_eq!(parse_clock("0:00:05"), Some(5));
706
+ assert_eq!(parse_clock("0:03:00"), Some(180));
707
+ assert_eq!(parse_clock("bad"), None);
708
+ }
709
+
710
+ #[test]
711
+ fn test_parse_eval() {
712
+ assert_eq!(parse_eval("0.23"), Some(23));
713
+ assert_eq!(parse_eval("-1.50"), Some(-150));
714
+ assert_eq!(parse_eval("0.00"), Some(0));
715
+ // Mate scores: 32767 - N
716
+ assert_eq!(parse_eval("#1"), Some(32766));
717
+ assert_eq!(parse_eval("#-1"), Some(-32766));
718
+ assert_eq!(parse_eval("#3"), Some(32764));
719
+ assert_eq!(parse_eval("#-3"), Some(-32764));
720
+ assert_eq!(parse_eval("#10"), Some(32757));
721
+ // Bit 14 (0x4000 = 16384) is set for all mate values
722
+ assert!(parse_eval("#1").unwrap() & 0x4000 != 0);
723
+ assert!(parse_eval("#100").unwrap() & 0x4000 != 0);
724
+ // Centipawns clamped to ±16383 to avoid mate range
725
+ assert_eq!(parse_eval("200.00"), Some(16383));
726
+ assert_eq!(parse_eval("-200.00"), Some(-16383));
727
+ }
728
+
729
+ #[test]
730
+ fn test_parse_header_line() {
731
+ assert_eq!(
732
+ parse_header_line(r#"[White "alice"]"#),
733
+ Some(("White".to_string(), "alice".to_string()))
734
+ );
735
+ assert_eq!(
736
+ parse_header_line(r#"[WhiteElo "1873"]"#),
737
+ Some(("WhiteElo".to_string(), "1873".to_string()))
738
+ );
739
+ assert_eq!(
740
+ parse_header_line(r#"[Opening "Bird Opening: Dutch Variation"]"#),
741
+ Some(("Opening".to_string(), "Bird Opening: Dutch Variation".to_string()))
742
+ );
743
+ }
744
+
745
+ #[test]
746
+ fn test_extract_moves_and_annotations() {
747
+ let text = r#"1. e4 { [%clk 0:10:00] [%eval 0.23] } 1... e5 { [%clk 0:09:58] [%eval 0.31] } 2. Nf3 { [%clk 0:09:55] } 1-0"#;
748
+ let (moves, clocks, evals) = extract_moves_and_annotations(text);
749
+ assert_eq!(moves, vec!["e4", "e5", "Nf3"]);
750
+ assert_eq!(clocks, vec![600, 598, 595]);
751
+ assert_eq!(evals, vec![23, 31, EVAL_NONE]);
752
+ }
753
+
754
+ #[test]
755
+ fn test_extract_moves_no_annotations() {
756
+ let text = "1. e4 e5 2. Nf3 Nc6 1-0";
757
+ let (moves, clocks, evals) = extract_moves_and_annotations(text);
758
+ assert_eq!(moves, vec!["e4", "e5", "Nf3", "Nc6"]);
759
+ assert_eq!(clocks, vec![CLOCK_NONE, CLOCK_NONE, CLOCK_NONE, CLOCK_NONE]);
760
+ assert_eq!(evals, vec![EVAL_NONE; 4]);
761
+ }
762
+
763
+ #[test]
764
+ fn test_extract_moves_mate_eval() {
765
+ let text = r#"1. e4 { [%eval 0.23] } 1... e5 { [%eval #-3] } 1-0"#;
766
+ let (moves, _clocks, evals) = extract_moves_and_annotations(text);
767
+ assert_eq!(moves, vec!["e4", "e5"]);
768
+ assert_eq!(evals, vec![23, -32764]);
769
+ }
770
+
771
+ #[test]
772
+ fn test_enriched_full_game() {
773
+ let pgn = r#"[Event "Rated Rapid game"]
774
+ [Site "https://lichess.org/abc123"]
775
+ [White "alice"]
776
+ [Black "bob"]
777
+ [Result "1-0"]
778
+ [WhiteElo "1873"]
779
+ [BlackElo "1844"]
780
+ [WhiteRatingDiff "+6"]
781
+ [BlackRatingDiff "-26"]
782
+ [ECO "C20"]
783
+ [Opening "King's Pawn Game"]
784
+ [TimeControl "600+0"]
785
+ [Termination "Normal"]
786
+ [UTCDate "2025.01.15"]
787
+ [UTCTime "12:30:00"]
788
+
789
+ 1. e4 { [%clk 0:10:00] [%eval 0.23] } 1... e5 { [%clk 0:09:58] [%eval 0.31] } 2. Nf3 { [%clk 0:09:50] [%eval 0.25] } 2... Nc6 { [%clk 0:09:45] [%eval 0.30] } 1-0
790
+ "#;
791
+ let games = parse_pgn_enriched(pgn, 256, 100, 2);
792
+ assert_eq!(games.len(), 1);
793
+ let g = &games[0];
794
+ assert_eq!(g.game_length, 4);
795
+ assert_eq!(g.clocks, vec![600, 598, 590, 585]);
796
+ assert_eq!(g.evals, vec![23, 31, 25, 30]);
797
+ assert_eq!(g.headers.get("White").unwrap(), "alice");
798
+ assert_eq!(g.headers.get("WhiteElo").unwrap(), "1873");
799
+ assert_eq!(g.headers.get("Site").unwrap(), "https://lichess.org/abc123");
800
+ assert_eq!(g.headers.get("ECO").unwrap(), "C20");
801
+ assert_eq!(g.headers.get("TimeControl").unwrap(), "600+0");
802
+ }
803
+
804
+ #[test]
805
+ fn test_enriched_tokens_match_legacy() {
806
+ // Enriched parsing should produce the same tokens as the legacy pipeline
807
+ let pgn = r#"[Event "Test"]
808
+
809
+ 1. e4 { [%clk 0:10:00] } 1... e5 { [%clk 0:09:58] } 2. Nf3 { [%clk 0:09:50] } 2... Nc6 { [%clk 0:09:45] } 1-0
810
+ "#;
811
+ let enriched = parse_pgn_enriched(pgn, 256, 100, 2);
812
+ let legacy = parse_pgn_to_san(pgn, 100);
813
+
814
+ assert_eq!(enriched.len(), 1);
815
+ assert_eq!(legacy.len(), 1);
816
+
817
+ // Convert legacy SAN to tokens for comparison
818
+ let refs: Vec<&str> = legacy[0].iter().map(|s| s.as_str()).collect();
819
+ let (legacy_tokens, legacy_n) = san_moves_to_tokens(&refs, 256);
820
+
821
+ assert_eq!(enriched[0].tokens, legacy_tokens);
822
+ assert_eq!(enriched[0].game_length, legacy_n);
823
+ }
824
+
825
+ #[test]
826
+ fn test_count_games_in_date_range() {
827
+ let pgn = r#"[Event "Game 1"]
828
+ [UTCDate "2023.12.05"]
829
+
830
+ 1. e4 e5 1-0
831
+
832
+ [Event "Game 2"]
833
+ [UTCDate "2023.12.20"]
834
+
835
+ 1. d4 d5 0-1
836
+
837
+ [Event "Game 3"]
838
+ [UTCDate "2025.01.15"]
839
+
840
+ 1. e4 c5 1-0
841
+ "#;
842
+ assert_eq!(count_games_in_date_range(pgn, "2023.12.01", "2023.12.31"), 2);
843
+ assert_eq!(count_games_in_date_range(pgn, "2023.12.01", "2023.12.14"), 1);
844
+ assert_eq!(count_games_in_date_range(pgn, "2023.12.15", "2023.12.31"), 1);
845
+ assert_eq!(count_games_in_date_range(pgn, "2025.01.01", "2025.01.31"), 1);
846
+ assert_eq!(count_games_in_date_range(pgn, "2024.01.01", "2024.12.31"), 0);
847
+ }
848
+
849
+ #[test]
850
+ fn test_sampled_parsing() {
851
+ let pgn = r#"[Event "Game 1"]
852
+ [UTCDate "2023.12.05"]
853
+
854
+ 1. e4 e5 1-0
855
+
856
+ [Event "Game 2"]
857
+ [UTCDate "2023.12.10"]
858
+
859
+ 1. d4 d5 0-1
860
+
861
+ [Event "Game 3"]
862
+ [UTCDate "2023.12.20"]
863
+
864
+ 1. e4 c5 1-0
865
+
866
+ [Event "Game 4"]
867
+ [UTCDate "2025.01.15"]
868
+
869
+ 1. Nf3 d5 1-0
870
+ "#;
871
+ // 3 games match Dec 2023 (indices 0, 1, 2 within the date range)
872
+ assert_eq!(count_games_in_date_range(pgn, "2023.12.01", "2023.12.31"), 3);
873
+
874
+ // Sample only index 1 (Game 2)
875
+ let indices: HashSet<usize> = HashSet::from([1]);
876
+ let sampled = parse_pgn_enriched_sampled(
877
+ pgn, 256, 2, "2023.12.01", "2023.12.31", &indices, 0,
878
+ );
879
+ assert_eq!(sampled.len(), 1);
880
+ assert_eq!(sampled[0].headers.get("UTCDate").unwrap(), "2023.12.10");
881
+
882
+ // Sample indices 0 and 2 (Game 1 and Game 3)
883
+ let indices: HashSet<usize> = HashSet::from([0, 2]);
884
+ let sampled = parse_pgn_enriched_sampled(
885
+ pgn, 256, 2, "2023.12.01", "2023.12.31", &indices, 0,
886
+ );
887
+ assert_eq!(sampled.len(), 2);
888
+ assert_eq!(sampled[0].headers.get("UTCDate").unwrap(), "2023.12.05");
889
+ assert_eq!(sampled[1].headers.get("UTCDate").unwrap(), "2023.12.20");
890
+
891
+ // Sample with offset: simulating a second chunk where previous chunk had 1 match.
892
+ // Global index 2 = offset 1 + local index 1 => selects Game 2 (local idx 1).
893
+ let indices: HashSet<usize> = HashSet::from([2]);
894
+ let sampled = parse_pgn_enriched_sampled(
895
+ pgn, 256, 2, "2023.12.01", "2023.12.31", &indices, 1,
896
+ );
897
+ assert_eq!(sampled.len(), 1);
898
+ assert_eq!(sampled[0].headers.get("UTCDate").unwrap(), "2023.12.10");
899
+
900
+ // Offset that skips all local games: offset=3 means local indices are 3,4,5
901
+ // but we only ask for global index 0, which isn't in this chunk.
902
+ let indices: HashSet<usize> = HashSet::from([0]);
903
+ let sampled = parse_pgn_enriched_sampled(
904
+ pgn, 256, 2, "2023.12.01", "2023.12.31", &indices, 3,
905
+ );
906
+ assert_eq!(sampled.len(), 0);
907
+ }
908
  }
scripts/extract_lichess_parquet.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Extract Lichess monthly PGN database dumps into PAWN-compatible Parquet.
3
+
4
+ Downloads a zstd-compressed PGN from database.lichess.org, parses games via
5
+ the Rust chess engine (tokens, clocks, evals, headers), builds a Polars
6
+ DataFrame, and writes sharded Parquet to disk with train/val/test splits.
7
+
8
+ The output schema stores pre-tokenized move sequences as list[int16],
9
+ clock annotations as list[uint16] (seconds remaining, 0x8000=missing),
10
+ eval annotations as list[int16] (centipawns, mate=±(32767-N),
11
+ 0x8000=missing), and metadata columns. Player usernames are hashed to
12
+ uint64 via Polars xxHash64 (deterministic within a Polars version).
13
+
14
+ Training months are written as chronologically-ordered shards. Holdout
15
+ val/test data is uniformly sampled from a separate month via two-pass
16
+ Rust-side date filtering.
17
+
18
+ Designed to run on a CPU pod with the pawn Docker image.
19
+
20
+ Usage:
21
+ python scripts/extract_lichess_parquet.py \\
22
+ --months 2025-01 2025-02 2025-03 \\
23
+ --output /workspace/lichess-parquet \\
24
+ --hf-repo thomas-schweich/lichess-pawn \\
25
+ --batch-size 500000
26
+ """
27
+
28
+ import argparse
29
+ import io
30
+ import os
31
+ import sys
32
+ import time
33
+ import urllib.request
34
+ from datetime import datetime
35
+ from pathlib import Path
36
+
37
+ import numpy as np
38
+ import chess_engine
39
+ import polars as pl
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Constants
44
+ # ---------------------------------------------------------------------------
45
+
46
+ LICHESS_URL_TEMPLATE = (
47
+ "https://database.lichess.org/standard/"
48
+ "lichess_db_standard_rated_{year_month}.pgn.zst"
49
+ )
50
+ MAX_PLY = 255 # Max plies per game (token sequence = outcome + 255 plies)
51
+ EVAL_MISSING = -32768 # i16::MIN sentinel for missing eval
52
+ SHARD_TARGET_GAMES = 1_000_000 # Target games per shard
53
+
54
+
55
+ def log(msg: str) -> None:
56
+ ts = datetime.now().strftime("%H:%M:%S")
57
+ print(f"[{ts}] {msg}", flush=True)
58
+
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # PGN streaming
62
+ # ---------------------------------------------------------------------------
63
+
64
+ def stream_pgn_games(fileobj, batch_size: int):
65
+ """Yield batches of complete PGN game strings from a text stream.
66
+
67
+ Each batch is a single string containing `batch_size` complete games
68
+ (delimited by blank lines between the last movetext and next header).
69
+ """
70
+ import zstandard as zstd
71
+
72
+ dctx = zstd.ZstdDecompressor()
73
+ reader = dctx.stream_reader(fileobj)
74
+ text_reader = io.TextIOWrapper(reader, encoding="latin-1", errors="replace")
75
+
76
+ buf = []
77
+ game_count = 0
78
+ in_movetext = False
79
+
80
+ for line in text_reader:
81
+ stripped = line.strip()
82
+
83
+ if not stripped:
84
+ if in_movetext:
85
+ # End of movetext — game boundary
86
+ game_count += 1
87
+ in_movetext = False
88
+ buf.append(line)
89
+ if game_count >= batch_size:
90
+ yield "".join(buf), game_count
91
+ buf.clear()
92
+ game_count = 0
93
+ continue
94
+ buf.append(line)
95
+ continue
96
+
97
+ if stripped.startswith("["):
98
+ in_movetext = False
99
+ else:
100
+ in_movetext = True
101
+
102
+ buf.append(line)
103
+
104
+ # Final batch
105
+ if buf:
106
+ yield "".join(buf), game_count
107
+
108
+
109
+ def download_zst(year_month: str, output_dir: Path) -> Path:
110
+ """Download a Lichess zstd PGN dump to disk. Returns path to .zst file."""
111
+ url = LICHESS_URL_TEMPLATE.format(year_month=year_month)
112
+ zst_path = output_dir / f"lichess_{year_month}.pgn.zst"
113
+ if zst_path.exists():
114
+ log(f" Using cached {zst_path} ({zst_path.stat().st_size / 1e9:.1f} GB)")
115
+ return zst_path
116
+
117
+ log(f" Downloading {url}")
118
+ req = urllib.request.Request(url)
119
+ req.add_header("User-Agent", "pawn-lichess-extract/1.0")
120
+ response = urllib.request.urlopen(req)
121
+
122
+ # Write to a temp file and rename on completion to avoid partial downloads
123
+ tmp_path = zst_path.with_suffix(".zst.tmp")
124
+ t0 = time.monotonic()
125
+ downloaded = 0
126
+ try:
127
+ with open(tmp_path, "wb") as f:
128
+ while True:
129
+ chunk = response.read(8 * 1024 * 1024) # 8 MB
130
+ if not chunk:
131
+ break
132
+ f.write(chunk)
133
+ downloaded += len(chunk)
134
+ elapsed = time.monotonic() - t0
135
+ rate_mb = (downloaded / 1e6) / elapsed if elapsed > 0 else 0
136
+ print(f"\r Downloaded {downloaded / 1e9:.2f} GB ({rate_mb:.0f} MB/s)", end="", flush=True)
137
+ print(flush=True)
138
+ tmp_path.rename(zst_path)
139
+ except BaseException:
140
+ tmp_path.unlink(missing_ok=True)
141
+ raise
142
+
143
+ response.close()
144
+ log(f" Saved {zst_path} ({downloaded / 1e9:.2f} GB in {time.monotonic() - t0:.0f}s)")
145
+ return zst_path
146
+
147
+
148
+ def stream_pgn_from_zst(zst_path: Path, batch_size: int):
149
+ """Yield PGN text batches from a local .zst file."""
150
+ with open(zst_path, "rb") as f:
151
+ yield from stream_pgn_games(f, batch_size)
152
+
153
+
154
+ def download_month(year_month: str, output_dir: Path, batch_size: int):
155
+ """Download and parse a single month's PGN dump, yielding parsed batches."""
156
+ zst_path = download_zst(year_month, output_dir)
157
+ total_games = 0
158
+ batch_num = 0
159
+
160
+ for pgn_text, n_games_in_chunk in stream_pgn_from_zst(zst_path, batch_size):
161
+ if not pgn_text.strip():
162
+ continue
163
+
164
+ t0 = time.monotonic()
165
+ parsed = chess_engine.parse_pgn_enriched(
166
+ pgn_text, max_ply=MAX_PLY, max_games=batch_size * 2, min_ply=1
167
+ )
168
+ dt = time.monotonic() - t0
169
+
170
+ n = parsed["tokens"].shape[0]
171
+ total_games += n
172
+ batch_num += 1
173
+ rate = n / dt if dt > 0 else 0
174
+ log(f" [{year_month}] batch {batch_num}: {n:,} games parsed in {dt:.1f}s ({rate:,.0f} games/s) | total: {total_games:,}")
175
+
176
+ yield parsed
177
+
178
+ log(f" [{year_month}] Done — {total_games:,} games total")
179
+
180
+
181
+ def sample_holdout(
182
+ zst_path: Path,
183
+ date_start: str,
184
+ date_end: str,
185
+ n_games: int,
186
+ batch_size: int,
187
+ seed: int,
188
+ ) -> pl.DataFrame:
189
+ """Uniformly sample n_games from a date range within a .zst PGN dump.
190
+
191
+ Two-pass approach:
192
+ 1. Count games in the date range (header-only scan, no tokenization)
193
+ 2. Generate random indices, parse only those games
194
+ """
195
+ # Pass 1: count
196
+ log(f" Pass 1: counting games in [{date_start}, {date_end}]")
197
+ total_in_range = 0
198
+ t0 = time.monotonic()
199
+ for pgn_text, _ in stream_pgn_from_zst(zst_path, batch_size):
200
+ if not pgn_text.strip():
201
+ continue
202
+ total_in_range += chess_engine.count_pgn_games_in_date_range(
203
+ pgn_text, date_start, date_end
204
+ )
205
+ dt = time.monotonic() - t0
206
+ log(f" Found {total_in_range:,} games in range ({dt:.1f}s)")
207
+
208
+ if total_in_range == 0:
209
+ return pl.DataFrame()
210
+
211
+ # Generate random sample indices
212
+ actual_n = min(n_games, total_in_range)
213
+ rng = np.random.default_rng(seed)
214
+ indices = set(rng.choice(total_in_range, size=actual_n, replace=False).tolist())
215
+ log(f" Sampling {actual_n:,} of {total_in_range:,} games (seed={seed})")
216
+
217
+ # Pass 2: parse only sampled games
218
+ log(f" Pass 2: parsing sampled games")
219
+ frames = []
220
+ game_offset = 0
221
+ t0 = time.monotonic()
222
+ for pgn_text, _ in stream_pgn_from_zst(zst_path, batch_size):
223
+ if not pgn_text.strip():
224
+ continue
225
+
226
+ # Count how many date-matching games are in this chunk
227
+ chunk_count = chess_engine.count_pgn_games_in_date_range(
228
+ pgn_text, date_start, date_end
229
+ )
230
+
231
+ # Check if any of our target indices fall in this chunk's range
232
+ chunk_indices = [
233
+ i for i in range(game_offset, game_offset + chunk_count)
234
+ if i in indices
235
+ ]
236
+
237
+ if chunk_indices:
238
+ parsed = chess_engine.parse_pgn_sampled(
239
+ pgn_text,
240
+ chunk_indices,
241
+ date_start,
242
+ date_end,
243
+ game_offset=game_offset,
244
+ max_ply=MAX_PLY,
245
+ min_ply=1,
246
+ )
247
+ n = parsed["tokens"].shape[0]
248
+ if n > 0:
249
+ frames.append(batch_to_dataframe(parsed))
250
+
251
+ game_offset += chunk_count
252
+
253
+ dt = time.monotonic() - t0
254
+ if not frames:
255
+ log(f" No games parsed")
256
+ return pl.DataFrame()
257
+
258
+ result = pl.concat(frames)
259
+ log(f" Parsed {len(result):,} games ({dt:.1f}s)")
260
+ return result
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # DataFrame construction
265
+ # ---------------------------------------------------------------------------
266
+
267
+ def numpy_rows_to_list_series(
268
+ arr: np.ndarray, lengths: np.ndarray, name: str, inner_dtype: pl.DataType
269
+ ) -> pl.Series:
270
+ """Convert a 0-padded (N, max_ply) numpy array to a Polars List series,
271
+ trimming each row to its actual game length."""
272
+ rows = [arr[i, :lengths[i]].tolist() for i in range(len(arr))]
273
+ return pl.Series(name, rows, dtype=pl.List(inner_dtype))
274
+
275
+
276
+ def batch_to_dataframe(parsed: dict) -> pl.DataFrame:
277
+ """Convert a parsed batch dict from Rust into a Polars DataFrame.
278
+
279
+ Rust returns numpy arrays: tokens/clocks/evals as (N, max_ply),
280
+ scalar fields as (N,) arrays, and strings as Python lists.
281
+ """
282
+ tokens: np.ndarray = parsed["tokens"] # (N, max_ply) i16
283
+ n = tokens.shape[0]
284
+ if n == 0:
285
+ return pl.DataFrame()
286
+
287
+ lengths: np.ndarray = parsed["game_lengths"] # (N,) u16
288
+
289
+ # Parse datetime strings -> proper datetime
290
+ # Format: "YYYY.MM.DD HH:MM:SS"
291
+ datetimes = []
292
+ for dt_str in parsed["date_time"]:
293
+ if dt_str and len(dt_str) >= 10:
294
+ try:
295
+ datetimes.append(datetime.strptime(dt_str, "%Y.%m.%d %H:%M:%S"))
296
+ except ValueError:
297
+ try:
298
+ datetimes.append(datetime.strptime(dt_str[:10], "%Y.%m.%d"))
299
+ except ValueError:
300
+ datetimes.append(None)
301
+ else:
302
+ datetimes.append(None)
303
+
304
+ df = pl.DataFrame({
305
+ "tokens": numpy_rows_to_list_series(tokens, lengths, "tokens", pl.Int16),
306
+ "clock": numpy_rows_to_list_series(parsed["clocks"], lengths, "clock", pl.UInt16),
307
+ "eval": numpy_rows_to_list_series(parsed["evals"], lengths, "eval", pl.Int16),
308
+ "game_length": pl.Series("game_length", parsed["game_lengths"], dtype=pl.UInt16),
309
+ "result": pl.Series("result", parsed["result"], dtype=pl.Utf8),
310
+ "white_player": pl.Series("white_player", parsed["white"], dtype=pl.Utf8),
311
+ "black_player": pl.Series("black_player", parsed["black"], dtype=pl.Utf8),
312
+ "white_elo": pl.Series("white_elo", parsed["white_elo"], dtype=pl.UInt16),
313
+ "black_elo": pl.Series("black_elo", parsed["black_elo"], dtype=pl.UInt16),
314
+ "white_rating_diff": pl.Series("white_rating_diff", parsed["white_rating_diff"], dtype=pl.Int16),
315
+ "black_rating_diff": pl.Series("black_rating_diff", parsed["black_rating_diff"], dtype=pl.Int16),
316
+ "eco": pl.Series("eco", parsed["eco"], dtype=pl.Utf8),
317
+ "opening": pl.Series("opening", parsed["opening"], dtype=pl.Utf8),
318
+ "time_control": pl.Series("time_control", parsed["time_control"], dtype=pl.Utf8),
319
+ "termination": pl.Series("termination", parsed["termination"], dtype=pl.Utf8),
320
+ "date": pl.Series("date", datetimes, dtype=pl.Datetime("ms")),
321
+ "site": pl.Series("site", parsed["site"], dtype=pl.Utf8),
322
+ })
323
+
324
+ # Hash usernames: vectorized xxHash64 via Polars.
325
+ # NOTE: hash() output is deterministic within a Polars version but the
326
+ # algorithm is not guaranteed stable across major versions. Originally
327
+ # recorded with Polars 1.39.3. Pin Polars version (via uv.lock) and
328
+ # tag the repo to ensure reproducibility. See test_enriched_pgn.py
329
+ # TestPlayerHashRegression for the snapshot test.
330
+ df = df.with_columns(
331
+ pl.col("white_player").hash().alias("white_player"),
332
+ pl.col("black_player").hash().alias("black_player"),
333
+ )
334
+
335
+ return df
336
+
337
+
338
+ # ---------------------------------------------------------------------------
339
+ # Shard writing
340
+ # ---------------------------------------------------------------------------
341
+
342
+ def write_shard(
343
+ df: pl.DataFrame,
344
+ output_dir: Path,
345
+ split: str,
346
+ shard_idx: int,
347
+ total_shards: int,
348
+ ) -> Path:
349
+ """Write a single Parquet shard with HF-compatible naming."""
350
+ name = f"{split}-{shard_idx:05d}-of-{total_shards:05d}.parquet"
351
+ path = output_dir / "data" / name
352
+ path.parent.mkdir(parents=True, exist_ok=True)
353
+ df.write_parquet(path, compression="zstd", compression_level=3)
354
+ size_mb = path.stat().st_size / 1024 / 1024
355
+ log(f" Wrote {name}: {len(df):,} games, {size_mb:.1f} MB")
356
+ return path
357
+
358
+
359
+ # ---------------------------------------------------------------------------
360
+ # HuggingFace upload
361
+ # ---------------------------------------------------------------------------
362
+
363
+ def upload_to_hf(output_dir: Path, hf_repo: str) -> None:
364
+ """Upload the output directory to HuggingFace as a dataset."""
365
+ from huggingface_hub import HfApi
366
+
367
+ api = HfApi()
368
+ log(f"Uploading to HuggingFace: {hf_repo}")
369
+
370
+ # Create repo if it doesn't exist
371
+ api.create_repo(hf_repo, repo_type="dataset", exist_ok=True)
372
+
373
+ # Upload the data directory
374
+ api.upload_folder(
375
+ repo_id=hf_repo,
376
+ folder_path=str(output_dir),
377
+ repo_type="dataset",
378
+ )
379
+ log(f"Upload complete: https://huggingface.co/datasets/{hf_repo}")
380
+
381
+
382
+ # ---------------------------------------------------------------------------
383
+ # Main
384
+ # ---------------------------------------------------------------------------
385
+
386
+ class SplitBuffer:
387
+ """Accumulates DataFrames for a single split and flushes to Parquet shards."""
388
+
389
+ def __init__(self, split: str, shard_size: int, output_dir: Path):
390
+ self.split = split
391
+ self.shard_size = shard_size
392
+ self.output_dir = output_dir / "data"
393
+ self.output_dir.mkdir(parents=True, exist_ok=True)
394
+ self.frames: list[pl.DataFrame] = []
395
+ self.buffered = 0
396
+ self.total_games = 0
397
+ self.shard_paths: list[Path] = []
398
+ self.shard_idx = 0
399
+
400
+ def add(self, df: pl.DataFrame) -> None:
401
+ if df.is_empty():
402
+ return
403
+ self.frames.append(df)
404
+ self.buffered += len(df)
405
+ self.total_games += len(df)
406
+ self._flush_full()
407
+
408
+ def _flush_full(self) -> None:
409
+ while self.buffered >= self.shard_size:
410
+ combined = pl.concat(self.frames)
411
+ shard_df = combined.head(self.shard_size)
412
+ leftover = combined.slice(self.shard_size)
413
+ self._write_shard(shard_df)
414
+ self.frames = [leftover] if len(leftover) > 0 else []
415
+ self.buffered = len(leftover) if len(leftover) > 0 else 0
416
+
417
+ def flush_remaining(self) -> None:
418
+ if self.frames:
419
+ combined = pl.concat(self.frames)
420
+ if len(combined) > 0:
421
+ self._write_shard(combined)
422
+ self.frames.clear()
423
+ self.buffered = 0
424
+
425
+ def _write_shard(self, df: pl.DataFrame) -> None:
426
+ # Write with placeholder name; rename after all shards are counted
427
+ path = self.output_dir / f"{self.split}-temp-{self.shard_idx:05d}.parquet"
428
+ df.write_parquet(path, compression="zstd", compression_level=3)
429
+ size_mb = path.stat().st_size / 1024 / 1024
430
+ log(f" [{self.split}] shard {self.shard_idx}: {len(df):,} games, {size_mb:.1f} MB")
431
+ self.shard_paths.append(path)
432
+ self.shard_idx += 1
433
+
434
+ def rename_shards(self) -> list[Path]:
435
+ """Rename temp shards to HF-compatible names with correct total count."""
436
+ n = len(self.shard_paths)
437
+ final = []
438
+ for i, path in enumerate(self.shard_paths):
439
+ new_name = f"{self.split}-{i:05d}-of-{n:05d}.parquet"
440
+ new_path = path.parent / new_name
441
+ path.rename(new_path)
442
+ final.append(new_path)
443
+ log(f" {path.name} -> {new_name}")
444
+ self.shard_paths = final
445
+ return final
446
+
447
+
448
+ def main():
449
+ parser = argparse.ArgumentParser(
450
+ description="Extract Lichess PGN dumps to PAWN-compatible Parquet"
451
+ )
452
+ parser.add_argument(
453
+ "--months", nargs="+", required=True,
454
+ help="Training month(s) to download, e.g. 2025-01 2025-02 2025-03"
455
+ )
456
+ parser.add_argument(
457
+ "--output", type=Path, default=Path("/workspace/lichess-parquet"),
458
+ help="Output directory for Parquet shards"
459
+ )
460
+ parser.add_argument(
461
+ "--hf-repo", type=str, default=None,
462
+ help="HuggingFace dataset repo to push to (e.g. thomas-schweich/pawn-lichess-full)"
463
+ )
464
+ parser.add_argument(
465
+ "--batch-size", type=int, default=500_000,
466
+ help="Games per batch during parsing (controls memory usage)"
467
+ )
468
+ parser.add_argument(
469
+ "--shard-size", type=int, default=SHARD_TARGET_GAMES,
470
+ help="Target games per output shard"
471
+ )
472
+ parser.add_argument(
473
+ "--holdout-month", type=str, default=None,
474
+ help="Month to use for val/test (e.g. 2023-12). First half of month "
475
+ "-> val, second half -> test. Randomly samples --holdout-games "
476
+ "from each half."
477
+ )
478
+ parser.add_argument(
479
+ "--holdout-games", type=int, default=50_000,
480
+ help="Number of games to sample for each of val and test (default: 50000)"
481
+ )
482
+ parser.add_argument(
483
+ "--max-games", type=int, default=None,
484
+ help="Stop after this many training games (for testing)"
485
+ )
486
+ parser.add_argument(
487
+ "--seed", type=int, default=42,
488
+ help="Random seed for holdout sampling (default: 42)"
489
+ )
490
+ args = parser.parse_args()
491
+
492
+ log("=== Lichess Parquet Extraction ===")
493
+ log(f"Training months: {args.months}")
494
+ log(f"Output: {args.output}")
495
+ log(f"Batch size: {args.batch_size:,}")
496
+ log(f"Shard size: {args.shard_size:,}")
497
+ if args.holdout_month:
498
+ log(f"Holdout month: {args.holdout_month}")
499
+ log(f"Holdout games per split: {args.holdout_games:,}")
500
+ log(f"Holdout seed: {args.seed}")
501
+ if args.max_games:
502
+ log(f"Max training games: {args.max_games:,}")
503
+ log("")
504
+
505
+ args.output.mkdir(parents=True, exist_ok=True)
506
+
507
+ # ── Phase 1: Process holdout month (val/test) ──────────────────────
508
+ buffers = {
509
+ "train": SplitBuffer("train", args.shard_size, args.output),
510
+ "validation": SplitBuffer("validation", args.shard_size, args.output),
511
+ "test": SplitBuffer("test", args.shard_size, args.output),
512
+ }
513
+
514
+ if args.holdout_month:
515
+ log(f"\n=== Processing holdout month {args.holdout_month} ===")
516
+
517
+ # Download the zstd file (reused for both count and parse passes)
518
+ zst_path = download_zst(args.holdout_month, args.output)
519
+
520
+ # Date ranges: first half of month -> val, second half -> test
521
+ # UTCDate format is "YYYY.MM.DD"
522
+ year, mon = args.holdout_month.split("-")
523
+ val_start = f"{year}.{mon}.01"
524
+ val_end = f"{year}.{mon}.14"
525
+ test_start = f"{year}.{mon}.15"
526
+ test_end = f"{year}.{mon}.31"
527
+
528
+ log(f" Val date range: [{val_start}, {val_end}]")
529
+ val_df = sample_holdout(
530
+ zst_path, val_start, val_end,
531
+ args.holdout_games, args.batch_size, args.seed,
532
+ )
533
+ if not val_df.is_empty():
534
+ buffers["validation"].add(val_df)
535
+
536
+ log(f" Test date range: [{test_start}, {test_end}]")
537
+ test_df = sample_holdout(
538
+ zst_path, test_start, test_end,
539
+ args.holdout_games, args.batch_size, args.seed + 1,
540
+ )
541
+ if not test_df.is_empty():
542
+ buffers["test"].add(test_df)
543
+
544
+ # ── Phase 2: Process training months ─────────────────���─────────────
545
+ total_train = 0
546
+ stop = False
547
+
548
+ for month in args.months:
549
+ if stop:
550
+ break
551
+ log(f"\n=== Processing {month} (train) ===")
552
+
553
+ for parsed in download_month(month, args.output, args.batch_size):
554
+ df = batch_to_dataframe(parsed)
555
+ if df.is_empty():
556
+ continue
557
+
558
+ if args.max_games:
559
+ remaining = args.max_games - total_train
560
+ if remaining <= 0:
561
+ stop = True
562
+ break
563
+ if len(df) > remaining:
564
+ df = df.head(remaining)
565
+
566
+ total_train += len(df)
567
+ buffers["train"].add(df)
568
+
569
+ if args.max_games and total_train >= args.max_games:
570
+ stop = True
571
+ break
572
+
573
+ # Flush remaining data in each buffer
574
+ for buf in buffers.values():
575
+ buf.flush_remaining()
576
+
577
+ log(f"\n=== Renaming shards ===")
578
+ final_paths = []
579
+ for buf in buffers.values():
580
+ if buf.shard_paths:
581
+ final_paths.extend(buf.rename_shards())
582
+
583
+ # Summary
584
+ log(f"\n=== Summary ===")
585
+ total_games = sum(buf.total_games for buf in buffers.values())
586
+ log(f"Total games: {total_games:,}")
587
+ for name, buf in buffers.items():
588
+ if buf.total_games > 0:
589
+ log(f" {name}: {buf.total_games:,} games, {len(buf.shard_paths)} shards")
590
+
591
+ if final_paths:
592
+ total_size = sum(p.stat().st_size for p in final_paths)
593
+ log(f"Total size: {total_size / 1024 / 1024 / 1024:.2f} GB")
594
+
595
+ # Upload to HuggingFace
596
+ if args.hf_repo:
597
+ upload_to_hf(args.output, args.hf_repo)
598
+
599
+ log("\nDone!")
600
+
601
+
602
+ if __name__ == "__main__":
603
+ main()
tests/test_enriched_pgn.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for enriched PGN parsing and dataset extraction pipeline."""
2
+
3
+ import numpy as np
4
+ import polars as pl
5
+ import pytest
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ import chess_engine
10
+
11
+ # Import extraction helper
12
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "scripts"))
13
+ from extract_lichess_parquet import batch_to_dataframe
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Test PGN data — each game has distinct moves, metadata, and annotations
18
+ # ---------------------------------------------------------------------------
19
+
20
+ PGNS = {
21
+ "alice_v_bob": """\
22
+ [Event "Rated Rapid game"]
23
+ [Site "https://lichess.org/game001"]
24
+ [White "alice"]
25
+ [Black "bob"]
26
+ [Result "1-0"]
27
+ [WhiteElo "1873"]
28
+ [BlackElo "1844"]
29
+ [WhiteRatingDiff "+6"]
30
+ [BlackRatingDiff "-26"]
31
+ [ECO "C20"]
32
+ [Opening "King's Pawn Game"]
33
+ [TimeControl "600+0"]
34
+ [Termination "Normal"]
35
+ [UTCDate "2025.01.10"]
36
+ [UTCTime "10:00:00"]
37
+
38
+ 1. e4 { [%clk 0:10:00] [%eval 0.23] } 1... e5 { [%clk 0:09:58] [%eval 0.31] } 2. Nf3 { [%clk 0:09:50] [%eval 0.25] } 2... Nc6 { [%clk 0:09:45] [%eval 0.30] } 1-0
39
+ """,
40
+ "bob_v_alice": """\
41
+ [Event "Rated Blitz game"]
42
+ [Site "https://lichess.org/game002"]
43
+ [White "bob"]
44
+ [Black "alice"]
45
+ [Result "0-1"]
46
+ [WhiteElo "1850"]
47
+ [BlackElo "1880"]
48
+ [WhiteRatingDiff "-5"]
49
+ [BlackRatingDiff "+5"]
50
+ [ECO "B20"]
51
+ [Opening "Sicilian Defense"]
52
+ [TimeControl "300+3"]
53
+ [Termination "Time forfeit"]
54
+ [UTCDate "2025.02.14"]
55
+ [UTCTime "20:00:00"]
56
+
57
+ 1. e4 { [%clk 0:05:00] [%eval 0.20] } 1... c5 { [%clk 0:04:55] [%eval 0.25] } 2. d4 { [%clk 0:04:48] [%eval 0.40] } 0-1
58
+ """,
59
+ "alice_v_xavier": """\
60
+ [Event "Rated Classical game"]
61
+ [Site "https://lichess.org/game003"]
62
+ [White "alice"]
63
+ [Black "xavier"]
64
+ [Result "1/2-1/2"]
65
+ [WhiteElo "1900"]
66
+ [BlackElo "2100"]
67
+ [WhiteRatingDiff "+3"]
68
+ [BlackRatingDiff "-1"]
69
+ [ECO "D30"]
70
+ [Opening "Queen's Gambit Declined"]
71
+ [TimeControl "1800+30"]
72
+ [Termination "Normal"]
73
+ [UTCDate "2025.03.01"]
74
+ [UTCTime "15:30:00"]
75
+
76
+ 1. d4 { [%clk 0:30:00] [%eval 0.10] } 1... d5 { [%clk 0:29:50] [%eval 0.15] } 2. c4 { [%clk 0:29:40] [%eval 0.20] } 2... e6 { [%clk 0:29:30] [%eval 0.18] } 3. Nf3 { [%clk 0:29:20] [%eval 0.22] } 1/2-1/2
77
+ """,
78
+ "xavier_v_alice": """\
79
+ [Event "Rated Rapid game"]
80
+ [Site "https://lichess.org/game004"]
81
+ [White "xavier"]
82
+ [Black "alice"]
83
+ [Result "1-0"]
84
+ [WhiteElo "2105"]
85
+ [BlackElo "1895"]
86
+ [WhiteRatingDiff "+2"]
87
+ [BlackRatingDiff "-4"]
88
+ [ECO "A45"]
89
+ [Opening "Trompowsky Attack"]
90
+ [TimeControl "900+10"]
91
+ [Termination "Normal"]
92
+ [UTCDate "2025.01.20"]
93
+ [UTCTime "18:00:00"]
94
+
95
+ 1. d4 { [%clk 0:15:00] } 1... Nf6 { [%clk 0:14:55] } 2. Bg5 { [%clk 0:14:48] } 1-0
96
+ """,
97
+ "bob_v_xavier": """\
98
+ [Event "Rated Bullet game"]
99
+ [Site "https://lichess.org/game005"]
100
+ [White "bob"]
101
+ [Black "xavier"]
102
+ [Result "0-1"]
103
+ [WhiteElo "1840"]
104
+ [BlackElo "2110"]
105
+ [WhiteRatingDiff "-8"]
106
+ [BlackRatingDiff "+3"]
107
+ [ECO "C50"]
108
+ [Opening "Italian Game"]
109
+ [TimeControl "60+0"]
110
+ [Termination "Normal"]
111
+ [UTCDate "2025.02.28"]
112
+ [UTCTime "23:59:00"]
113
+
114
+ 1. e4 { [%clk 0:01:00] [%eval 0.20] } 1... e5 { [%clk 0:00:59] [%eval 0.25] } 2. Nf3 { [%clk 0:00:55] [%eval 0.30] } 2... Nc6 { [%clk 0:00:53] [%eval 0.28] } 3. Bc4 { [%clk 0:00:50] [%eval 0.35] } 3... Bc5 { [%clk 0:00:48] [%eval 0.30] } 0-1
115
+ """,
116
+ "xavier_v_bob": """\
117
+ [Event "Rated Rapid game"]
118
+ [Site "https://lichess.org/game006"]
119
+ [White "xavier"]
120
+ [Black "bob"]
121
+ [Result "1-0"]
122
+ [WhiteElo "2115"]
123
+ [BlackElo "1835"]
124
+ [WhiteRatingDiff "+1"]
125
+ [BlackRatingDiff "-7"]
126
+ [ECO "E00"]
127
+ [Opening "Queen's Pawn Game"]
128
+ [TimeControl "600+5"]
129
+ [Termination "Normal"]
130
+ [UTCDate "2025.03.15"]
131
+ [UTCTime "09:00:00"]
132
+
133
+ 1. d4 { [%clk 0:10:00] [%eval 0.15] } 1... Nf6 { [%clk 0:09:55] [%eval 0.20] } 2. c4 { [%clk 0:09:48] [%eval 0.25] } 2... e6 { [%clk 0:09:40] [%eval 0.22] } 3. Nc3 { [%clk 0:09:35] [%eval 0.28] } 3... Bb4 { [%clk 0:09:28] [%eval 0.30] } 4. Qc2 { [%clk 0:09:20] [%eval 0.32] } 1-0
134
+ """,
135
+ }
136
+
137
+
138
+ # ---------------------------------------------------------------------------
139
+ # Tests
140
+ # ---------------------------------------------------------------------------
141
+
142
+
143
+ class TestEnrichedParsing:
144
+ """Test the Rust parse_pgn_enriched function."""
145
+
146
+ def test_basic_parsing(self):
147
+ pgn = PGNS["alice_v_bob"]
148
+ r = chess_engine.parse_pgn_enriched(pgn)
149
+ assert r["tokens"].shape == (1, 255)
150
+ assert r["clocks"].shape == (1, 255)
151
+ assert r["evals"].shape == (1, 255)
152
+ assert r["game_lengths"].shape == (1,)
153
+ assert r["game_lengths"][0] == 4
154
+
155
+ def test_return_types(self):
156
+ pgn = PGNS["alice_v_bob"]
157
+ r = chess_engine.parse_pgn_enriched(pgn)
158
+ for key in ("tokens", "clocks", "evals"):
159
+ assert isinstance(r[key], np.ndarray), f"{key} should be ndarray"
160
+ for key in ("game_lengths", "white_elo", "black_elo",
161
+ "white_rating_diff", "black_rating_diff"):
162
+ assert isinstance(r[key], np.ndarray), f"{key} should be ndarray"
163
+ for key in ("result", "white", "black", "eco", "opening",
164
+ "time_control", "termination", "date_time", "site"):
165
+ assert isinstance(r[key], list), f"{key} should be list"
166
+
167
+ def test_clock_extraction(self):
168
+ pgn = PGNS["alice_v_bob"]
169
+ r = chess_engine.parse_pgn_enriched(pgn)
170
+ clocks = r["clocks"][0, :4]
171
+ assert list(clocks) == [600, 598, 590, 585]
172
+
173
+ def test_eval_extraction(self):
174
+ pgn = PGNS["alice_v_bob"]
175
+ r = chess_engine.parse_pgn_enriched(pgn)
176
+ evals = r["evals"][0, :4]
177
+ assert list(evals) == [23, 31, 25, 30]
178
+
179
+ def test_missing_eval_sentinel(self):
180
+ pgn = PGNS["xavier_v_alice"] # no eval annotations
181
+ r = chess_engine.parse_pgn_enriched(pgn)
182
+ length = r["game_lengths"][0]
183
+ evals = r["evals"][0, :length]
184
+ # Rust uses i16::MIN (-32768) as the "no eval" sentinel
185
+ assert all(e == -32768 for e in evals), (
186
+ f"Missing evals should be -32768 (i16::MIN), got {list(evals)}"
187
+ )
188
+
189
+ def test_padding_is_zero(self):
190
+ pgn = PGNS["alice_v_bob"]
191
+ r = chess_engine.parse_pgn_enriched(pgn)
192
+ length = r["game_lengths"][0]
193
+ assert np.all(r["tokens"][0, length:] == 0)
194
+ assert np.all(r["clocks"][0, length:] == 0)
195
+ assert np.all(r["evals"][0, length:] == 0)
196
+
197
+ def test_headers_extracted(self):
198
+ pgn = PGNS["alice_v_bob"]
199
+ r = chess_engine.parse_pgn_enriched(pgn)
200
+ assert r["white"][0] == "alice"
201
+ assert r["black"][0] == "bob"
202
+ assert r["result"][0] == "1-0"
203
+ assert r["white_elo"][0] == 1873
204
+ assert r["black_elo"][0] == 1844
205
+ assert r["white_rating_diff"][0] == 6
206
+ assert r["black_rating_diff"][0] == -26
207
+ assert r["eco"][0] == "C20"
208
+ assert r["time_control"][0] == "600+0"
209
+ assert r["site"][0] == "https://lichess.org/game001"
210
+
211
+ def test_different_games_produce_different_tokens(self):
212
+ """Each test PGN has distinct moves — tokens must differ."""
213
+ all_tokens = {}
214
+ for name, pgn in PGNS.items():
215
+ r = chess_engine.parse_pgn_enriched(pgn)
216
+ length = r["game_lengths"][0]
217
+ all_tokens[name] = tuple(r["tokens"][0, :length])
218
+
219
+ names = list(all_tokens.keys())
220
+ for i in range(len(names)):
221
+ for j in range(i + 1, len(names)):
222
+ assert all_tokens[names[i]] != all_tokens[names[j]], (
223
+ f"{names[i]} and {names[j]} should have different token sequences"
224
+ )
225
+
226
+
227
+ class TestPlayerHashing:
228
+ """Test that player username hashing is deterministic and independent of context.
229
+
230
+ Each PGN has different moves, metadata, Elo, time control, dates, and
231
+ game lengths — the only thing shared is the player name strings. If
232
+ hashing accidentally depended on row context, these would diverge.
233
+ """
234
+
235
+ @pytest.fixture
236
+ def player_hashes(self):
237
+ """Parse all 6 PGNs separately and collect per-player hash values."""
238
+ hashes = {} # name -> list of observed uint64 hashes
239
+ for name, pgn in PGNS.items():
240
+ r = chess_engine.parse_pgn_enriched(pgn)
241
+ df = batch_to_dataframe(r)
242
+ w_name = r["white"][0]
243
+ b_name = r["black"][0]
244
+ w_hash = df["white_player"][0]
245
+ b_hash = df["black_player"][0]
246
+ hashes.setdefault(w_name, []).append(w_hash)
247
+ hashes.setdefault(b_name, []).append(b_hash)
248
+ return hashes
249
+
250
+ def test_same_name_always_same_hash(self, player_hashes):
251
+ """A player name must always produce the same hash regardless of
252
+ which game it appears in, whether as white or black, and what
253
+ the surrounding metadata looks like."""
254
+ for name, vals in player_hashes.items():
255
+ unique = set(vals)
256
+ assert len(unique) == 1, (
257
+ f"'{name}' produced {len(unique)} distinct hashes across "
258
+ f"{len(vals)} appearances: {unique}"
259
+ )
260
+
261
+ def test_different_names_different_hashes(self, player_hashes):
262
+ """alice, bob, and xavier must all have distinct hashes."""
263
+ canonical = {name: vals[0] for name, vals in player_hashes.items()}
264
+ hash_vals = list(canonical.values())
265
+ assert len(set(hash_vals)) == len(hash_vals), (
266
+ f"Hash collision among players: {canonical}"
267
+ )
268
+
269
+ def test_hash_dtype_is_uint64(self, player_hashes):
270
+ # Parse one game and check the Polars column dtype
271
+ r = chess_engine.parse_pgn_enriched(PGNS["alice_v_bob"])
272
+ df = batch_to_dataframe(r)
273
+ assert df["white_player"].dtype == pl.UInt64
274
+ assert df["black_player"].dtype == pl.UInt64
275
+
276
+ def test_hash_appears_in_both_columns(self, player_hashes):
277
+ """alice appears as both white and black — hash must match in both columns."""
278
+ # alice is white in alice_v_bob and black in bob_v_alice
279
+ r1 = chess_engine.parse_pgn_enriched(PGNS["alice_v_bob"])
280
+ df1 = batch_to_dataframe(r1)
281
+ r2 = chess_engine.parse_pgn_enriched(PGNS["bob_v_alice"])
282
+ df2 = batch_to_dataframe(r2)
283
+
284
+ alice_as_white = df1["white_player"][0]
285
+ alice_as_black = df2["black_player"][0]
286
+ assert alice_as_white == alice_as_black, (
287
+ f"alice hash differs: white={alice_as_white}, black={alice_as_black}"
288
+ )
289
+
290
+
291
+ class TestPlayerHashRegression:
292
+ """Snapshot test: catch if a Polars update changes the hash algorithm.
293
+
294
+ These exact values were recorded with Polars 1.39.3 using the default
295
+ hash() seed (xxHash64). If this test fails after a Polars upgrade, the
296
+ dataset must be regenerated to stay consistent (or the old Polars
297
+ version must be pinned).
298
+ """
299
+
300
+ EXPECTED_HASHES = {
301
+ "alice": 573680751236103438,
302
+ "bob": 11376496890720967193,
303
+ "xavier": 2453512920044318708,
304
+ }
305
+
306
+ def test_hash_values_match_snapshot(self):
307
+ """Verify that pl.Series.hash() produces the exact same uint64
308
+ values that were recorded when the dataset was built."""
309
+ for name, expected in self.EXPECTED_HASHES.items():
310
+ actual = pl.Series([name]).hash()[0]
311
+ assert actual == expected, (
312
+ f"Hash regression for '{name}': expected {expected}, got {actual}. "
313
+ f"Polars hash algorithm may have changed — dataset must be regenerated."
314
+ )
315
+
316
+ def test_snapshot_matches_pipeline(self):
317
+ """The snapshot values must agree with what batch_to_dataframe produces."""
318
+ combined = "\n".join(PGNS.values())
319
+ r = chess_engine.parse_pgn_enriched(combined)
320
+ df = batch_to_dataframe(r)
321
+
322
+ for name, expected in self.EXPECTED_HASHES.items():
323
+ # Find rows where this player appears as white
324
+ white_rows = [
325
+ i for i, w in enumerate(r["white"]) if w == name
326
+ ]
327
+ for i in white_rows:
328
+ actual = df["white_player"][i]
329
+ assert actual == expected, (
330
+ f"Pipeline hash for '{name}' (white, row {i}): "
331
+ f"expected {expected}, got {actual}"
332
+ )
333
+
334
+ # Find rows where this player appears as black
335
+ black_rows = [
336
+ i for i, b in enumerate(r["black"]) if b == name
337
+ ]
338
+ for i in black_rows:
339
+ actual = df["black_player"][i]
340
+ assert actual == expected, (
341
+ f"Pipeline hash for '{name}' (black, row {i}): "
342
+ f"expected {expected}, got {actual}"
343
+ )
344
+
345
+
346
+ class TestBatchToDataframe:
347
+ """Test the full batch_to_dataframe pipeline."""
348
+
349
+ def test_schema(self):
350
+ r = chess_engine.parse_pgn_enriched(PGNS["alice_v_bob"])
351
+ df = batch_to_dataframe(r)
352
+ assert df["tokens"].dtype == pl.List(pl.Int16)
353
+ assert df["clock"].dtype == pl.List(pl.UInt16)
354
+ assert df["eval"].dtype == pl.List(pl.Int16)
355
+ assert df["game_length"].dtype == pl.UInt16
356
+ assert df["white_elo"].dtype == pl.UInt16
357
+ assert df["black_elo"].dtype == pl.UInt16
358
+ assert df["white_rating_diff"].dtype == pl.Int16
359
+ assert df["black_rating_diff"].dtype == pl.Int16
360
+ assert df["white_player"].dtype == pl.UInt64
361
+ assert df["black_player"].dtype == pl.UInt64
362
+
363
+ def test_list_columns_trimmed_to_game_length(self):
364
+ """List columns should contain exactly game_length elements (no padding)."""
365
+ r = chess_engine.parse_pgn_enriched(PGNS["bob_v_xavier"])
366
+ df = batch_to_dataframe(r)
367
+ gl = df["game_length"][0]
368
+ assert len(df["tokens"][0]) == gl
369
+ assert len(df["clock"][0]) == gl
370
+ assert len(df["eval"][0]) == gl
371
+
372
+ def test_parquet_roundtrip(self, tmp_path):
373
+ """Write to Parquet and read back — all values must survive."""
374
+ r = chess_engine.parse_pgn_enriched(PGNS["xavier_v_bob"])
375
+ df = batch_to_dataframe(r)
376
+ path = tmp_path / "test.parquet"
377
+ df.write_parquet(path, compression="zstd")
378
+ df2 = pl.read_parquet(path)
379
+ assert df.shape == df2.shape
380
+ assert df["tokens"].to_list() == df2["tokens"].to_list()
381
+ assert df["clock"].to_list() == df2["clock"].to_list()
382
+ assert df["eval"].to_list() == df2["eval"].to_list()
383
+ assert df["white_player"].to_list() == df2["white_player"].to_list()
384
+
385
+ def test_multi_game_batch(self):
386
+ """Parse all 6 games in a single PGN string."""
387
+ combined = "\n".join(PGNS.values())
388
+ r = chess_engine.parse_pgn_enriched(combined)
389
+ df = batch_to_dataframe(r)
390
+ assert len(df) == 6
391
+ # Each game should have a different game length
392
+ lengths = df["game_length"].to_list()
393
+ assert len(set(lengths)) > 1, "Games should have different lengths"