v2 / scripts /12_downstream_eval.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Downstream evaluation: pass@1 on MATH-500 holdout, AIME-24, GPQA-D.
For each of:
- baseline (alpha=1, no steering — NEW SEMANTICS)
- plan_alpha_0 (full planning suppression)
- mon_alpha_0 (full monitoring suppression)
generate answers and grade.
Grading is lenient numeric match for MATH / AIME; substring for GPQA.
Output: results/downstream_accuracy.json
"""
import sys
import argparse
import re
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import (
ensure_dirs, LOGS_DIR,
TEST_MATH_PATH, TEST_AIME_PATH, TEST_GPQA_PATH,
PLAN_V_PCA_SUBSPACE, MON_V_PCA_SUBSPACE,
RESULTS_DIR, DOWNSTREAM_ACC_JSON,
)
from configs.model import GEN_CONFIG_FAST
from src.utils import setup_logger, read_jsonl, write_json, cleanup_memory
from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate
from src.steering import ResidualSteerer, is_neutral_alpha
from src.directions import load_directions
def extract_boxed(text):
"""Extract LaTeX \\boxed{...} answer."""
match = re.search(r"\\boxed\{([^}]*)\}", text)
if match:
return match.group(1).strip()
return None
def extract_final_answer(text):
"""Fallback: last number-like token."""
boxed = extract_boxed(text)
if boxed:
return boxed
# Try "Final Answer: X" or "answer is X"
m = re.search(r"(?i)final\s+answer[:\s]+(.+?)(?:\n|\Z)", text)
if m:
return m.group(1).strip().rstrip(".")
# Try last number
m = re.findall(r"-?\d+\.?\d*", text[-500:])
if m:
return m[-1]
return ""
def normalize_numeric(s):
s = s.strip().replace(",", "")
s = re.sub(r"\\frac\{(\-?\d+)\}\{(\-?\d+)\}", r"\1/\2", s)
return s
def grade_numeric(pred, gold):
pred_n = normalize_numeric(str(pred))
gold_n = normalize_numeric(str(gold))
if pred_n == gold_n:
return True
try:
return abs(float(pred_n) - float(gold_n)) < 1e-4
except Exception:
return False
def grade_substring(pred, gold):
pred = str(pred).strip().lower()
gold = str(gold).strip().lower()
return gold in pred or pred in gold
def _mcnemar_pvalue(b: int, c: int):
"""
Exact two-sided McNemar test p-value.
Tests H0: P(baseline correct, steered wrong) = P(baseline wrong, steered correct)
i.e. that the two configs have equal accuracy on this paired test.
b = # samples where baseline is RIGHT but steered is WRONG (regressions)
c = # samples where baseline is WRONG but steered is RIGHT (recoveries)
Returns float p-value, or None if no discordant pairs (test undefined).
Implementation: under H0, b ~ Binomial(n=b+c, p=0.5). Two-sided exact test.
"""
n = b + c
if n == 0:
return None # no discordant pairs — test is undefined
# Two-sided exact: p = 2 * P(X <= min(b, c) | n, 0.5), capped at 1.0
# Compute Binomial CDF without scipy (small n typical):
from math import comb
k = min(b, c)
cdf = sum(comb(n, i) for i in range(k + 1)) / (2 ** n)
p = min(2 * cdf, 1.0)
return float(p)
def run_config(model, tokenizer, test_set, config_name, directions=None, alpha=1.0,
grader="numeric", max_new_tokens=2048):
"""Run one eval config over test_set. Returns {accuracy, n, per_sample}.
NEW SEMANTICS: alpha=1.0 is baseline (no steering applied).
"""
per_sample = []
correct = 0
for prob in tqdm(test_set, desc=config_name, leave=False):
prompt = build_thinking_prompt(tokenizer, prob["problem"], enable_thinking=True)
# Apply steering only if alpha is NOT the neutral value
if directions is not None and not is_neutral_alpha(alpha):
steerer = ResidualSteerer(model, directions, alpha=alpha)
steerer.start()
try:
text = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
finally:
steerer.stop()
else:
text = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
pred = extract_final_answer(text)
gold = prob.get("answer", "")
if grader == "numeric":
ok = grade_numeric(pred, gold)
else:
ok = grade_substring(pred, gold)
if ok:
correct += 1
per_sample.append({
"idx": prob["idx"], "pred": pred, "gold": gold, "correct": bool(ok),
})
cleanup_memory()
return {
"accuracy": correct / max(len(test_set), 1),
"correct": correct,
"n": len(test_set),
"per_sample": per_sample,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--configs", nargs="+",
default=["baseline", "plan_alpha_0", "mon_alpha_0"],
help="NEW SEMANTICS: alpha=1 baseline; alpha=0 max suppress; "
"alpha=2 amplify. e.g. plan_alpha_0 = max planning suppress.")
parser.add_argument("--max_new_tokens", type=int, default=2048)
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("12_downstream", LOGS_DIR / "12_downstream.log")
if args.resume and DOWNSTREAM_ACC_JSON.exists():
log.info(f"Downstream results already exist: {DOWNSTREAM_ACC_JSON}")
return
# Load test sets
test_sets = {}
if TEST_MATH_PATH.exists():
test_sets["MATH-500-holdout"] = (read_jsonl(TEST_MATH_PATH), "numeric")
if TEST_AIME_PATH.exists():
test_sets["AIME-24"] = (read_jsonl(TEST_AIME_PATH), "numeric")
if TEST_GPQA_PATH.exists():
test_sets["GPQA-D"] = (read_jsonl(TEST_GPQA_PATH), "substring")
log.info(f"Test sets: {list(test_sets.keys())}")
for name, (ds, _) in test_sets.items():
log.info(f" {name}: {len(ds)} problems")
log.info("Loading model...")
model, tokenizer = load_model_and_tokenizer()
plan_dirs = load_directions(PLAN_V_PCA_SUBSPACE)
mon_dirs = load_directions(MON_V_PCA_SUBSPACE)
results = {}
for config_name in args.configs:
log.info(f"=== Config: {config_name} ===")
if config_name == "baseline":
directions, alpha = None, 1.0 # NEW SEMANTICS: alpha=1 means no steering
elif config_name.startswith("plan_alpha_"):
alpha = float(config_name.replace("plan_alpha_", ""))
directions = plan_dirs
elif config_name.startswith("mon_alpha_"):
alpha = float(config_name.replace("mon_alpha_", ""))
directions = mon_dirs
else:
log.warning(f"Unknown config: {config_name}, skipping")
continue
results[config_name] = {}
for ts_name, (ts, grader) in test_sets.items():
r = run_config(
model, tokenizer, ts, f"{config_name}/{ts_name}",
directions=directions, alpha=alpha,
grader=grader, max_new_tokens=args.max_new_tokens,
)
log.info(f" {ts_name}: {r['correct']}/{r['n']} = {r['accuracy']:.3f}")
results[config_name][ts_name] = {
"accuracy": r["accuracy"],
"correct": r["correct"],
"n": r["n"],
"per_sample": r["per_sample"],
}
# =========================================================
# Compute accuracy-drop statistics vs baseline
# =========================================================
if "baseline" in results:
log.info("=" * 60)
log.info("Computing per-config accuracy drop vs baseline...")
for config_name in list(results.keys()):
if config_name == "baseline":
continue
for ts_name in list(results[config_name].keys()):
base = results["baseline"].get(ts_name)
cur = results[config_name].get(ts_name)
if base is None or cur is None:
continue
# Build per-sample correctness lookup keyed by problem idx
base_map = {p["idx"]: p["correct"] for p in base["per_sample"]}
cur_map = {p["idx"]: p["correct"] for p in cur["per_sample"]}
common_idx = set(base_map) & set(cur_map)
n_common = len(common_idx)
# Discordant pairs for McNemar
# b: baseline correct, steered wrong ("regressions")
# c: baseline wrong, steered correct ("recoveries")
b = sum(1 for i in common_idx
if base_map[i] and not cur_map[i])
c = sum(1 for i in common_idx
if (not base_map[i]) and cur_map[i])
# McNemar p-value (exact binomial under H0: b ~ Bin(b+c, 0.5))
mcnemar_p = _mcnemar_pvalue(b, c)
base_acc = base["accuracy"]
cur_acc = cur["accuracy"]
delta = base_acc - cur_acc # positive => accuracy DROPPED
rel_drop = (delta / base_acc) if base_acc > 0 else 0.0
cur["vs_baseline"] = {
"baseline_accuracy": base_acc,
"steered_accuracy": cur_acc,
"absolute_drop": delta,
"relative_drop": rel_drop,
"n_common": n_common,
"n_regressions": b,
"n_recoveries": c,
"mcnemar_p_value": mcnemar_p,
"significant_at_0_05": (mcnemar_p is not None and mcnemar_p < 0.05),
}
log.info(
f" {config_name}/{ts_name}: "
f"acc {base_acc:.3f} -> {cur_acc:.3f} "
f"(Δ={delta:+.3f}, rel={rel_drop:+.1%}) "
f"regressions={b} recoveries={c} "
f"McNemar p={mcnemar_p if mcnemar_p is None else f'{mcnemar_p:.3g}'}"
)
write_json(results, DOWNSTREAM_ACC_JSON)
log.info(f"Saved: {DOWNSTREAM_ACC_JSON}")
if __name__ == "__main__":
main()