| | |
| | """Self-consistency evaluation for math-conjecture model checkpoints.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| | import re |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Sequence, Tuple |
| |
|
| | import torch |
| | import yaml |
| | from datasets import load_dataset |
| | from peft import PeftModel |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed |
| |
|
| | SCRIPT_ROOT = Path(__file__).resolve().parents[1] |
| | DEFAULT_CONFIG_PATH = SCRIPT_ROOT / "configs" / "deepseek_math_sota.yaml" |
| | DEFAULT_OUTPUT_JSON = SCRIPT_ROOT / "runs" / "latest_eval_report.json" |
| |
|
| | BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}") |
| | SPACE_RE = re.compile(r"\s+") |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="Run pass@k-style evaluation on held-out split.") |
| | parser.add_argument( |
| | "--config", |
| | type=Path, |
| | default=DEFAULT_CONFIG_PATH, |
| | help="Training config used for prompt formatting defaults.", |
| | ) |
| | parser.add_argument( |
| | "--base-model", |
| | type=str, |
| | default=None, |
| | help="Override base model id from config.", |
| | ) |
| | parser.add_argument( |
| | "--adapter-path", |
| | type=Path, |
| | default=None, |
| | help="Optional LoRA adapter path to load on top of base model.", |
| | ) |
| | parser.add_argument( |
| | "--eval-file", |
| | type=Path, |
| | default=None, |
| | help="Parquet split used for evaluation (defaults to post_eval.eval_file or data.default_validation_file).", |
| | ) |
| | parser.add_argument("--max-samples", type=int, default=300, help="Maximum evaluation rows.") |
| | parser.add_argument("--k", type=int, default=4, help="Number of sampled generations per prompt.") |
| | parser.add_argument("--max-new-tokens", type=int, default=256, help="Generation length cap.") |
| | parser.add_argument("--max-input-length", type=int, default=4096, help="Prompt tokenization length cap.") |
| | parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature.") |
| | parser.add_argument("--top-p", type=float, default=0.95, help="Nucleus sampling p.") |
| | parser.add_argument("--seed", type=int, default=17, help="Random seed.") |
| | parser.add_argument( |
| | "--progress-every", |
| | type=int, |
| | default=25, |
| | help="Print progress every N evaluated rows (0 disables).", |
| | ) |
| | parser.add_argument( |
| | "--sample-records", |
| | type=int, |
| | default=30, |
| | help="How many sample records to store in report.", |
| | ) |
| | parser.add_argument( |
| | "--output-json", |
| | type=Path, |
| | default=DEFAULT_OUTPUT_JSON, |
| | help="Where to write evaluation report.", |
| | ) |
| | return parser.parse_args() |
| |
|
| |
|
| | def as_text(value: Any) -> str: |
| | if value is None: |
| | return "" |
| | if isinstance(value, str): |
| | return value.strip() |
| | return str(value).strip() |
| |
|
| |
|
| | def as_float(value: Any, default: float) -> float: |
| | if value is None: |
| | return default |
| | try: |
| | return float(value) |
| | except (TypeError, ValueError): |
| | return default |
| |
|
| |
|
| | def as_int(value: Any, default: int) -> int: |
| | if value is None: |
| | return default |
| | try: |
| | return int(value) |
| | except (TypeError, ValueError): |
| | return default |
| |
|
| |
|
| | def load_config(path: Path) -> Dict[str, Any]: |
| | cfg = yaml.safe_load(path.read_text(encoding="utf-8")) |
| | if not isinstance(cfg, dict): |
| | raise ValueError("Invalid YAML config.") |
| | return cfg |
| |
|
| |
|
| | def normalize_answer(text: str) -> str: |
| | text = text.strip().lower() |
| | text = text.replace("$", "") |
| | text = text.replace("\\left", "").replace("\\right", "") |
| | text = text.replace("\\,", "").replace("\\!", "").replace("\\;", "") |
| | text = SPACE_RE.sub(" ", text) |
| | return text.strip(" .") |
| |
|
| |
|
| | def extract_boxed_values(text: str) -> List[str]: |
| | return [normalize_answer(match) for match in BOXED_RE.findall(text or "") if normalize_answer(match)] |
| |
|
| |
|
| | def parse_numeric_value(text: str) -> Optional[float]: |
| | normalized = normalize_answer(text) |
| | if not normalized: |
| | return None |
| | candidate = normalized.replace(",", "") |
| | if re.fullmatch(r"[-+]?\d+\s*/\s*[-+]?\d+", candidate): |
| | left, right = candidate.split("/", maxsplit=1) |
| | try: |
| | numerator = float(left.strip()) |
| | denominator = float(right.strip()) |
| | except ValueError: |
| | return None |
| | if denominator == 0: |
| | return None |
| | return numerator / denominator |
| | if re.fullmatch(r"[-+]?(?:\d+\.\d*|\d*\.\d+|\d+)(?:[eE][-+]?\d+)?", candidate): |
| | try: |
| | return float(candidate) |
| | except ValueError: |
| | return None |
| | return None |
| |
|
| |
|
| | def approximately_equal(left: float, right: float) -> bool: |
| | tolerance = 1e-6 * max(1.0, abs(left), abs(right)) |
| | return abs(left - right) <= tolerance |
| |
|
| |
|
| | def match_candidate(candidate: str, expected_values: Sequence[str]) -> Dict[str, Any]: |
| | cand_norm = normalize_answer(candidate) |
| | if not cand_norm: |
| | return { |
| | "match": False, |
| | "exact": False, |
| | "boxed": False, |
| | "numeric": False, |
| | "reason": "empty_candidate", |
| | } |
| |
|
| | cand_boxed = extract_boxed_values(candidate) |
| | cand_num = parse_numeric_value(cand_norm) |
| |
|
| | substring_hit = False |
| | boxed_hit = False |
| | numeric_hit = False |
| |
|
| | for expected in expected_values: |
| | exp_norm = normalize_answer(expected) |
| | if not exp_norm: |
| | continue |
| |
|
| | if cand_norm == exp_norm: |
| | return { |
| | "match": True, |
| | "exact": True, |
| | "boxed": exp_norm in cand_boxed, |
| | "numeric": False, |
| | "reason": "exact", |
| | } |
| |
|
| | if exp_norm in cand_norm or cand_norm in exp_norm: |
| | substring_hit = True |
| |
|
| | expected_boxed = extract_boxed_values(expected) |
| | for cand_box in cand_boxed: |
| | if cand_box == exp_norm or exp_norm in cand_box or cand_box in exp_norm: |
| | boxed_hit = True |
| | for exp_box in expected_boxed: |
| | if cand_norm == exp_box or exp_box in cand_norm or cand_norm in exp_box: |
| | boxed_hit = True |
| |
|
| | exp_num = parse_numeric_value(exp_norm) |
| | if cand_num is not None and exp_num is not None and approximately_equal(cand_num, exp_num): |
| | numeric_hit = True |
| |
|
| | if boxed_hit: |
| | return { |
| | "match": True, |
| | "exact": False, |
| | "boxed": True, |
| | "numeric": numeric_hit, |
| | "reason": "boxed", |
| | } |
| | if numeric_hit: |
| | return { |
| | "match": True, |
| | "exact": False, |
| | "boxed": False, |
| | "numeric": True, |
| | "reason": "numeric", |
| | } |
| | if substring_hit: |
| | return { |
| | "match": True, |
| | "exact": False, |
| | "boxed": False, |
| | "numeric": False, |
| | "reason": "substring", |
| | } |
| |
|
| | return { |
| | "match": False, |
| | "exact": False, |
| | "boxed": False, |
| | "numeric": False, |
| | "reason": "no_match", |
| | } |
| |
|
| |
|
| | def flatten_expected(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> List[str]: |
| | out: List[str] = [] |
| | final_field = as_text(data_cfg.get("final_answer_field")) or "final_answer" |
| | target_field = as_text(data_cfg.get("target_field")) or "target" |
| |
|
| | final_answer = row.get(final_field) |
| | if final_answer is not None: |
| | txt = as_text(final_answer) |
| | if txt: |
| | out.append(txt) |
| |
|
| | target = row.get(target_field) |
| | if target is None: |
| | return out |
| | if isinstance(target, str): |
| | stripped = target.strip() |
| | if not stripped: |
| | return out |
| | try: |
| | target = json.loads(stripped) |
| | except json.JSONDecodeError: |
| | out.append(stripped) |
| | return out |
| |
|
| | if isinstance(target, dict): |
| | for value in target.values(): |
| | if isinstance(value, list): |
| | for item in value: |
| | txt = as_text(item) |
| | if txt: |
| | out.append(txt) |
| | else: |
| | txt = as_text(value) |
| | if txt: |
| | out.append(txt) |
| | elif isinstance(target, list): |
| | for item in target: |
| | txt = as_text(item) |
| | if txt: |
| | out.append(txt) |
| | else: |
| | txt = as_text(target) |
| | if txt: |
| | out.append(txt) |
| | return out |
| |
|
| |
|
| | def build_user_block(row: Dict[str, Any], data_cfg: Dict[str, Any]) -> str: |
| | prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" |
| | prompt = as_text(row.get(prompt_field)) |
| | if not prompt: |
| | prompt = "Solve the math task." |
| | meta_fields = [ |
| | ("task_type", "Task type"), |
| | ("family", "Family"), |
| | ("difficulty", "Difficulty"), |
| | ("source_dataset", "Source"), |
| | ("status_as_of", "Status as of"), |
| | ] |
| | lines = [] |
| | for key, label in meta_fields: |
| | value = as_text(row.get(key)) |
| | if value: |
| | lines.append(f"{label}: {value}") |
| | if lines: |
| | return f"{prompt}\n\nMetadata:\n" + "\n".join(lines) |
| | return prompt |
| |
|
| |
|
| | def build_prompt_text(row: Dict[str, Any], tokenizer: AutoTokenizer, data_cfg: Dict[str, Any]) -> str: |
| | system_prompt = as_text(data_cfg.get("system_prompt")) |
| | if not system_prompt: |
| | system_prompt = "You are a rigorous mathematical reasoning assistant." |
| | user_block = build_user_block(row, data_cfg) |
| | if getattr(tokenizer, "chat_template", None): |
| | messages = [ |
| | {"role": "system", "content": system_prompt}, |
| | {"role": "user", "content": user_block}, |
| | ] |
| | return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| | return f"System:\n{system_prompt}\n\nUser:\n{user_block}\n\nAssistant:\n" |
| |
|
| |
|
| | def extract_candidate_text(full_generation: str, prompt_text: str) -> str: |
| | if full_generation.startswith(prompt_text): |
| | return full_generation[len(prompt_text) :].strip() |
| | return full_generation.strip() |
| |
|
| |
|
| | def load_model_and_tokenizer( |
| | base_model: str, |
| | adapter_path: Optional[Path], |
| | trust_remote_code: bool, |
| | ) -> Tuple[Any, AutoTokenizer]: |
| | tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=trust_remote_code, use_fast=True) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token or tokenizer.unk_token |
| | if tokenizer.pad_token is None: |
| | tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | base_model, |
| | torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| | device_map="auto" if torch.cuda.is_available() else None, |
| | trust_remote_code=trust_remote_code, |
| | ) |
| | if adapter_path is not None: |
| | model = PeftModel.from_pretrained(model, str(adapter_path)) |
| | model.eval() |
| | return model, tokenizer |
| |
|
| |
|
| | def make_bucket() -> Dict[str, Any]: |
| | return { |
| | "evaluated_rows": 0, |
| | "pass_at_1_hits": 0, |
| | "pass_at_k_hits": 0, |
| | "exact_at_1_hits": 0, |
| | "exact_at_k_hits": 0, |
| | "boxed_at_k_hits": 0, |
| | } |
| |
|
| |
|
| | def update_bucket(bucket: Dict[str, Any], hit1: bool, hitk: bool, exact1: bool, exactk: bool, boxedk: bool) -> None: |
| | bucket["evaluated_rows"] += 1 |
| | if hit1: |
| | bucket["pass_at_1_hits"] += 1 |
| | if hitk: |
| | bucket["pass_at_k_hits"] += 1 |
| | if exact1: |
| | bucket["exact_at_1_hits"] += 1 |
| | if exactk: |
| | bucket["exact_at_k_hits"] += 1 |
| | if boxedk: |
| | bucket["boxed_at_k_hits"] += 1 |
| |
|
| |
|
| | def finalize_bucket(bucket: Dict[str, Any]) -> Dict[str, Any]: |
| | total = max(int(bucket.get("evaluated_rows", 0)), 1) |
| | rows = int(bucket.get("evaluated_rows", 0)) |
| | return { |
| | "evaluated_rows": rows, |
| | "pass_at_1": float(bucket.get("pass_at_1_hits", 0)) / total, |
| | "pass_at_k": float(bucket.get("pass_at_k_hits", 0)) / total, |
| | "exact_at_1": float(bucket.get("exact_at_1_hits", 0)) / total, |
| | "exact_at_k": float(bucket.get("exact_at_k_hits", 0)) / total, |
| | "boxed_at_k": float(bucket.get("boxed_at_k_hits", 0)) / total, |
| | } |
| |
|
| |
|
| | def resolve_eval_file(arg_eval_file: Optional[Path], cfg: Dict[str, Any]) -> Path: |
| | if arg_eval_file is not None: |
| | return arg_eval_file |
| | post_eval_cfg = cfg.get("post_eval", {}) |
| | data_cfg = cfg.get("data", {}) |
| | for candidate in ( |
| | as_text(post_eval_cfg.get("eval_file")), |
| | as_text(data_cfg.get("default_validation_file")), |
| | "data/releases/v1/test.parquet", |
| | "workspace/data/releases/v1/test.parquet", |
| | ): |
| | if not candidate: |
| | continue |
| | path = Path(candidate) |
| | if path.exists(): |
| | return path |
| | return Path("data/releases/v1/test.parquet") |
| |
|
| |
|
| | def run_evaluation(args: argparse.Namespace) -> Dict[str, Any]: |
| | if args.k < 1: |
| | raise ValueError("--k must be >= 1.") |
| | if args.max_samples < 1: |
| | raise ValueError("--max-samples must be >= 1.") |
| | if args.max_new_tokens < 1: |
| | raise ValueError("--max-new-tokens must be >= 1.") |
| | if args.max_input_length < 128: |
| | raise ValueError("--max-input-length must be >= 128.") |
| | if args.temperature <= 0: |
| | raise ValueError("--temperature must be > 0.") |
| | if not 0 < args.top_p <= 1: |
| | raise ValueError("--top-p must be in (0, 1].") |
| |
|
| | cfg = load_config(args.config) |
| | data_cfg = cfg.get("data", {}) |
| | model_cfg = cfg.get("model", {}) |
| | set_seed(args.seed) |
| |
|
| | base_model = args.base_model or as_text(model_cfg.get("base_model")) |
| | if not base_model: |
| | raise ValueError("Base model is required via --base-model or config.model.base_model.") |
| | if args.adapter_path is not None and not args.adapter_path.exists(): |
| | raise FileNotFoundError(f"Adapter path not found: {args.adapter_path}") |
| |
|
| | eval_file = resolve_eval_file(args.eval_file, cfg) |
| | if not eval_file.exists(): |
| | raise FileNotFoundError(f"Evaluation file not found: {eval_file}") |
| |
|
| | model, tokenizer = load_model_and_tokenizer( |
| | base_model=base_model, |
| | adapter_path=args.adapter_path, |
| | trust_remote_code=bool(model_cfg.get("trust_remote_code", False)), |
| | ) |
| |
|
| | ds = load_dataset("parquet", data_files={"eval": str(eval_file)})["eval"] |
| | if args.max_samples > 0 and args.max_samples < len(ds): |
| | ds = ds.select(range(args.max_samples)) |
| |
|
| | totals = make_bucket() |
| | family_buckets: Dict[str, Dict[str, Any]] = {} |
| | difficulty_buckets: Dict[str, Dict[str, Any]] = {} |
| |
|
| | processed_rows = 0 |
| | skipped_no_expected = 0 |
| | samples: List[Dict[str, Any]] = [] |
| |
|
| | model_device = next(model.parameters()).device |
| | prompt_field = as_text(data_cfg.get("prompt_field")) or "prompt" |
| |
|
| | for row in ds: |
| | expected_values = flatten_expected(row, data_cfg) |
| | if not expected_values: |
| | skipped_no_expected += 1 |
| | continue |
| |
|
| | prompt_text = build_prompt_text(row, tokenizer, data_cfg) |
| | inputs = tokenizer( |
| | prompt_text, |
| | return_tensors="pt", |
| | truncation=True, |
| | max_length=args.max_input_length, |
| | ) |
| | inputs = {k: v.to(model_device) for k, v in inputs.items()} |
| |
|
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | **inputs, |
| | do_sample=True, |
| | temperature=args.temperature, |
| | top_p=args.top_p, |
| | num_return_sequences=args.k, |
| | max_new_tokens=args.max_new_tokens, |
| | pad_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ) |
| |
|
| | generations = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| | candidates = [extract_candidate_text(text, prompt_text) for text in generations] |
| | details = [match_candidate(candidate, expected_values) for candidate in candidates] |
| |
|
| | matches = [bool(item["match"]) for item in details] |
| | exacts = [bool(item["exact"]) for item in details] |
| | boxed = [bool(item["boxed"]) for item in details] |
| |
|
| | hit1 = bool(matches and matches[0]) |
| | hitk = bool(any(matches)) |
| | exact1 = bool(exacts and exacts[0]) |
| | exactk = bool(any(exacts)) |
| | boxedk = bool(any(boxed)) |
| |
|
| | update_bucket(totals, hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk) |
| |
|
| | family = as_text(row.get("family")) or "__unknown__" |
| | if family not in family_buckets: |
| | family_buckets[family] = make_bucket() |
| | update_bucket(family_buckets[family], hit1=hit1, hitk=hitk, exact1=exact1, exactk=exactk, boxedk=boxedk) |
| |
|
| | difficulty = as_text(row.get("difficulty")) or "__unknown__" |
| | if difficulty not in difficulty_buckets: |
| | difficulty_buckets[difficulty] = make_bucket() |
| | update_bucket( |
| | difficulty_buckets[difficulty], |
| | hit1=hit1, |
| | hitk=hitk, |
| | exact1=exact1, |
| | exactk=exactk, |
| | boxedk=boxedk, |
| | ) |
| |
|
| | processed_rows += 1 |
| | if args.progress_every > 0 and processed_rows % args.progress_every == 0: |
| | print(f"Progress: evaluated_rows={processed_rows} latest_family={family}") |
| |
|
| | if len(samples) < args.sample_records: |
| | samples.append( |
| | { |
| | "uid": as_text(row.get("uid")), |
| | "family": family, |
| | "difficulty": difficulty, |
| | "prompt": as_text(row.get(prompt_field)), |
| | "expected_values": expected_values[:5], |
| | "candidates": candidates, |
| | "match_details": details, |
| | "matches": matches, |
| | } |
| | ) |
| |
|
| | total_eval = int(totals.get("evaluated_rows", 0)) |
| | denominator = max(total_eval, 1) |
| |
|
| | pass_at_1 = float(totals.get("pass_at_1_hits", 0)) / denominator |
| | pass_at_k = float(totals.get("pass_at_k_hits", 0)) / denominator |
| | exact_at_1 = float(totals.get("exact_at_1_hits", 0)) / denominator |
| | exact_at_k = float(totals.get("exact_at_k_hits", 0)) / denominator |
| | boxed_at_k = float(totals.get("boxed_at_k_hits", 0)) / denominator |
| |
|
| | composite_score = 0.30 * pass_at_1 + 0.50 * pass_at_k + 0.20 * exact_at_k |
| |
|
| | report: Dict[str, Any] = { |
| | "base_model": base_model, |
| | "adapter_path": str(args.adapter_path) if args.adapter_path is not None else None, |
| | "eval_file": str(eval_file), |
| | "config": str(args.config), |
| | "evaluated_rows": total_eval, |
| | "skipped_rows_without_targets": skipped_no_expected, |
| | "requested_rows": len(ds), |
| | "k": args.k, |
| | "pass_at_1": pass_at_1, |
| | "pass_at_k": pass_at_k, |
| | "exact_at_1": exact_at_1, |
| | "exact_at_k": exact_at_k, |
| | "boxed_at_k": boxed_at_k, |
| | "composite_score": composite_score, |
| | "temperature": args.temperature, |
| | "top_p": args.top_p, |
| | "max_new_tokens": args.max_new_tokens, |
| | "max_input_length": args.max_input_length, |
| | "seed": args.seed, |
| | "family_metrics": { |
| | key: finalize_bucket(family_buckets[key]) |
| | for key in sorted(family_buckets.keys()) |
| | }, |
| | "difficulty_metrics": { |
| | key: finalize_bucket(difficulty_buckets[key]) |
| | for key in sorted(difficulty_buckets.keys()) |
| | }, |
| | "samples": samples, |
| | } |
| |
|
| | args.output_json.parent.mkdir(parents=True, exist_ok=True) |
| | args.output_json.write_text(json.dumps(report, ensure_ascii=True, indent=2), encoding="utf-8") |
| |
|
| | summary_view = { |
| | "evaluated_rows": total_eval, |
| | "pass_at_1": pass_at_1, |
| | "pass_at_k": pass_at_k, |
| | "exact_at_k": exact_at_k, |
| | "composite_score": composite_score, |
| | "k": args.k, |
| | } |
| | print(json.dumps(summary_view, indent=2)) |
| | print(f"Saved report to {args.output_json}") |
| | return report |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | run_evaluation(args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|