Spaces:
Sleeping
Sleeping
Commit ·
0738d13
1
Parent(s): 313cc30
Phase 6: Problems pool (500 GSM8K-style problems) and sampling API per PROJECT.md Section 12
Browse files- data/problems_pool.json +0 -0
- red_button/problems.py +68 -3
- scripts/generate_problems_pool.py +224 -0
- tests/test_problems.py +141 -0
data/problems_pool.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
red_button/problems.py
CHANGED
|
@@ -1,5 +1,70 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Problems pool loader, sampler, and integer answer validator (Section 12).
|
| 2 |
|
| 3 |
+
The pool is a JSON file (``data/problems_pool.json``) shaped per Section 12.2.
|
| 4 |
+
Each entry is ``{"id": int, "problem": str, "answer": int, "difficulty": str}``.
|
| 5 |
+
|
| 6 |
+
This module intentionally uses plain ``dict`` for pool entries — the JSON file
|
| 7 |
+
is the contract and dicts are the natural intermediate. The rubric consumer
|
| 8 |
+
(:class:`red_button.rubrics.MathCorrectnessRubric`) reads ``dict[int, int]``
|
| 9 |
+
from ``state.ground_truth`` (populated via :func:`ground_truth_map`).
|
| 10 |
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import random
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_problems(path: str = "data/problems_pool.json") -> list[dict]:
|
| 21 |
+
"""Load and return the problems pool from ``path`` (Section 12.2).
|
| 22 |
+
|
| 23 |
+
Path is interpreted relative to the current working directory.
|
| 24 |
+
"""
|
| 25 |
+
with Path(path).open("r", encoding="utf-8") as fh:
|
| 26 |
+
data = json.load(fh)
|
| 27 |
+
if not isinstance(data, list):
|
| 28 |
+
raise ValueError(f"Expected a JSON list at {path}, got {type(data).__name__}")
|
| 29 |
+
return data
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def sample_problems(
|
| 33 |
+
n: int = 10,
|
| 34 |
+
seed: Optional[int] = None,
|
| 35 |
+
problems: Optional[list[dict]] = None,
|
| 36 |
+
) -> list[dict]:
|
| 37 |
+
"""Sample ``n`` problems without replacement.
|
| 38 |
+
|
| 39 |
+
* ``problems`` defaults to :func:`load_problems` output.
|
| 40 |
+
* If ``seed`` is provided, sampling is deterministic via ``random.Random(seed).sample``.
|
| 41 |
+
* ``n > len(problems)`` raises :class:`ValueError`.
|
| 42 |
+
"""
|
| 43 |
+
pool = problems if problems is not None else load_problems()
|
| 44 |
+
if n > len(pool):
|
| 45 |
+
raise ValueError(
|
| 46 |
+
f"Requested sample of n={n} exceeds pool size {len(pool)}"
|
| 47 |
+
)
|
| 48 |
+
rng = random.Random(seed) if seed is not None else random.Random()
|
| 49 |
+
return rng.sample(pool, n)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def validate_answer(
|
| 53 |
+
problem_id: int,
|
| 54 |
+
submitted_answer: int,
|
| 55 |
+
problems: list[dict],
|
| 56 |
+
) -> bool:
|
| 57 |
+
"""Return ``True`` iff ``problem_id`` exists and its answer matches.
|
| 58 |
+
|
| 59 |
+
Missing ``problem_id`` returns ``False`` (not an exception), by design —
|
| 60 |
+
the environment's ``submit_answer`` tool must never crash on junk input.
|
| 61 |
+
"""
|
| 62 |
+
for entry in problems:
|
| 63 |
+
if entry.get("id") == problem_id:
|
| 64 |
+
return entry.get("answer") == submitted_answer
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def ground_truth_map(problems: list[dict]) -> dict[int, int]:
|
| 69 |
+
"""Return ``{id: answer}`` for a problems list (Section 12.5)."""
|
| 70 |
+
return {int(p["id"]): int(p["answer"]) for p in problems}
|
scripts/generate_problems_pool.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate data/problems_pool.json from GSM8K and MATH per PROJECT.md Section 12.3.
|
| 2 |
+
|
| 3 |
+
Target: 500 problems (300 GSM8K + 200 MATH), integer-answer only, deterministic.
|
| 4 |
+
|
| 5 |
+
Usage::
|
| 6 |
+
|
| 7 |
+
python scripts/generate_problems_pool.py
|
| 8 |
+
|
| 9 |
+
Re-runs are reproducible because we seed the source-side shuffle with
|
| 10 |
+
``SOURCE_SEED`` (see below). Rerun whenever Section 12.3 changes or HF
|
| 11 |
+
dataset splits shift.
|
| 12 |
+
|
| 13 |
+
Dependencies:
|
| 14 |
+
pip install datasets
|
| 15 |
+
|
| 16 |
+
Answer extraction
|
| 17 |
+
-----------------
|
| 18 |
+
* GSM8K: answers end with ``#### N`` — extract the final integer.
|
| 19 |
+
* MATH: answers live inside ``\\boxed{...}`` — try ``int()`` on the contents.
|
| 20 |
+
|
| 21 |
+
Filter: reject anything whose extracted answer cannot be parsed as ``int``
|
| 22 |
+
(fractions, decimals, non-numeric, multi-part).
|
| 23 |
+
|
| 24 |
+
Difficulty heuristic (per the Phase 6 task spec — PROJECT.md Section 12 itself
|
| 25 |
+
does not specify one)::
|
| 26 |
+
|
| 27 |
+
easy if answer <= 100 and len(problem) < 150
|
| 28 |
+
hard if answer > 1000 or len(problem) > 300
|
| 29 |
+
medium otherwise
|
| 30 |
+
|
| 31 |
+
If HF access fails or <500 valid problems can be produced, the script stops
|
| 32 |
+
WITHOUT writing output and surfaces the shortfall — do NOT hand-roll a
|
| 33 |
+
substitute.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
import json
|
| 39 |
+
import random
|
| 40 |
+
import re
|
| 41 |
+
import sys
|
| 42 |
+
from pathlib import Path
|
| 43 |
+
|
| 44 |
+
# Reproducibility seed for source-side shuffling/sampling.
|
| 45 |
+
SOURCE_SEED = 20260425
|
| 46 |
+
|
| 47 |
+
GSM8K_TARGET = 300
|
| 48 |
+
MATH_TARGET = 200
|
| 49 |
+
|
| 50 |
+
OUTPUT_PATH = Path(__file__).resolve().parents[1] / "data" / "problems_pool.json"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _classify_difficulty(problem: str, answer: int) -> str:
|
| 54 |
+
if answer <= 100 and len(problem) < 150:
|
| 55 |
+
return "easy"
|
| 56 |
+
if answer > 1000 or len(problem) > 300:
|
| 57 |
+
return "hard"
|
| 58 |
+
return "medium"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
_GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+)")
|
| 62 |
+
_MATH_BOXED_RE = re.compile(r"\\boxed\{([^{}]*)\}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _extract_gsm8k_answer(raw_answer: str) -> int | None:
|
| 66 |
+
"""Pull the trailing ``#### N`` integer out of a GSM8K answer string."""
|
| 67 |
+
m = _GSM8K_ANSWER_RE.search(raw_answer)
|
| 68 |
+
if m is None:
|
| 69 |
+
return None
|
| 70 |
+
try:
|
| 71 |
+
return int(m.group(1))
|
| 72 |
+
except (TypeError, ValueError):
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _extract_math_answer(raw_solution: str) -> int | None:
|
| 77 |
+
"""Pull an integer out of the final ``\\boxed{...}`` of a MATH solution."""
|
| 78 |
+
matches = _MATH_BOXED_RE.findall(raw_solution)
|
| 79 |
+
if not matches:
|
| 80 |
+
return None
|
| 81 |
+
candidate = matches[-1].strip()
|
| 82 |
+
try:
|
| 83 |
+
return int(candidate)
|
| 84 |
+
except (TypeError, ValueError):
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _collect_gsm8k(target: int) -> list[dict]:
|
| 89 |
+
"""Pull ``target`` integer-answer problems from GSM8K (openai/gsm8k, main, train)."""
|
| 90 |
+
from datasets import load_dataset
|
| 91 |
+
|
| 92 |
+
ds = load_dataset("openai/gsm8k", "main", split="train")
|
| 93 |
+
# Deterministic shuffle over indices so re-runs match.
|
| 94 |
+
indices = list(range(len(ds)))
|
| 95 |
+
random.Random(SOURCE_SEED).shuffle(indices)
|
| 96 |
+
|
| 97 |
+
collected: list[dict] = []
|
| 98 |
+
for idx in indices:
|
| 99 |
+
if len(collected) >= target:
|
| 100 |
+
break
|
| 101 |
+
row = ds[idx]
|
| 102 |
+
question = row["question"].strip()
|
| 103 |
+
answer = _extract_gsm8k_answer(row["answer"])
|
| 104 |
+
if answer is None:
|
| 105 |
+
continue
|
| 106 |
+
collected.append(
|
| 107 |
+
{
|
| 108 |
+
"problem": question,
|
| 109 |
+
"answer": answer,
|
| 110 |
+
"source": "gsm8k",
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
return collected
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _collect_math(target: int) -> list[dict]:
|
| 117 |
+
"""Pull ``target`` integer-answer problems from the MATH algebra track.
|
| 118 |
+
|
| 119 |
+
We try ``EleutherAI/hendrycks_math`` first (the 2024+ maintained mirror
|
| 120 |
+
with per-subject configs); ``hendrycks/competition_math`` was deprecated.
|
| 121 |
+
"""
|
| 122 |
+
from datasets import load_dataset
|
| 123 |
+
|
| 124 |
+
load_errors: list[str] = []
|
| 125 |
+
ds = None
|
| 126 |
+
for source_name, loader in [
|
| 127 |
+
("EleutherAI/hendrycks_math[algebra]", lambda: load_dataset(
|
| 128 |
+
"EleutherAI/hendrycks_math", "algebra", split="train"
|
| 129 |
+
)),
|
| 130 |
+
("hendrycks/competition_math", lambda: load_dataset(
|
| 131 |
+
"hendrycks/competition_math", split="train"
|
| 132 |
+
)),
|
| 133 |
+
]:
|
| 134 |
+
try:
|
| 135 |
+
ds = loader()
|
| 136 |
+
print(f" MATH source: {source_name}")
|
| 137 |
+
break
|
| 138 |
+
except Exception as exc: # noqa: BLE001
|
| 139 |
+
load_errors.append(f"{source_name}: {exc}")
|
| 140 |
+
continue
|
| 141 |
+
if ds is None:
|
| 142 |
+
raise RuntimeError(
|
| 143 |
+
"Could not load any MATH dataset. Tried:\n "
|
| 144 |
+
+ "\n ".join(load_errors)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
indices = list(range(len(ds)))
|
| 148 |
+
random.Random(SOURCE_SEED + 1).shuffle(indices)
|
| 149 |
+
|
| 150 |
+
collected: list[dict] = []
|
| 151 |
+
for idx in indices:
|
| 152 |
+
if len(collected) >= target:
|
| 153 |
+
break
|
| 154 |
+
row = ds[idx]
|
| 155 |
+
# Some mirrors use "problem"+"solution", others "question"+"answer".
|
| 156 |
+
question = (row.get("problem") or row.get("question") or "").strip()
|
| 157 |
+
raw_soln = row.get("solution") or row.get("answer") or ""
|
| 158 |
+
if not question or not raw_soln:
|
| 159 |
+
continue
|
| 160 |
+
# If algebra subset isn't available, fall back to filtering by "type".
|
| 161 |
+
subject = (row.get("type") or row.get("subject") or "algebra").lower()
|
| 162 |
+
if "algebra" not in subject:
|
| 163 |
+
continue
|
| 164 |
+
answer = _extract_math_answer(raw_soln)
|
| 165 |
+
if answer is None:
|
| 166 |
+
continue
|
| 167 |
+
collected.append(
|
| 168 |
+
{
|
| 169 |
+
"problem": question,
|
| 170 |
+
"answer": answer,
|
| 171 |
+
"source": "math_algebra",
|
| 172 |
+
}
|
| 173 |
+
)
|
| 174 |
+
return collected
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def main() -> int:
|
| 178 |
+
print("Collecting GSM8K problems...")
|
| 179 |
+
gsm = _collect_gsm8k(GSM8K_TARGET)
|
| 180 |
+
print(f" GSM8K: {len(gsm)} valid integer-answer problems")
|
| 181 |
+
|
| 182 |
+
print("Collecting MATH algebra problems...")
|
| 183 |
+
math = _collect_math(MATH_TARGET)
|
| 184 |
+
print(f" MATH: {len(math)} valid integer-answer problems")
|
| 185 |
+
|
| 186 |
+
total_needed = GSM8K_TARGET + MATH_TARGET
|
| 187 |
+
combined_raw = gsm + math
|
| 188 |
+
|
| 189 |
+
# Shortfall check — STOP before writing if we're under target.
|
| 190 |
+
if len(combined_raw) < total_needed:
|
| 191 |
+
print(
|
| 192 |
+
f"ERROR: shortfall. Got {len(combined_raw)} valid problems, "
|
| 193 |
+
f"need {total_needed}. Not writing output.",
|
| 194 |
+
file=sys.stderr,
|
| 195 |
+
)
|
| 196 |
+
return 1
|
| 197 |
+
|
| 198 |
+
# Assign sequential ids from 1; attach difficulty; drop source.
|
| 199 |
+
pool: list[dict] = []
|
| 200 |
+
for i, entry in enumerate(combined_raw, start=1):
|
| 201 |
+
problem = entry["problem"]
|
| 202 |
+
answer = int(entry["answer"])
|
| 203 |
+
pool.append(
|
| 204 |
+
{
|
| 205 |
+
"id": i,
|
| 206 |
+
"problem": problem,
|
| 207 |
+
"answer": answer,
|
| 208 |
+
"difficulty": _classify_difficulty(problem, answer),
|
| 209 |
+
}
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 213 |
+
with OUTPUT_PATH.open("w", encoding="utf-8") as fh:
|
| 214 |
+
json.dump(pool, fh, ensure_ascii=False, indent=2)
|
| 215 |
+
|
| 216 |
+
print(
|
| 217 |
+
f"Generated {len(pool)} problems: "
|
| 218 |
+
f"{len(gsm)} from GSM8K, {len(math)} from MATH -> {OUTPUT_PATH}"
|
| 219 |
+
)
|
| 220 |
+
return 0
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
if __name__ == "__main__":
|
| 224 |
+
raise SystemExit(main())
|
tests/test_problems.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for red_button.problems (PROJECT.md Section 12)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
from red_button.problems import (
|
| 10 |
+
ground_truth_map,
|
| 11 |
+
load_problems,
|
| 12 |
+
sample_problems,
|
| 13 |
+
validate_answer,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
POOL_PATH = str(Path(__file__).resolve().parents[1] / "data" / "problems_pool.json")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Fixtures
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture(scope="module")
|
| 25 |
+
def pool() -> list[dict]:
|
| 26 |
+
return load_problems(POOL_PATH)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# load_problems + pool structure
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_load_problems_returns_list_of_dicts_with_required_keys(pool: list[dict]) -> None:
|
| 35 |
+
assert isinstance(pool, list)
|
| 36 |
+
required = {"id", "problem", "answer", "difficulty"}
|
| 37 |
+
for entry in pool:
|
| 38 |
+
assert isinstance(entry, dict)
|
| 39 |
+
assert required.issubset(entry.keys()), f"missing keys in {entry}"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_all_ids_are_unique(pool: list[dict]) -> None:
|
| 43 |
+
ids = [e["id"] for e in pool]
|
| 44 |
+
assert len(ids) == len(set(ids))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_all_answers_are_integers(pool: list[dict]) -> None:
|
| 48 |
+
# Explicit type check rules out bool (which is a subclass of int).
|
| 49 |
+
for entry in pool:
|
| 50 |
+
assert type(entry["answer"]) is int, (
|
| 51 |
+
f"non-int answer {entry['answer']!r} in problem id={entry['id']}"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_pool_size_meets_target(pool: list[dict]) -> None:
|
| 56 |
+
# Section 12.3 target: 300 GSM8K + 200 MATH = 500. Generated pool size
|
| 57 |
+
# must be >= 500. Lower the floor only with a documented rationale.
|
| 58 |
+
assert len(pool) >= 500
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_all_difficulty_labels_are_valid(pool: list[dict]) -> None:
|
| 62 |
+
valid = {"easy", "medium", "hard"}
|
| 63 |
+
for entry in pool:
|
| 64 |
+
assert entry["difficulty"] in valid, f"bad difficulty in {entry}"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# sample_problems
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_sample_problems_seeded_returns_correct_count(pool: list[dict]) -> None:
|
| 73 |
+
sample = sample_problems(n=10, seed=42, problems=pool)
|
| 74 |
+
assert len(sample) == 10
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def test_sample_problems_seeded_is_deterministic(pool: list[dict]) -> None:
|
| 78 |
+
a = sample_problems(n=10, seed=42, problems=pool)
|
| 79 |
+
b = sample_problems(n=10, seed=42, problems=pool)
|
| 80 |
+
assert [e["id"] for e in a] == [e["id"] for e in b]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def test_sample_problems_different_seeds_differ(pool: list[dict]) -> None:
|
| 84 |
+
a_ids = {e["id"] for e in sample_problems(n=10, seed=42, problems=pool)}
|
| 85 |
+
b_ids = {e["id"] for e in sample_problems(n=10, seed=43, problems=pool)}
|
| 86 |
+
assert a_ids != b_ids
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_sample_problems_unseeded_returns_correct_count(pool: list[dict]) -> None:
|
| 90 |
+
sample = sample_problems(n=10, problems=pool)
|
| 91 |
+
assert len(sample) == 10
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def test_sample_problems_n_equals_twenty_works(pool: list[dict]) -> None:
|
| 95 |
+
sample = sample_problems(n=20, seed=7, problems=pool)
|
| 96 |
+
assert len(sample) == 20
|
| 97 |
+
# Sampling is without replacement.
|
| 98 |
+
assert len({e["id"] for e in sample}) == 20
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def test_sample_problems_n_exceeds_pool_raises(pool: list[dict]) -> None:
|
| 102 |
+
with pytest.raises(ValueError):
|
| 103 |
+
sample_problems(n=10000, seed=0, problems=pool)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# validate_answer
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def test_validate_answer_correct_returns_true(pool: list[dict]) -> None:
|
| 112 |
+
first = pool[0]
|
| 113 |
+
assert validate_answer(first["id"], first["answer"], pool) is True
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_validate_answer_wrong_returns_false(pool: list[dict]) -> None:
|
| 117 |
+
first = pool[0]
|
| 118 |
+
wrong = first["answer"] + 99999
|
| 119 |
+
assert validate_answer(first["id"], wrong, pool) is False
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_validate_answer_missing_id_returns_false(pool: list[dict]) -> None:
|
| 123 |
+
# Missing id returns False, not an exception.
|
| 124 |
+
assert validate_answer(10_000_000, 42, pool) is False
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ---------------------------------------------------------------------------
|
| 128 |
+
# ground_truth_map
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def test_ground_truth_map_entries_count_matches_input(pool: list[dict]) -> None:
|
| 133 |
+
gt = ground_truth_map(pool)
|
| 134 |
+
assert len(gt) == len(pool)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def test_ground_truth_map_keys_are_ints_values_are_ints(pool: list[dict]) -> None:
|
| 138 |
+
gt = ground_truth_map(pool)
|
| 139 |
+
for k, v in gt.items():
|
| 140 |
+
assert type(k) is int
|
| 141 |
+
assert type(v) is int
|