eval_framework / cli.py
LCZZZZ's picture
Upload eval_framework source code
85b19cf verified
"""CLI: dataset path, baseline, output dir, dry-run, smoke eval.
Evaluation uses batch LLM judge: 2 calls/session + 2 calls/QA.
Session and QA evaluations run in parallel via ThreadPoolExecutor.
Pipeline results are checkpointed before eval so --eval-only can resume.
"""
from __future__ import annotations
import argparse
import json
import os
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import asdict
from pathlib import Path
from typing import Any
try:
from openai import OpenAI
except ImportError:
OpenAI = None # type: ignore[assignment]
from eval_framework.config import EvalConfig
from eval_framework.datasets.domain_a_v2 import (
DomainAV2AcademicBundle,
NormalizedCheckpointQuestion,
load_domain_a_v2_academic,
)
from eval_framework.datasets.schemas import (
MemoryDeltaRecord,
MemorySnapshotRecord,
RetrievalItem,
RetrievalRecord,
)
from eval_framework.evaluators.aggregate import aggregate_metrics
from eval_framework.evaluators.extraction import evaluate_extraction
from eval_framework.evaluators.qa import evaluate_checkpoint_qa
from eval_framework.memory_adapters.base import MemoryAdapter
from eval_framework.openai_compat import (
patch_openai_chat_completions,
rewrite_chat_completion_kwargs,
)
from eval_framework.pipeline.gold_state import GoldMemoryPoint, SessionGoldState
from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord
from eval_framework.pipeline.runner import run_domain_a_v2_sample
_CHECKPOINT_SESSIONS = "pipeline_sessions.jsonl"
_CHECKPOINT_QA = "pipeline_qa.jsonl"
# ---------------------------------------------------------------------------
# Checkpoint deserialization: dict -> frozen dataclass
# ---------------------------------------------------------------------------
def _gold_point_from_dict(d: dict[str, Any]) -> GoldMemoryPoint:
return GoldMemoryPoint(
memory_id=d["memory_id"],
memory_content=d["memory_content"],
memory_type=d["memory_type"],
memory_source=d["memory_source"],
is_update=bool(d["is_update"]),
original_memories=tuple(d.get("original_memories") or ()),
importance=float(d.get("importance", 0.0)),
timestamp=d.get("timestamp"),
update_type=d.get("update_type", ""),
)
def _gold_state_from_dict(d: dict[str, Any]) -> SessionGoldState:
return SessionGoldState(
session_id=d["session_id"],
cumulative_gold_memories=tuple(_gold_point_from_dict(g) for g in d["cumulative_gold_memories"]),
session_new_memories=tuple(_gold_point_from_dict(g) for g in d["session_new_memories"]),
session_update_memories=tuple(_gold_point_from_dict(g) for g in d["session_update_memories"]),
session_interference_memories=tuple(_gold_point_from_dict(g) for g in d["session_interference_memories"]),
)
def _snapshot_record_from_dict(d: dict[str, Any]) -> MemorySnapshotRecord:
return MemorySnapshotRecord(
memory_id=d["memory_id"],
text=d["text"],
session_id=d["session_id"],
status=d["status"],
source=d.get("source"),
raw_backend_id=d.get("raw_backend_id"),
raw_backend_type=d.get("raw_backend_type"),
metadata=d.get("metadata") or {},
)
def _delta_record_from_dict(d: dict[str, Any]) -> MemoryDeltaRecord:
return MemoryDeltaRecord(
session_id=d["session_id"],
op=d["op"],
text=d["text"],
linked_previous=tuple(d.get("linked_previous") or ()),
raw_backend_id=d.get("raw_backend_id"),
metadata=d.get("metadata") or {},
)
def _retrieval_item_from_dict(d: dict[str, Any]) -> RetrievalItem:
return RetrievalItem(
rank=int(d["rank"]),
memory_id=d["memory_id"],
text=d["text"],
score=float(d["score"]),
raw_backend_id=d.get("raw_backend_id"),
)
def _retrieval_record_from_dict(d: dict[str, Any]) -> RetrievalRecord:
return RetrievalRecord(
query=d["query"],
top_k=int(d["top_k"]),
items=[_retrieval_item_from_dict(i) for i in d["items"]],
raw_trace=d.get("raw_trace") or {},
)
def _session_record_from_dict(d: dict[str, Any]) -> PipelineSessionRecord:
return PipelineSessionRecord(
sample_id=d["sample_id"],
sample_uuid=d["sample_uuid"],
session_id=d["session_id"],
memory_snapshot=tuple(_snapshot_record_from_dict(s) for s in d["memory_snapshot"]),
memory_delta=tuple(_delta_record_from_dict(dl) for dl in d["memory_delta"]),
gold_state=_gold_state_from_dict(d["gold_state"]),
)
def _qa_record_from_dict(d: dict[str, Any]) -> PipelineCheckpointQARecord:
return PipelineCheckpointQARecord(
sample_id=d["sample_id"],
sample_uuid=d["sample_uuid"],
checkpoint_id=d["checkpoint_id"],
question=d["question"],
gold_answer=d["gold_answer"],
gold_evidence_memory_ids=tuple(d.get("gold_evidence_memory_ids") or ()),
gold_evidence_contents=tuple(d.get("gold_evidence_contents") or ()),
question_type=d["question_type"],
question_type_abbrev=d["question_type_abbrev"],
difficulty=d["difficulty"],
retrieval=_retrieval_record_from_dict(d["retrieval"]),
generated_answer=d["generated_answer"],
cited_memories=tuple(d.get("cited_memories") or ()),
)
def _read_jsonl(path: Path) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def _load_pipeline_checkpoint(
output_dir: Path,
) -> tuple[list[PipelineSessionRecord], list[PipelineCheckpointQARecord]]:
"""Restore pipeline records from checkpoint JSONL files."""
sess_path = output_dir / _CHECKPOINT_SESSIONS
qa_path = output_dir / _CHECKPOINT_QA
if not sess_path.exists() or not qa_path.exists():
raise SystemExit(
f"Checkpoint files not found in {output_dir}. "
f"Run without --eval-only first to generate them."
)
session_records = [_session_record_from_dict(d) for d in _read_jsonl(sess_path)]
qa_records = [_qa_record_from_dict(d) for d in _read_jsonl(qa_path)]
return session_records, qa_records
def _default_create_adapter(baseline_name: str) -> MemoryAdapter:
from eval_framework.memory_adapters import registry as reg
if baseline_name in reg.MEMGALLERY_NATIVE_REGISTRY:
return reg.MEMGALLERY_NATIVE_REGISTRY[baseline_name]()
if baseline_name in reg.EXTERNAL_ADAPTER_REGISTRY:
return reg.EXTERNAL_ADAPTER_REGISTRY[baseline_name]()
known = sorted(
reg.MEMGALLERY_NATIVE_BASELINES | reg.EXTERNAL_ADAPTER_KEYS
)
raise SystemExit(
f"Unknown baseline {baseline_name!r}. "
f"Expected one of: {', '.join(known)}"
)
def _gold_echo_answer(
q: NormalizedCheckpointQuestion, _retrieval: RetrievalRecord
) -> tuple[str, list[str]]:
return q.gold_answer, []
def _parse_answer_json(raw: str) -> tuple[str, list[str]]:
"""Extract answer and cited_memories from the model's JSON response."""
# Try to parse as JSON first
try:
data = json.loads(raw)
answer = str(data.get("answer", ""))
cited = data.get("cited_memories", [])
if isinstance(cited, list):
return answer, [str(c) for c in cited]
return answer, []
except (json.JSONDecodeError, TypeError):
pass
# Fallback: try to find JSON block in the response
import re
m = re.search(r"\{[\s\S]*\}", raw)
if m:
try:
data = json.loads(m.group())
answer = str(data.get("answer", ""))
cited = data.get("cited_memories", [])
if isinstance(cited, list):
return answer, [str(c) for c in cited]
except (json.JSONDecodeError, TypeError):
pass
# Final fallback: treat entire response as the answer, no citations
return raw.strip(), []
def build_default_answer_fn() -> Callable[
[NormalizedCheckpointQuestion, RetrievalRecord], tuple[str, list[str]]
]:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key or OpenAI is None:
return _gold_echo_answer
client = OpenAI(
api_key=api_key,
base_url=os.getenv("OPENAI_BASE_URL"),
)
model = os.getenv("OPENAI_MODEL") or "gpt-4o"
temperature = float(os.getenv("OPENAI_TEMPERATURE", "0.0"))
max_tokens = int(os.getenv("OPENAI_MAX_TOKENS", "1024"))
def _answer(
q: NormalizedCheckpointQuestion, retrieval: RetrievalRecord
) -> tuple[str, list[str]]:
context_lines = [
f"[{item.rank}] {item.text}" for item in retrieval.items[: retrieval.top_k]
]
context = "\n".join(context_lines) if context_lines else "No retrieved memories."
prompt = (
"Answer the user's question using only the retrieved memories below. "
"If the memories are insufficient, answer exactly: Not mentioned in memory.\n\n"
"You MUST also list the specific memory passages you relied on to produce "
"the answer. Copy the relevant text verbatim from the retrieved memories.\n\n"
f"Question: {q.question}\n\n"
f"Retrieved memories:\n{context}\n\n"
'Respond in JSON:\n'
'{\n'
' "answer": "your concise answer",\n'
' "cited_memories": ["verbatim passage 1", "verbatim passage 2"]\n'
'}\n'
)
request_kwargs = rewrite_chat_completion_kwargs(
{
"model": model,
"messages": [
{
"role": "system",
"content": (
"You answer benchmark questions using only supplied memory context. "
"Be concise and do not invent missing facts. "
"Always respond in the requested JSON format."
),
},
{"role": "user", "content": prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
}
)
response = client.chat.completions.create(**request_kwargs)
raw = response.choices[0].message.content or ""
return _parse_answer_json(raw)
return _answer
def config_from_namespace(ns: argparse.Namespace) -> EvalConfig:
return EvalConfig(
dataset_path=Path(ns.dataset).expanduser().resolve(),
output_dir=Path(ns.output_dir).expanduser().resolve(),
baseline=str(ns.baseline),
smoke=bool(ns.smoke),
dry_run=bool(ns.dry_run),
)
def _record_to_json_obj(obj: Any) -> dict[str, Any]:
return asdict(obj)
def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as fh:
for row in rows:
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
def _write_json(path: Path, payload: dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(
json.dumps(payload, ensure_ascii=False, indent=2) + "\n",
encoding="utf-8",
)
def run_eval(
config: EvalConfig,
*,
load_domain_bundle: Callable[[Path], DomainAV2AcademicBundle] = load_domain_a_v2_academic,
create_adapter: Callable[[str], MemoryAdapter] | None = None,
answer_fn: Callable | None = None,
max_eval_workers: int = 5,
eval_only: bool = False,
) -> None:
"""Load data, run pipeline (serial) + LLM eval (parallel)."""
patch_openai_chat_completions()
if config.dry_run:
return
out = config.output_dir
out.mkdir(parents=True, exist_ok=True)
if eval_only:
# --- Resume from checkpoint ---
print(f"[Eval-only] Loading pipeline checkpoint from {out}")
session_records, qa_records = _load_pipeline_checkpoint(out)
print(f"[Eval-only] Loaded {len(session_records)} sessions + {len(qa_records)} QA records")
else:
# --- Stage 1: Pipeline (serial — adapter is stateful) ---
adapter_factory = create_adapter or _default_create_adapter
bundle = load_domain_bundle(config.dataset_path)
samples = bundle.samples[:1] if config.smoke else bundle.samples
_answer = answer_fn if answer_fn is not None else build_default_answer_fn()
session_records: list[PipelineSessionRecord] = []
qa_records: list[PipelineCheckpointQARecord] = []
print(f"[Pipeline] Running {len(samples)} sample(s) with baseline={config.baseline}")
for i, sample in enumerate(samples):
print(f" Sample {i + 1}/{len(samples)}: {sample.sample_id}")
adapter = adapter_factory(config.baseline)
sess, qa = run_domain_a_v2_sample(
adapter,
sample,
answer_fn=_answer,
)
session_records.extend(sess)
qa_records.extend(qa)
# --- Save checkpoint ---
_write_jsonl(out / _CHECKPOINT_SESSIONS,
[_record_to_json_obj(r) for r in session_records])
_write_jsonl(out / _CHECKPOINT_QA,
[_record_to_json_obj(r) for r in qa_records])
print(f"[Checkpoint] Saved {len(session_records)} sessions + {len(qa_records)} QA to {out}")
# --- Stage 2: Eval (parallel — each record is self-contained) ---
print(f"[Eval] Evaluating {len(session_records)} sessions + {len(qa_records)} QA with LLM judge (workers={max_eval_workers})...")
session_evals: list[dict[str, object] | None] = [None] * len(session_records)
qa_evals: list[dict[str, object] | None] = [None] * len(qa_records)
with ThreadPoolExecutor(max_workers=max_eval_workers) as pool:
# Submit session evals
session_futures = {}
for idx, srec in enumerate(session_records):
fut = pool.submit(evaluate_extraction, srec)
session_futures[fut] = idx
# Submit QA evals
qa_futures = {}
for idx, qrec in enumerate(qa_records):
fut = pool.submit(evaluate_checkpoint_qa, qrec)
qa_futures[fut] = idx
# Collect session results
done_sessions = 0
for fut in as_completed(session_futures):
idx = session_futures[fut]
try:
session_evals[idx] = fut.result()
except Exception as e:
session_evals[idx] = {"error": str(e)}
done_sessions += 1
if done_sessions % 10 == 0 or done_sessions == len(session_records):
print(f" Sessions: {done_sessions}/{len(session_records)} done")
# Collect QA results
done_qa = 0
for fut in as_completed(qa_futures):
idx = qa_futures[fut]
try:
qa_evals[idx] = fut.result()
except Exception as e:
qa_evals[idx] = {"error": str(e)}
done_qa += 1
if done_qa % 20 == 0 or done_qa == len(qa_records):
print(f" QA: {done_qa}/{len(qa_records)} done")
# --- Stage 3: Aggregate + write ---
agg = aggregate_metrics(
config.baseline,
session_evaluations=[e for e in session_evals if e is not None],
qa_evaluations=[e for e in qa_evals if e is not None],
)
session_rows = []
for srec, s_eval in zip(session_records, session_evals):
row = _record_to_json_obj(srec)
row["eval"] = s_eval
session_rows.append(row)
qa_rows = []
for qrec, q_eval in zip(qa_records, qa_evals):
row = _record_to_json_obj(qrec)
row["eval"] = q_eval
qa_rows.append(row)
_write_jsonl(out / "session_records.jsonl", session_rows)
_write_jsonl(out / "qa_records.jsonl", qa_rows)
_write_json(out / "aggregate_metrics.json", agg)
print(f"\n[Done] Results written to {out}")
print(f" Aggregate: {json.dumps(agg, indent=2)}")
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(prog="eval_framework")
p.add_argument("--dataset", required=True)
p.add_argument("--baseline", required=True)
p.add_argument("--output-dir", default="eval_framework/results")
p.add_argument("--smoke", action="store_true")
p.add_argument("--dry-run", action="store_true")
p.add_argument("--eval-only", action="store_true",
help="Skip pipeline, load from checkpoint in output-dir.")
p.add_argument("--max-eval-workers", type=int, default=5,
help="Parallel threads for eval stage (default 5).")
return p
def main(argv: list[str] | None = None) -> None:
parser = build_parser()
args = parser.parse_args(argv)
cfg = config_from_namespace(args)
if cfg.dry_run:
print(json.dumps(cfg.to_display_dict(), indent=2))
return
eval_only = bool(args.eval_only)
if not eval_only and not cfg.dataset_path.is_dir():
raise SystemExit(f"Dataset path is not a directory: {cfg.dataset_path}")
run_eval(cfg, max_eval_workers=args.max_eval_workers, eval_only=eval_only)
if __name__ == "__main__":
main()