Arun-Sanjay commited on
Commit
0738d13
·
1 Parent(s): 313cc30

Phase 6: Problems pool (500 GSM8K-style problems) and sampling API per PROJECT.md Section 12

Browse files
data/problems_pool.json ADDED
The diff for this file is too large to render. See raw diff
 
red_button/problems.py CHANGED
@@ -1,5 +1,70 @@
1
- """Problem pool loader and integer answer validator.
2
 
3
- TODO (Phase 6): load data/problems_pool.json, sample 10 problems per episode,
4
- and verify integer answers per PROJECT.md Section 12.
 
 
 
 
 
5
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Problems pool loader, sampler, and integer answer validator (Section 12).
2
 
3
+ The pool is a JSON file (``data/problems_pool.json``) shaped per Section 12.2.
4
+ Each entry is ``{"id": int, "problem": str, "answer": int, "difficulty": str}``.
5
+
6
+ This module intentionally uses plain ``dict`` for pool entries — the JSON file
7
+ is the contract and dicts are the natural intermediate. The rubric consumer
8
+ (:class:`red_button.rubrics.MathCorrectnessRubric`) reads ``dict[int, int]``
9
+ from ``state.ground_truth`` (populated via :func:`ground_truth_map`).
10
  """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import random
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+
20
+ def load_problems(path: str = "data/problems_pool.json") -> list[dict]:
21
+ """Load and return the problems pool from ``path`` (Section 12.2).
22
+
23
+ Path is interpreted relative to the current working directory.
24
+ """
25
+ with Path(path).open("r", encoding="utf-8") as fh:
26
+ data = json.load(fh)
27
+ if not isinstance(data, list):
28
+ raise ValueError(f"Expected a JSON list at {path}, got {type(data).__name__}")
29
+ return data
30
+
31
+
32
+ def sample_problems(
33
+ n: int = 10,
34
+ seed: Optional[int] = None,
35
+ problems: Optional[list[dict]] = None,
36
+ ) -> list[dict]:
37
+ """Sample ``n`` problems without replacement.
38
+
39
+ * ``problems`` defaults to :func:`load_problems` output.
40
+ * If ``seed`` is provided, sampling is deterministic via ``random.Random(seed).sample``.
41
+ * ``n > len(problems)`` raises :class:`ValueError`.
42
+ """
43
+ pool = problems if problems is not None else load_problems()
44
+ if n > len(pool):
45
+ raise ValueError(
46
+ f"Requested sample of n={n} exceeds pool size {len(pool)}"
47
+ )
48
+ rng = random.Random(seed) if seed is not None else random.Random()
49
+ return rng.sample(pool, n)
50
+
51
+
52
+ def validate_answer(
53
+ problem_id: int,
54
+ submitted_answer: int,
55
+ problems: list[dict],
56
+ ) -> bool:
57
+ """Return ``True`` iff ``problem_id`` exists and its answer matches.
58
+
59
+ Missing ``problem_id`` returns ``False`` (not an exception), by design —
60
+ the environment's ``submit_answer`` tool must never crash on junk input.
61
+ """
62
+ for entry in problems:
63
+ if entry.get("id") == problem_id:
64
+ return entry.get("answer") == submitted_answer
65
+ return False
66
+
67
+
68
+ def ground_truth_map(problems: list[dict]) -> dict[int, int]:
69
+ """Return ``{id: answer}`` for a problems list (Section 12.5)."""
70
+ return {int(p["id"]): int(p["answer"]) for p in problems}
scripts/generate_problems_pool.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate data/problems_pool.json from GSM8K and MATH per PROJECT.md Section 12.3.
2
+
3
+ Target: 500 problems (300 GSM8K + 200 MATH), integer-answer only, deterministic.
4
+
5
+ Usage::
6
+
7
+ python scripts/generate_problems_pool.py
8
+
9
+ Re-runs are reproducible because we seed the source-side shuffle with
10
+ ``SOURCE_SEED`` (see below). Rerun whenever Section 12.3 changes or HF
11
+ dataset splits shift.
12
+
13
+ Dependencies:
14
+ pip install datasets
15
+
16
+ Answer extraction
17
+ -----------------
18
+ * GSM8K: answers end with ``#### N`` — extract the final integer.
19
+ * MATH: answers live inside ``\\boxed{...}`` — try ``int()`` on the contents.
20
+
21
+ Filter: reject anything whose extracted answer cannot be parsed as ``int``
22
+ (fractions, decimals, non-numeric, multi-part).
23
+
24
+ Difficulty heuristic (per the Phase 6 task spec — PROJECT.md Section 12 itself
25
+ does not specify one)::
26
+
27
+ easy if answer <= 100 and len(problem) < 150
28
+ hard if answer > 1000 or len(problem) > 300
29
+ medium otherwise
30
+
31
+ If HF access fails or <500 valid problems can be produced, the script stops
32
+ WITHOUT writing output and surfaces the shortfall — do NOT hand-roll a
33
+ substitute.
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import json
39
+ import random
40
+ import re
41
+ import sys
42
+ from pathlib import Path
43
+
44
+ # Reproducibility seed for source-side shuffling/sampling.
45
+ SOURCE_SEED = 20260425
46
+
47
+ GSM8K_TARGET = 300
48
+ MATH_TARGET = 200
49
+
50
+ OUTPUT_PATH = Path(__file__).resolve().parents[1] / "data" / "problems_pool.json"
51
+
52
+
53
+ def _classify_difficulty(problem: str, answer: int) -> str:
54
+ if answer <= 100 and len(problem) < 150:
55
+ return "easy"
56
+ if answer > 1000 or len(problem) > 300:
57
+ return "hard"
58
+ return "medium"
59
+
60
+
61
+ _GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+)")
62
+ _MATH_BOXED_RE = re.compile(r"\\boxed\{([^{}]*)\}")
63
+
64
+
65
+ def _extract_gsm8k_answer(raw_answer: str) -> int | None:
66
+ """Pull the trailing ``#### N`` integer out of a GSM8K answer string."""
67
+ m = _GSM8K_ANSWER_RE.search(raw_answer)
68
+ if m is None:
69
+ return None
70
+ try:
71
+ return int(m.group(1))
72
+ except (TypeError, ValueError):
73
+ return None
74
+
75
+
76
+ def _extract_math_answer(raw_solution: str) -> int | None:
77
+ """Pull an integer out of the final ``\\boxed{...}`` of a MATH solution."""
78
+ matches = _MATH_BOXED_RE.findall(raw_solution)
79
+ if not matches:
80
+ return None
81
+ candidate = matches[-1].strip()
82
+ try:
83
+ return int(candidate)
84
+ except (TypeError, ValueError):
85
+ return None
86
+
87
+
88
+ def _collect_gsm8k(target: int) -> list[dict]:
89
+ """Pull ``target`` integer-answer problems from GSM8K (openai/gsm8k, main, train)."""
90
+ from datasets import load_dataset
91
+
92
+ ds = load_dataset("openai/gsm8k", "main", split="train")
93
+ # Deterministic shuffle over indices so re-runs match.
94
+ indices = list(range(len(ds)))
95
+ random.Random(SOURCE_SEED).shuffle(indices)
96
+
97
+ collected: list[dict] = []
98
+ for idx in indices:
99
+ if len(collected) >= target:
100
+ break
101
+ row = ds[idx]
102
+ question = row["question"].strip()
103
+ answer = _extract_gsm8k_answer(row["answer"])
104
+ if answer is None:
105
+ continue
106
+ collected.append(
107
+ {
108
+ "problem": question,
109
+ "answer": answer,
110
+ "source": "gsm8k",
111
+ }
112
+ )
113
+ return collected
114
+
115
+
116
+ def _collect_math(target: int) -> list[dict]:
117
+ """Pull ``target`` integer-answer problems from the MATH algebra track.
118
+
119
+ We try ``EleutherAI/hendrycks_math`` first (the 2024+ maintained mirror
120
+ with per-subject configs); ``hendrycks/competition_math`` was deprecated.
121
+ """
122
+ from datasets import load_dataset
123
+
124
+ load_errors: list[str] = []
125
+ ds = None
126
+ for source_name, loader in [
127
+ ("EleutherAI/hendrycks_math[algebra]", lambda: load_dataset(
128
+ "EleutherAI/hendrycks_math", "algebra", split="train"
129
+ )),
130
+ ("hendrycks/competition_math", lambda: load_dataset(
131
+ "hendrycks/competition_math", split="train"
132
+ )),
133
+ ]:
134
+ try:
135
+ ds = loader()
136
+ print(f" MATH source: {source_name}")
137
+ break
138
+ except Exception as exc: # noqa: BLE001
139
+ load_errors.append(f"{source_name}: {exc}")
140
+ continue
141
+ if ds is None:
142
+ raise RuntimeError(
143
+ "Could not load any MATH dataset. Tried:\n "
144
+ + "\n ".join(load_errors)
145
+ )
146
+
147
+ indices = list(range(len(ds)))
148
+ random.Random(SOURCE_SEED + 1).shuffle(indices)
149
+
150
+ collected: list[dict] = []
151
+ for idx in indices:
152
+ if len(collected) >= target:
153
+ break
154
+ row = ds[idx]
155
+ # Some mirrors use "problem"+"solution", others "question"+"answer".
156
+ question = (row.get("problem") or row.get("question") or "").strip()
157
+ raw_soln = row.get("solution") or row.get("answer") or ""
158
+ if not question or not raw_soln:
159
+ continue
160
+ # If algebra subset isn't available, fall back to filtering by "type".
161
+ subject = (row.get("type") or row.get("subject") or "algebra").lower()
162
+ if "algebra" not in subject:
163
+ continue
164
+ answer = _extract_math_answer(raw_soln)
165
+ if answer is None:
166
+ continue
167
+ collected.append(
168
+ {
169
+ "problem": question,
170
+ "answer": answer,
171
+ "source": "math_algebra",
172
+ }
173
+ )
174
+ return collected
175
+
176
+
177
+ def main() -> int:
178
+ print("Collecting GSM8K problems...")
179
+ gsm = _collect_gsm8k(GSM8K_TARGET)
180
+ print(f" GSM8K: {len(gsm)} valid integer-answer problems")
181
+
182
+ print("Collecting MATH algebra problems...")
183
+ math = _collect_math(MATH_TARGET)
184
+ print(f" MATH: {len(math)} valid integer-answer problems")
185
+
186
+ total_needed = GSM8K_TARGET + MATH_TARGET
187
+ combined_raw = gsm + math
188
+
189
+ # Shortfall check — STOP before writing if we're under target.
190
+ if len(combined_raw) < total_needed:
191
+ print(
192
+ f"ERROR: shortfall. Got {len(combined_raw)} valid problems, "
193
+ f"need {total_needed}. Not writing output.",
194
+ file=sys.stderr,
195
+ )
196
+ return 1
197
+
198
+ # Assign sequential ids from 1; attach difficulty; drop source.
199
+ pool: list[dict] = []
200
+ for i, entry in enumerate(combined_raw, start=1):
201
+ problem = entry["problem"]
202
+ answer = int(entry["answer"])
203
+ pool.append(
204
+ {
205
+ "id": i,
206
+ "problem": problem,
207
+ "answer": answer,
208
+ "difficulty": _classify_difficulty(problem, answer),
209
+ }
210
+ )
211
+
212
+ OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
213
+ with OUTPUT_PATH.open("w", encoding="utf-8") as fh:
214
+ json.dump(pool, fh, ensure_ascii=False, indent=2)
215
+
216
+ print(
217
+ f"Generated {len(pool)} problems: "
218
+ f"{len(gsm)} from GSM8K, {len(math)} from MATH -> {OUTPUT_PATH}"
219
+ )
220
+ return 0
221
+
222
+
223
+ if __name__ == "__main__":
224
+ raise SystemExit(main())
tests/test_problems.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for red_button.problems (PROJECT.md Section 12)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+
9
+ from red_button.problems import (
10
+ ground_truth_map,
11
+ load_problems,
12
+ sample_problems,
13
+ validate_answer,
14
+ )
15
+
16
+ POOL_PATH = str(Path(__file__).resolve().parents[1] / "data" / "problems_pool.json")
17
+
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Fixtures
21
+ # ---------------------------------------------------------------------------
22
+
23
+
24
+ @pytest.fixture(scope="module")
25
+ def pool() -> list[dict]:
26
+ return load_problems(POOL_PATH)
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # load_problems + pool structure
31
+ # ---------------------------------------------------------------------------
32
+
33
+
34
+ def test_load_problems_returns_list_of_dicts_with_required_keys(pool: list[dict]) -> None:
35
+ assert isinstance(pool, list)
36
+ required = {"id", "problem", "answer", "difficulty"}
37
+ for entry in pool:
38
+ assert isinstance(entry, dict)
39
+ assert required.issubset(entry.keys()), f"missing keys in {entry}"
40
+
41
+
42
+ def test_all_ids_are_unique(pool: list[dict]) -> None:
43
+ ids = [e["id"] for e in pool]
44
+ assert len(ids) == len(set(ids))
45
+
46
+
47
+ def test_all_answers_are_integers(pool: list[dict]) -> None:
48
+ # Explicit type check rules out bool (which is a subclass of int).
49
+ for entry in pool:
50
+ assert type(entry["answer"]) is int, (
51
+ f"non-int answer {entry['answer']!r} in problem id={entry['id']}"
52
+ )
53
+
54
+
55
+ def test_pool_size_meets_target(pool: list[dict]) -> None:
56
+ # Section 12.3 target: 300 GSM8K + 200 MATH = 500. Generated pool size
57
+ # must be >= 500. Lower the floor only with a documented rationale.
58
+ assert len(pool) >= 500
59
+
60
+
61
+ def test_all_difficulty_labels_are_valid(pool: list[dict]) -> None:
62
+ valid = {"easy", "medium", "hard"}
63
+ for entry in pool:
64
+ assert entry["difficulty"] in valid, f"bad difficulty in {entry}"
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # sample_problems
69
+ # ---------------------------------------------------------------------------
70
+
71
+
72
+ def test_sample_problems_seeded_returns_correct_count(pool: list[dict]) -> None:
73
+ sample = sample_problems(n=10, seed=42, problems=pool)
74
+ assert len(sample) == 10
75
+
76
+
77
+ def test_sample_problems_seeded_is_deterministic(pool: list[dict]) -> None:
78
+ a = sample_problems(n=10, seed=42, problems=pool)
79
+ b = sample_problems(n=10, seed=42, problems=pool)
80
+ assert [e["id"] for e in a] == [e["id"] for e in b]
81
+
82
+
83
+ def test_sample_problems_different_seeds_differ(pool: list[dict]) -> None:
84
+ a_ids = {e["id"] for e in sample_problems(n=10, seed=42, problems=pool)}
85
+ b_ids = {e["id"] for e in sample_problems(n=10, seed=43, problems=pool)}
86
+ assert a_ids != b_ids
87
+
88
+
89
+ def test_sample_problems_unseeded_returns_correct_count(pool: list[dict]) -> None:
90
+ sample = sample_problems(n=10, problems=pool)
91
+ assert len(sample) == 10
92
+
93
+
94
+ def test_sample_problems_n_equals_twenty_works(pool: list[dict]) -> None:
95
+ sample = sample_problems(n=20, seed=7, problems=pool)
96
+ assert len(sample) == 20
97
+ # Sampling is without replacement.
98
+ assert len({e["id"] for e in sample}) == 20
99
+
100
+
101
+ def test_sample_problems_n_exceeds_pool_raises(pool: list[dict]) -> None:
102
+ with pytest.raises(ValueError):
103
+ sample_problems(n=10000, seed=0, problems=pool)
104
+
105
+
106
+ # ---------------------------------------------------------------------------
107
+ # validate_answer
108
+ # ---------------------------------------------------------------------------
109
+
110
+
111
+ def test_validate_answer_correct_returns_true(pool: list[dict]) -> None:
112
+ first = pool[0]
113
+ assert validate_answer(first["id"], first["answer"], pool) is True
114
+
115
+
116
+ def test_validate_answer_wrong_returns_false(pool: list[dict]) -> None:
117
+ first = pool[0]
118
+ wrong = first["answer"] + 99999
119
+ assert validate_answer(first["id"], wrong, pool) is False
120
+
121
+
122
+ def test_validate_answer_missing_id_returns_false(pool: list[dict]) -> None:
123
+ # Missing id returns False, not an exception.
124
+ assert validate_answer(10_000_000, 42, pool) is False
125
+
126
+
127
+ # ---------------------------------------------------------------------------
128
+ # ground_truth_map
129
+ # ---------------------------------------------------------------------------
130
+
131
+
132
+ def test_ground_truth_map_entries_count_matches_input(pool: list[dict]) -> None:
133
+ gt = ground_truth_map(pool)
134
+ assert len(gt) == len(pool)
135
+
136
+
137
+ def test_ground_truth_map_keys_are_ints_values_are_ints(pool: list[dict]) -> None:
138
+ gt = ground_truth_map(pool)
139
+ for k, v in gt.items():
140
+ assert type(k) is int
141
+ assert type(v) is int