| """CLI: dataset path, baseline, output dir, dry-run, smoke eval. |
| |
| Evaluation uses batch LLM judge: 2 calls/session + 2 calls/QA. |
| Session and QA evaluations run in parallel via ThreadPoolExecutor. |
| Pipeline results are checkpointed before eval so --eval-only can resume. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| from collections.abc import Callable |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from dataclasses import asdict |
| from pathlib import Path |
| from typing import Any |
|
|
| try: |
| from openai import OpenAI |
| except ImportError: |
| OpenAI = None |
|
|
| from eval_framework.config import EvalConfig |
| from eval_framework.datasets.domain_a_v2 import ( |
| DomainAV2AcademicBundle, |
| NormalizedCheckpointQuestion, |
| load_domain_a_v2_academic, |
| ) |
| from eval_framework.datasets.schemas import ( |
| MemoryDeltaRecord, |
| MemorySnapshotRecord, |
| RetrievalItem, |
| RetrievalRecord, |
| ) |
| from eval_framework.evaluators.aggregate import aggregate_metrics |
| from eval_framework.evaluators.extraction import evaluate_extraction |
| from eval_framework.evaluators.qa import evaluate_checkpoint_qa |
| from eval_framework.memory_adapters.base import MemoryAdapter |
| from eval_framework.openai_compat import ( |
| patch_openai_chat_completions, |
| rewrite_chat_completion_kwargs, |
| ) |
| from eval_framework.pipeline.gold_state import GoldMemoryPoint, SessionGoldState |
| from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord |
| from eval_framework.pipeline.runner import run_domain_a_v2_sample |
|
|
|
|
| _CHECKPOINT_SESSIONS = "pipeline_sessions.jsonl" |
| _CHECKPOINT_QA = "pipeline_qa.jsonl" |
|
|
|
|
| |
| |
| |
|
|
| def _gold_point_from_dict(d: dict[str, Any]) -> GoldMemoryPoint: |
| return GoldMemoryPoint( |
| memory_id=d["memory_id"], |
| memory_content=d["memory_content"], |
| memory_type=d["memory_type"], |
| memory_source=d["memory_source"], |
| is_update=bool(d["is_update"]), |
| original_memories=tuple(d.get("original_memories") or ()), |
| importance=float(d.get("importance", 0.0)), |
| timestamp=d.get("timestamp"), |
| update_type=d.get("update_type", ""), |
| ) |
|
|
|
|
| def _gold_state_from_dict(d: dict[str, Any]) -> SessionGoldState: |
| return SessionGoldState( |
| session_id=d["session_id"], |
| cumulative_gold_memories=tuple(_gold_point_from_dict(g) for g in d["cumulative_gold_memories"]), |
| session_new_memories=tuple(_gold_point_from_dict(g) for g in d["session_new_memories"]), |
| session_update_memories=tuple(_gold_point_from_dict(g) for g in d["session_update_memories"]), |
| session_interference_memories=tuple(_gold_point_from_dict(g) for g in d["session_interference_memories"]), |
| ) |
|
|
|
|
| def _snapshot_record_from_dict(d: dict[str, Any]) -> MemorySnapshotRecord: |
| return MemorySnapshotRecord( |
| memory_id=d["memory_id"], |
| text=d["text"], |
| session_id=d["session_id"], |
| status=d["status"], |
| source=d.get("source"), |
| raw_backend_id=d.get("raw_backend_id"), |
| raw_backend_type=d.get("raw_backend_type"), |
| metadata=d.get("metadata") or {}, |
| ) |
|
|
|
|
| def _delta_record_from_dict(d: dict[str, Any]) -> MemoryDeltaRecord: |
| return MemoryDeltaRecord( |
| session_id=d["session_id"], |
| op=d["op"], |
| text=d["text"], |
| linked_previous=tuple(d.get("linked_previous") or ()), |
| raw_backend_id=d.get("raw_backend_id"), |
| metadata=d.get("metadata") or {}, |
| ) |
|
|
|
|
| def _retrieval_item_from_dict(d: dict[str, Any]) -> RetrievalItem: |
| return RetrievalItem( |
| rank=int(d["rank"]), |
| memory_id=d["memory_id"], |
| text=d["text"], |
| score=float(d["score"]), |
| raw_backend_id=d.get("raw_backend_id"), |
| ) |
|
|
|
|
| def _retrieval_record_from_dict(d: dict[str, Any]) -> RetrievalRecord: |
| return RetrievalRecord( |
| query=d["query"], |
| top_k=int(d["top_k"]), |
| items=[_retrieval_item_from_dict(i) for i in d["items"]], |
| raw_trace=d.get("raw_trace") or {}, |
| ) |
|
|
|
|
| def _session_record_from_dict(d: dict[str, Any]) -> PipelineSessionRecord: |
| return PipelineSessionRecord( |
| sample_id=d["sample_id"], |
| sample_uuid=d["sample_uuid"], |
| session_id=d["session_id"], |
| memory_snapshot=tuple(_snapshot_record_from_dict(s) for s in d["memory_snapshot"]), |
| memory_delta=tuple(_delta_record_from_dict(dl) for dl in d["memory_delta"]), |
| gold_state=_gold_state_from_dict(d["gold_state"]), |
| ) |
|
|
|
|
| def _qa_record_from_dict(d: dict[str, Any]) -> PipelineCheckpointQARecord: |
| return PipelineCheckpointQARecord( |
| sample_id=d["sample_id"], |
| sample_uuid=d["sample_uuid"], |
| checkpoint_id=d["checkpoint_id"], |
| question=d["question"], |
| gold_answer=d["gold_answer"], |
| gold_evidence_memory_ids=tuple(d.get("gold_evidence_memory_ids") or ()), |
| gold_evidence_contents=tuple(d.get("gold_evidence_contents") or ()), |
| question_type=d["question_type"], |
| question_type_abbrev=d["question_type_abbrev"], |
| difficulty=d["difficulty"], |
| retrieval=_retrieval_record_from_dict(d["retrieval"]), |
| generated_answer=d["generated_answer"], |
| cited_memories=tuple(d.get("cited_memories") or ()), |
| ) |
|
|
|
|
| def _read_jsonl(path: Path) -> list[dict[str, Any]]: |
| rows: list[dict[str, Any]] = [] |
| with path.open("r", encoding="utf-8") as fh: |
| for line in fh: |
| line = line.strip() |
| if line: |
| rows.append(json.loads(line)) |
| return rows |
|
|
|
|
| def _load_pipeline_checkpoint( |
| output_dir: Path, |
| ) -> tuple[list[PipelineSessionRecord], list[PipelineCheckpointQARecord]]: |
| """Restore pipeline records from checkpoint JSONL files.""" |
| sess_path = output_dir / _CHECKPOINT_SESSIONS |
| qa_path = output_dir / _CHECKPOINT_QA |
| if not sess_path.exists() or not qa_path.exists(): |
| raise SystemExit( |
| f"Checkpoint files not found in {output_dir}. " |
| f"Run without --eval-only first to generate them." |
| ) |
| session_records = [_session_record_from_dict(d) for d in _read_jsonl(sess_path)] |
| qa_records = [_qa_record_from_dict(d) for d in _read_jsonl(qa_path)] |
| return session_records, qa_records |
|
|
|
|
| def _default_create_adapter(baseline_name: str) -> MemoryAdapter: |
| from eval_framework.memory_adapters import registry as reg |
|
|
| if baseline_name in reg.MEMGALLERY_NATIVE_REGISTRY: |
| return reg.MEMGALLERY_NATIVE_REGISTRY[baseline_name]() |
| if baseline_name in reg.EXTERNAL_ADAPTER_REGISTRY: |
| return reg.EXTERNAL_ADAPTER_REGISTRY[baseline_name]() |
| known = sorted( |
| reg.MEMGALLERY_NATIVE_BASELINES | reg.EXTERNAL_ADAPTER_KEYS |
| ) |
| raise SystemExit( |
| f"Unknown baseline {baseline_name!r}. " |
| f"Expected one of: {', '.join(known)}" |
| ) |
|
|
|
|
| def _gold_echo_answer( |
| q: NormalizedCheckpointQuestion, _retrieval: RetrievalRecord |
| ) -> tuple[str, list[str]]: |
| return q.gold_answer, [] |
|
|
|
|
| def _parse_answer_json(raw: str) -> tuple[str, list[str]]: |
| """Extract answer and cited_memories from the model's JSON response.""" |
| |
| try: |
| data = json.loads(raw) |
| answer = str(data.get("answer", "")) |
| cited = data.get("cited_memories", []) |
| if isinstance(cited, list): |
| return answer, [str(c) for c in cited] |
| return answer, [] |
| except (json.JSONDecodeError, TypeError): |
| pass |
| |
| import re |
| m = re.search(r"\{[\s\S]*\}", raw) |
| if m: |
| try: |
| data = json.loads(m.group()) |
| answer = str(data.get("answer", "")) |
| cited = data.get("cited_memories", []) |
| if isinstance(cited, list): |
| return answer, [str(c) for c in cited] |
| except (json.JSONDecodeError, TypeError): |
| pass |
| |
| return raw.strip(), [] |
|
|
|
|
| def build_default_answer_fn() -> Callable[ |
| [NormalizedCheckpointQuestion, RetrievalRecord], tuple[str, list[str]] |
| ]: |
| api_key = os.getenv("OPENAI_API_KEY") |
| if not api_key or OpenAI is None: |
| return _gold_echo_answer |
|
|
| client = OpenAI( |
| api_key=api_key, |
| base_url=os.getenv("OPENAI_BASE_URL"), |
| ) |
| model = os.getenv("OPENAI_MODEL") or "gpt-4o" |
| temperature = float(os.getenv("OPENAI_TEMPERATURE", "0.0")) |
| max_tokens = int(os.getenv("OPENAI_MAX_TOKENS", "1024")) |
|
|
| def _answer( |
| q: NormalizedCheckpointQuestion, retrieval: RetrievalRecord |
| ) -> tuple[str, list[str]]: |
| context_lines = [ |
| f"[{item.rank}] {item.text}" for item in retrieval.items[: retrieval.top_k] |
| ] |
| context = "\n".join(context_lines) if context_lines else "No retrieved memories." |
| prompt = ( |
| "Answer the user's question using only the retrieved memories below. " |
| "If the memories are insufficient, answer exactly: Not mentioned in memory.\n\n" |
| "You MUST also list the specific memory passages you relied on to produce " |
| "the answer. Copy the relevant text verbatim from the retrieved memories.\n\n" |
| f"Question: {q.question}\n\n" |
| f"Retrieved memories:\n{context}\n\n" |
| 'Respond in JSON:\n' |
| '{\n' |
| ' "answer": "your concise answer",\n' |
| ' "cited_memories": ["verbatim passage 1", "verbatim passage 2"]\n' |
| '}\n' |
| ) |
| request_kwargs = rewrite_chat_completion_kwargs( |
| { |
| "model": model, |
| "messages": [ |
| { |
| "role": "system", |
| "content": ( |
| "You answer benchmark questions using only supplied memory context. " |
| "Be concise and do not invent missing facts. " |
| "Always respond in the requested JSON format." |
| ), |
| }, |
| {"role": "user", "content": prompt}, |
| ], |
| "temperature": temperature, |
| "max_tokens": max_tokens, |
| } |
| ) |
| response = client.chat.completions.create(**request_kwargs) |
| raw = response.choices[0].message.content or "" |
| return _parse_answer_json(raw) |
|
|
| return _answer |
|
|
|
|
| def config_from_namespace(ns: argparse.Namespace) -> EvalConfig: |
| return EvalConfig( |
| dataset_path=Path(ns.dataset).expanduser().resolve(), |
| output_dir=Path(ns.output_dir).expanduser().resolve(), |
| baseline=str(ns.baseline), |
| smoke=bool(ns.smoke), |
| dry_run=bool(ns.dry_run), |
| ) |
|
|
|
|
| def _record_to_json_obj(obj: Any) -> dict[str, Any]: |
| return asdict(obj) |
|
|
|
|
| def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("w", encoding="utf-8") as fh: |
| for row in rows: |
| fh.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
|
|
| def _write_json(path: Path, payload: dict[str, Any]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text( |
| json.dumps(payload, ensure_ascii=False, indent=2) + "\n", |
| encoding="utf-8", |
| ) |
|
|
|
|
| def run_eval( |
| config: EvalConfig, |
| *, |
| load_domain_bundle: Callable[[Path], DomainAV2AcademicBundle] = load_domain_a_v2_academic, |
| create_adapter: Callable[[str], MemoryAdapter] | None = None, |
| answer_fn: Callable | None = None, |
| max_eval_workers: int = 5, |
| eval_only: bool = False, |
| ) -> None: |
| """Load data, run pipeline (serial) + LLM eval (parallel).""" |
| patch_openai_chat_completions() |
| if config.dry_run: |
| return |
|
|
| out = config.output_dir |
| out.mkdir(parents=True, exist_ok=True) |
|
|
| if eval_only: |
| |
| print(f"[Eval-only] Loading pipeline checkpoint from {out}") |
| session_records, qa_records = _load_pipeline_checkpoint(out) |
| print(f"[Eval-only] Loaded {len(session_records)} sessions + {len(qa_records)} QA records") |
| else: |
| |
| adapter_factory = create_adapter or _default_create_adapter |
| bundle = load_domain_bundle(config.dataset_path) |
| samples = bundle.samples[:1] if config.smoke else bundle.samples |
| _answer = answer_fn if answer_fn is not None else build_default_answer_fn() |
|
|
| session_records: list[PipelineSessionRecord] = [] |
| qa_records: list[PipelineCheckpointQARecord] = [] |
|
|
| print(f"[Pipeline] Running {len(samples)} sample(s) with baseline={config.baseline}") |
| for i, sample in enumerate(samples): |
| print(f" Sample {i + 1}/{len(samples)}: {sample.sample_id}") |
| adapter = adapter_factory(config.baseline) |
| sess, qa = run_domain_a_v2_sample( |
| adapter, |
| sample, |
| answer_fn=_answer, |
| ) |
| session_records.extend(sess) |
| qa_records.extend(qa) |
|
|
| |
| _write_jsonl(out / _CHECKPOINT_SESSIONS, |
| [_record_to_json_obj(r) for r in session_records]) |
| _write_jsonl(out / _CHECKPOINT_QA, |
| [_record_to_json_obj(r) for r in qa_records]) |
| print(f"[Checkpoint] Saved {len(session_records)} sessions + {len(qa_records)} QA to {out}") |
|
|
| |
| print(f"[Eval] Evaluating {len(session_records)} sessions + {len(qa_records)} QA with LLM judge (workers={max_eval_workers})...") |
|
|
| session_evals: list[dict[str, object] | None] = [None] * len(session_records) |
| qa_evals: list[dict[str, object] | None] = [None] * len(qa_records) |
|
|
| with ThreadPoolExecutor(max_workers=max_eval_workers) as pool: |
| |
| session_futures = {} |
| for idx, srec in enumerate(session_records): |
| fut = pool.submit(evaluate_extraction, srec) |
| session_futures[fut] = idx |
|
|
| |
| qa_futures = {} |
| for idx, qrec in enumerate(qa_records): |
| fut = pool.submit(evaluate_checkpoint_qa, qrec) |
| qa_futures[fut] = idx |
|
|
| |
| done_sessions = 0 |
| for fut in as_completed(session_futures): |
| idx = session_futures[fut] |
| try: |
| session_evals[idx] = fut.result() |
| except Exception as e: |
| session_evals[idx] = {"error": str(e)} |
| done_sessions += 1 |
| if done_sessions % 10 == 0 or done_sessions == len(session_records): |
| print(f" Sessions: {done_sessions}/{len(session_records)} done") |
|
|
| |
| done_qa = 0 |
| for fut in as_completed(qa_futures): |
| idx = qa_futures[fut] |
| try: |
| qa_evals[idx] = fut.result() |
| except Exception as e: |
| qa_evals[idx] = {"error": str(e)} |
| done_qa += 1 |
| if done_qa % 20 == 0 or done_qa == len(qa_records): |
| print(f" QA: {done_qa}/{len(qa_records)} done") |
|
|
| |
| agg = aggregate_metrics( |
| config.baseline, |
| session_evaluations=[e for e in session_evals if e is not None], |
| qa_evaluations=[e for e in qa_evals if e is not None], |
| ) |
|
|
| session_rows = [] |
| for srec, s_eval in zip(session_records, session_evals): |
| row = _record_to_json_obj(srec) |
| row["eval"] = s_eval |
| session_rows.append(row) |
|
|
| qa_rows = [] |
| for qrec, q_eval in zip(qa_records, qa_evals): |
| row = _record_to_json_obj(qrec) |
| row["eval"] = q_eval |
| qa_rows.append(row) |
|
|
| _write_jsonl(out / "session_records.jsonl", session_rows) |
| _write_jsonl(out / "qa_records.jsonl", qa_rows) |
| _write_json(out / "aggregate_metrics.json", agg) |
|
|
| print(f"\n[Done] Results written to {out}") |
| print(f" Aggregate: {json.dumps(agg, indent=2)}") |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| p = argparse.ArgumentParser(prog="eval_framework") |
| p.add_argument("--dataset", required=True) |
| p.add_argument("--baseline", required=True) |
| p.add_argument("--output-dir", default="eval_framework/results") |
| p.add_argument("--smoke", action="store_true") |
| p.add_argument("--dry-run", action="store_true") |
| p.add_argument("--eval-only", action="store_true", |
| help="Skip pipeline, load from checkpoint in output-dir.") |
| p.add_argument("--max-eval-workers", type=int, default=5, |
| help="Parallel threads for eval stage (default 5).") |
| return p |
|
|
|
|
| def main(argv: list[str] | None = None) -> None: |
| parser = build_parser() |
| args = parser.parse_args(argv) |
| cfg = config_from_namespace(args) |
|
|
| if cfg.dry_run: |
| print(json.dumps(cfg.to_display_dict(), indent=2)) |
| return |
|
|
| eval_only = bool(args.eval_only) |
|
|
| if not eval_only and not cfg.dataset_path.is_dir(): |
| raise SystemExit(f"Dataset path is not a directory: {cfg.dataset_path}") |
|
|
| run_eval(cfg, max_eval_workers=args.max_eval_workers, eval_only=eval_only) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|