Upload eval_framework source code
Browse files- .gitignore +5 -0
- __init__.py +5 -0
- cli.py +471 -0
- config.py +30 -0
- datasets/__init__.py +1 -0
- datasets/convert_vistrajqa.py +452 -0
- datasets/domain_a_v2.py +286 -0
- datasets/schemas.py +106 -0
- docs/DATA_CONVERSION.md +142 -0
- docs/EXPERIMENTS.md +87 -0
- docs/GUIDE.md +423 -0
- docs/OUTPUT_FORMAT.md +261 -0
- evaluators/__init__.py +11 -0
- evaluators/aggregate.py +175 -0
- evaluators/extraction.py +193 -0
- evaluators/qa.py +70 -0
- judges/__init__.py +215 -0
- judges/llm_client.py +156 -0
- judges/prompts.py +223 -0
- memory_adapters/__init__.py +27 -0
- memory_adapters/amem.py +258 -0
- memory_adapters/amem_v2.py +142 -0
- memory_adapters/base.py +45 -0
- memory_adapters/dummy.py +118 -0
- memory_adapters/export_utils.py +123 -0
- memory_adapters/mem0_adapter.py +185 -0
- memory_adapters/memgallery_native.py +395 -0
- memory_adapters/memoryos.py +357 -0
- memory_adapters/memverse_adapter.py +203 -0
- memory_adapters/registry.py +410 -0
- memory_adapters/simplemem_adapter.py +156 -0
- memory_adapters/zep_adapter.py +122 -0
- openai_compat.py +49 -0
- pipeline/__init__.py +1 -0
- pipeline/gold_state.py +130 -0
- pipeline/qa_runner.py +59 -0
- pipeline/records.py +60 -0
- pipeline/runner.py +104 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
results/
|
| 2 |
+
converted/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
*.jsonl
|
__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unified memory evaluation framework (package scaffold)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
__all__: list[str] = []
|
cli.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI: dataset path, baseline, output dir, dry-run, smoke eval.
|
| 2 |
+
|
| 3 |
+
Evaluation uses batch LLM judge: 2 calls/session + 2 calls/QA.
|
| 4 |
+
Session and QA evaluations run in parallel via ThreadPoolExecutor.
|
| 5 |
+
Pipeline results are checkpointed before eval so --eval-only can resume.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
from collections.abc import Callable
|
| 14 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 15 |
+
from dataclasses import asdict
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from openai import OpenAI
|
| 21 |
+
except ImportError:
|
| 22 |
+
OpenAI = None # type: ignore[assignment]
|
| 23 |
+
|
| 24 |
+
from eval_framework.config import EvalConfig
|
| 25 |
+
from eval_framework.datasets.domain_a_v2 import (
|
| 26 |
+
DomainAV2AcademicBundle,
|
| 27 |
+
NormalizedCheckpointQuestion,
|
| 28 |
+
load_domain_a_v2_academic,
|
| 29 |
+
)
|
| 30 |
+
from eval_framework.datasets.schemas import (
|
| 31 |
+
MemoryDeltaRecord,
|
| 32 |
+
MemorySnapshotRecord,
|
| 33 |
+
RetrievalItem,
|
| 34 |
+
RetrievalRecord,
|
| 35 |
+
)
|
| 36 |
+
from eval_framework.evaluators.aggregate import aggregate_metrics
|
| 37 |
+
from eval_framework.evaluators.extraction import evaluate_extraction
|
| 38 |
+
from eval_framework.evaluators.qa import evaluate_checkpoint_qa
|
| 39 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 40 |
+
from eval_framework.openai_compat import (
|
| 41 |
+
patch_openai_chat_completions,
|
| 42 |
+
rewrite_chat_completion_kwargs,
|
| 43 |
+
)
|
| 44 |
+
from eval_framework.pipeline.gold_state import GoldMemoryPoint, SessionGoldState
|
| 45 |
+
from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord
|
| 46 |
+
from eval_framework.pipeline.runner import run_domain_a_v2_sample
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
_CHECKPOINT_SESSIONS = "pipeline_sessions.jsonl"
|
| 50 |
+
_CHECKPOINT_QA = "pipeline_qa.jsonl"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Checkpoint deserialization: dict -> frozen dataclass
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
def _gold_point_from_dict(d: dict[str, Any]) -> GoldMemoryPoint:
|
| 58 |
+
return GoldMemoryPoint(
|
| 59 |
+
memory_id=d["memory_id"],
|
| 60 |
+
memory_content=d["memory_content"],
|
| 61 |
+
memory_type=d["memory_type"],
|
| 62 |
+
memory_source=d["memory_source"],
|
| 63 |
+
is_update=bool(d["is_update"]),
|
| 64 |
+
original_memories=tuple(d.get("original_memories") or ()),
|
| 65 |
+
importance=float(d.get("importance", 0.0)),
|
| 66 |
+
timestamp=d.get("timestamp"),
|
| 67 |
+
update_type=d.get("update_type", ""),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _gold_state_from_dict(d: dict[str, Any]) -> SessionGoldState:
|
| 72 |
+
return SessionGoldState(
|
| 73 |
+
session_id=d["session_id"],
|
| 74 |
+
cumulative_gold_memories=tuple(_gold_point_from_dict(g) for g in d["cumulative_gold_memories"]),
|
| 75 |
+
session_new_memories=tuple(_gold_point_from_dict(g) for g in d["session_new_memories"]),
|
| 76 |
+
session_update_memories=tuple(_gold_point_from_dict(g) for g in d["session_update_memories"]),
|
| 77 |
+
session_interference_memories=tuple(_gold_point_from_dict(g) for g in d["session_interference_memories"]),
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _snapshot_record_from_dict(d: dict[str, Any]) -> MemorySnapshotRecord:
|
| 82 |
+
return MemorySnapshotRecord(
|
| 83 |
+
memory_id=d["memory_id"],
|
| 84 |
+
text=d["text"],
|
| 85 |
+
session_id=d["session_id"],
|
| 86 |
+
status=d["status"],
|
| 87 |
+
source=d.get("source"),
|
| 88 |
+
raw_backend_id=d.get("raw_backend_id"),
|
| 89 |
+
raw_backend_type=d.get("raw_backend_type"),
|
| 90 |
+
metadata=d.get("metadata") or {},
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _delta_record_from_dict(d: dict[str, Any]) -> MemoryDeltaRecord:
|
| 95 |
+
return MemoryDeltaRecord(
|
| 96 |
+
session_id=d["session_id"],
|
| 97 |
+
op=d["op"],
|
| 98 |
+
text=d["text"],
|
| 99 |
+
linked_previous=tuple(d.get("linked_previous") or ()),
|
| 100 |
+
raw_backend_id=d.get("raw_backend_id"),
|
| 101 |
+
metadata=d.get("metadata") or {},
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _retrieval_item_from_dict(d: dict[str, Any]) -> RetrievalItem:
|
| 106 |
+
return RetrievalItem(
|
| 107 |
+
rank=int(d["rank"]),
|
| 108 |
+
memory_id=d["memory_id"],
|
| 109 |
+
text=d["text"],
|
| 110 |
+
score=float(d["score"]),
|
| 111 |
+
raw_backend_id=d.get("raw_backend_id"),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def _retrieval_record_from_dict(d: dict[str, Any]) -> RetrievalRecord:
|
| 116 |
+
return RetrievalRecord(
|
| 117 |
+
query=d["query"],
|
| 118 |
+
top_k=int(d["top_k"]),
|
| 119 |
+
items=[_retrieval_item_from_dict(i) for i in d["items"]],
|
| 120 |
+
raw_trace=d.get("raw_trace") or {},
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _session_record_from_dict(d: dict[str, Any]) -> PipelineSessionRecord:
|
| 125 |
+
return PipelineSessionRecord(
|
| 126 |
+
sample_id=d["sample_id"],
|
| 127 |
+
sample_uuid=d["sample_uuid"],
|
| 128 |
+
session_id=d["session_id"],
|
| 129 |
+
memory_snapshot=tuple(_snapshot_record_from_dict(s) for s in d["memory_snapshot"]),
|
| 130 |
+
memory_delta=tuple(_delta_record_from_dict(dl) for dl in d["memory_delta"]),
|
| 131 |
+
gold_state=_gold_state_from_dict(d["gold_state"]),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _qa_record_from_dict(d: dict[str, Any]) -> PipelineCheckpointQARecord:
|
| 136 |
+
return PipelineCheckpointQARecord(
|
| 137 |
+
sample_id=d["sample_id"],
|
| 138 |
+
sample_uuid=d["sample_uuid"],
|
| 139 |
+
checkpoint_id=d["checkpoint_id"],
|
| 140 |
+
question=d["question"],
|
| 141 |
+
gold_answer=d["gold_answer"],
|
| 142 |
+
gold_evidence_memory_ids=tuple(d.get("gold_evidence_memory_ids") or ()),
|
| 143 |
+
gold_evidence_contents=tuple(d.get("gold_evidence_contents") or ()),
|
| 144 |
+
question_type=d["question_type"],
|
| 145 |
+
question_type_abbrev=d["question_type_abbrev"],
|
| 146 |
+
difficulty=d["difficulty"],
|
| 147 |
+
retrieval=_retrieval_record_from_dict(d["retrieval"]),
|
| 148 |
+
generated_answer=d["generated_answer"],
|
| 149 |
+
cited_memories=tuple(d.get("cited_memories") or ()),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _read_jsonl(path: Path) -> list[dict[str, Any]]:
|
| 154 |
+
rows: list[dict[str, Any]] = []
|
| 155 |
+
with path.open("r", encoding="utf-8") as fh:
|
| 156 |
+
for line in fh:
|
| 157 |
+
line = line.strip()
|
| 158 |
+
if line:
|
| 159 |
+
rows.append(json.loads(line))
|
| 160 |
+
return rows
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _load_pipeline_checkpoint(
|
| 164 |
+
output_dir: Path,
|
| 165 |
+
) -> tuple[list[PipelineSessionRecord], list[PipelineCheckpointQARecord]]:
|
| 166 |
+
"""Restore pipeline records from checkpoint JSONL files."""
|
| 167 |
+
sess_path = output_dir / _CHECKPOINT_SESSIONS
|
| 168 |
+
qa_path = output_dir / _CHECKPOINT_QA
|
| 169 |
+
if not sess_path.exists() or not qa_path.exists():
|
| 170 |
+
raise SystemExit(
|
| 171 |
+
f"Checkpoint files not found in {output_dir}. "
|
| 172 |
+
f"Run without --eval-only first to generate them."
|
| 173 |
+
)
|
| 174 |
+
session_records = [_session_record_from_dict(d) for d in _read_jsonl(sess_path)]
|
| 175 |
+
qa_records = [_qa_record_from_dict(d) for d in _read_jsonl(qa_path)]
|
| 176 |
+
return session_records, qa_records
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _default_create_adapter(baseline_name: str) -> MemoryAdapter:
|
| 180 |
+
from eval_framework.memory_adapters import registry as reg
|
| 181 |
+
|
| 182 |
+
if baseline_name in reg.MEMGALLERY_NATIVE_REGISTRY:
|
| 183 |
+
return reg.MEMGALLERY_NATIVE_REGISTRY[baseline_name]()
|
| 184 |
+
if baseline_name in reg.EXTERNAL_ADAPTER_REGISTRY:
|
| 185 |
+
return reg.EXTERNAL_ADAPTER_REGISTRY[baseline_name]()
|
| 186 |
+
known = sorted(
|
| 187 |
+
reg.MEMGALLERY_NATIVE_BASELINES | reg.EXTERNAL_ADAPTER_KEYS
|
| 188 |
+
)
|
| 189 |
+
raise SystemExit(
|
| 190 |
+
f"Unknown baseline {baseline_name!r}. "
|
| 191 |
+
f"Expected one of: {', '.join(known)}"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _gold_echo_answer(
|
| 196 |
+
q: NormalizedCheckpointQuestion, _retrieval: RetrievalRecord
|
| 197 |
+
) -> tuple[str, list[str]]:
|
| 198 |
+
return q.gold_answer, []
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _parse_answer_json(raw: str) -> tuple[str, list[str]]:
|
| 202 |
+
"""Extract answer and cited_memories from the model's JSON response."""
|
| 203 |
+
# Try to parse as JSON first
|
| 204 |
+
try:
|
| 205 |
+
data = json.loads(raw)
|
| 206 |
+
answer = str(data.get("answer", ""))
|
| 207 |
+
cited = data.get("cited_memories", [])
|
| 208 |
+
if isinstance(cited, list):
|
| 209 |
+
return answer, [str(c) for c in cited]
|
| 210 |
+
return answer, []
|
| 211 |
+
except (json.JSONDecodeError, TypeError):
|
| 212 |
+
pass
|
| 213 |
+
# Fallback: try to find JSON block in the response
|
| 214 |
+
import re
|
| 215 |
+
m = re.search(r"\{[\s\S]*\}", raw)
|
| 216 |
+
if m:
|
| 217 |
+
try:
|
| 218 |
+
data = json.loads(m.group())
|
| 219 |
+
answer = str(data.get("answer", ""))
|
| 220 |
+
cited = data.get("cited_memories", [])
|
| 221 |
+
if isinstance(cited, list):
|
| 222 |
+
return answer, [str(c) for c in cited]
|
| 223 |
+
except (json.JSONDecodeError, TypeError):
|
| 224 |
+
pass
|
| 225 |
+
# Final fallback: treat entire response as the answer, no citations
|
| 226 |
+
return raw.strip(), []
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def build_default_answer_fn() -> Callable[
|
| 230 |
+
[NormalizedCheckpointQuestion, RetrievalRecord], tuple[str, list[str]]
|
| 231 |
+
]:
|
| 232 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 233 |
+
if not api_key or OpenAI is None:
|
| 234 |
+
return _gold_echo_answer
|
| 235 |
+
|
| 236 |
+
client = OpenAI(
|
| 237 |
+
api_key=api_key,
|
| 238 |
+
base_url=os.getenv("OPENAI_BASE_URL"),
|
| 239 |
+
)
|
| 240 |
+
model = os.getenv("OPENAI_MODEL") or "gpt-4o"
|
| 241 |
+
temperature = float(os.getenv("OPENAI_TEMPERATURE", "0.0"))
|
| 242 |
+
max_tokens = int(os.getenv("OPENAI_MAX_TOKENS", "1024"))
|
| 243 |
+
|
| 244 |
+
def _answer(
|
| 245 |
+
q: NormalizedCheckpointQuestion, retrieval: RetrievalRecord
|
| 246 |
+
) -> tuple[str, list[str]]:
|
| 247 |
+
context_lines = [
|
| 248 |
+
f"[{item.rank}] {item.text}" for item in retrieval.items[: retrieval.top_k]
|
| 249 |
+
]
|
| 250 |
+
context = "\n".join(context_lines) if context_lines else "No retrieved memories."
|
| 251 |
+
prompt = (
|
| 252 |
+
"Answer the user's question using only the retrieved memories below. "
|
| 253 |
+
"If the memories are insufficient, answer exactly: Not mentioned in memory.\n\n"
|
| 254 |
+
"You MUST also list the specific memory passages you relied on to produce "
|
| 255 |
+
"the answer. Copy the relevant text verbatim from the retrieved memories.\n\n"
|
| 256 |
+
f"Question: {q.question}\n\n"
|
| 257 |
+
f"Retrieved memories:\n{context}\n\n"
|
| 258 |
+
'Respond in JSON:\n'
|
| 259 |
+
'{\n'
|
| 260 |
+
' "answer": "your concise answer",\n'
|
| 261 |
+
' "cited_memories": ["verbatim passage 1", "verbatim passage 2"]\n'
|
| 262 |
+
'}\n'
|
| 263 |
+
)
|
| 264 |
+
request_kwargs = rewrite_chat_completion_kwargs(
|
| 265 |
+
{
|
| 266 |
+
"model": model,
|
| 267 |
+
"messages": [
|
| 268 |
+
{
|
| 269 |
+
"role": "system",
|
| 270 |
+
"content": (
|
| 271 |
+
"You answer benchmark questions using only supplied memory context. "
|
| 272 |
+
"Be concise and do not invent missing facts. "
|
| 273 |
+
"Always respond in the requested JSON format."
|
| 274 |
+
),
|
| 275 |
+
},
|
| 276 |
+
{"role": "user", "content": prompt},
|
| 277 |
+
],
|
| 278 |
+
"temperature": temperature,
|
| 279 |
+
"max_tokens": max_tokens,
|
| 280 |
+
}
|
| 281 |
+
)
|
| 282 |
+
response = client.chat.completions.create(**request_kwargs)
|
| 283 |
+
raw = response.choices[0].message.content or ""
|
| 284 |
+
return _parse_answer_json(raw)
|
| 285 |
+
|
| 286 |
+
return _answer
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def config_from_namespace(ns: argparse.Namespace) -> EvalConfig:
|
| 290 |
+
return EvalConfig(
|
| 291 |
+
dataset_path=Path(ns.dataset).expanduser().resolve(),
|
| 292 |
+
output_dir=Path(ns.output_dir).expanduser().resolve(),
|
| 293 |
+
baseline=str(ns.baseline),
|
| 294 |
+
smoke=bool(ns.smoke),
|
| 295 |
+
dry_run=bool(ns.dry_run),
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def _record_to_json_obj(obj: Any) -> dict[str, Any]:
|
| 300 |
+
return asdict(obj)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
|
| 304 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 305 |
+
with path.open("w", encoding="utf-8") as fh:
|
| 306 |
+
for row in rows:
|
| 307 |
+
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _write_json(path: Path, payload: dict[str, Any]) -> None:
|
| 311 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 312 |
+
path.write_text(
|
| 313 |
+
json.dumps(payload, ensure_ascii=False, indent=2) + "\n",
|
| 314 |
+
encoding="utf-8",
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def run_eval(
|
| 319 |
+
config: EvalConfig,
|
| 320 |
+
*,
|
| 321 |
+
load_domain_bundle: Callable[[Path], DomainAV2AcademicBundle] = load_domain_a_v2_academic,
|
| 322 |
+
create_adapter: Callable[[str], MemoryAdapter] | None = None,
|
| 323 |
+
answer_fn: Callable | None = None,
|
| 324 |
+
max_eval_workers: int = 5,
|
| 325 |
+
eval_only: bool = False,
|
| 326 |
+
) -> None:
|
| 327 |
+
"""Load data, run pipeline (serial) + LLM eval (parallel)."""
|
| 328 |
+
patch_openai_chat_completions()
|
| 329 |
+
if config.dry_run:
|
| 330 |
+
return
|
| 331 |
+
|
| 332 |
+
out = config.output_dir
|
| 333 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 334 |
+
|
| 335 |
+
if eval_only:
|
| 336 |
+
# --- Resume from checkpoint ---
|
| 337 |
+
print(f"[Eval-only] Loading pipeline checkpoint from {out}")
|
| 338 |
+
session_records, qa_records = _load_pipeline_checkpoint(out)
|
| 339 |
+
print(f"[Eval-only] Loaded {len(session_records)} sessions + {len(qa_records)} QA records")
|
| 340 |
+
else:
|
| 341 |
+
# --- Stage 1: Pipeline (serial — adapter is stateful) ---
|
| 342 |
+
adapter_factory = create_adapter or _default_create_adapter
|
| 343 |
+
bundle = load_domain_bundle(config.dataset_path)
|
| 344 |
+
samples = bundle.samples[:1] if config.smoke else bundle.samples
|
| 345 |
+
_answer = answer_fn if answer_fn is not None else build_default_answer_fn()
|
| 346 |
+
|
| 347 |
+
session_records: list[PipelineSessionRecord] = []
|
| 348 |
+
qa_records: list[PipelineCheckpointQARecord] = []
|
| 349 |
+
|
| 350 |
+
print(f"[Pipeline] Running {len(samples)} sample(s) with baseline={config.baseline}")
|
| 351 |
+
for i, sample in enumerate(samples):
|
| 352 |
+
print(f" Sample {i + 1}/{len(samples)}: {sample.sample_id}")
|
| 353 |
+
adapter = adapter_factory(config.baseline)
|
| 354 |
+
sess, qa = run_domain_a_v2_sample(
|
| 355 |
+
adapter,
|
| 356 |
+
sample,
|
| 357 |
+
answer_fn=_answer,
|
| 358 |
+
)
|
| 359 |
+
session_records.extend(sess)
|
| 360 |
+
qa_records.extend(qa)
|
| 361 |
+
|
| 362 |
+
# --- Save checkpoint ---
|
| 363 |
+
_write_jsonl(out / _CHECKPOINT_SESSIONS,
|
| 364 |
+
[_record_to_json_obj(r) for r in session_records])
|
| 365 |
+
_write_jsonl(out / _CHECKPOINT_QA,
|
| 366 |
+
[_record_to_json_obj(r) for r in qa_records])
|
| 367 |
+
print(f"[Checkpoint] Saved {len(session_records)} sessions + {len(qa_records)} QA to {out}")
|
| 368 |
+
|
| 369 |
+
# --- Stage 2: Eval (parallel — each record is self-contained) ---
|
| 370 |
+
print(f"[Eval] Evaluating {len(session_records)} sessions + {len(qa_records)} QA with LLM judge (workers={max_eval_workers})...")
|
| 371 |
+
|
| 372 |
+
session_evals: list[dict[str, object] | None] = [None] * len(session_records)
|
| 373 |
+
qa_evals: list[dict[str, object] | None] = [None] * len(qa_records)
|
| 374 |
+
|
| 375 |
+
with ThreadPoolExecutor(max_workers=max_eval_workers) as pool:
|
| 376 |
+
# Submit session evals
|
| 377 |
+
session_futures = {}
|
| 378 |
+
for idx, srec in enumerate(session_records):
|
| 379 |
+
fut = pool.submit(evaluate_extraction, srec)
|
| 380 |
+
session_futures[fut] = idx
|
| 381 |
+
|
| 382 |
+
# Submit QA evals
|
| 383 |
+
qa_futures = {}
|
| 384 |
+
for idx, qrec in enumerate(qa_records):
|
| 385 |
+
fut = pool.submit(evaluate_checkpoint_qa, qrec)
|
| 386 |
+
qa_futures[fut] = idx
|
| 387 |
+
|
| 388 |
+
# Collect session results
|
| 389 |
+
done_sessions = 0
|
| 390 |
+
for fut in as_completed(session_futures):
|
| 391 |
+
idx = session_futures[fut]
|
| 392 |
+
try:
|
| 393 |
+
session_evals[idx] = fut.result()
|
| 394 |
+
except Exception as e:
|
| 395 |
+
session_evals[idx] = {"error": str(e)}
|
| 396 |
+
done_sessions += 1
|
| 397 |
+
if done_sessions % 10 == 0 or done_sessions == len(session_records):
|
| 398 |
+
print(f" Sessions: {done_sessions}/{len(session_records)} done")
|
| 399 |
+
|
| 400 |
+
# Collect QA results
|
| 401 |
+
done_qa = 0
|
| 402 |
+
for fut in as_completed(qa_futures):
|
| 403 |
+
idx = qa_futures[fut]
|
| 404 |
+
try:
|
| 405 |
+
qa_evals[idx] = fut.result()
|
| 406 |
+
except Exception as e:
|
| 407 |
+
qa_evals[idx] = {"error": str(e)}
|
| 408 |
+
done_qa += 1
|
| 409 |
+
if done_qa % 20 == 0 or done_qa == len(qa_records):
|
| 410 |
+
print(f" QA: {done_qa}/{len(qa_records)} done")
|
| 411 |
+
|
| 412 |
+
# --- Stage 3: Aggregate + write ---
|
| 413 |
+
agg = aggregate_metrics(
|
| 414 |
+
config.baseline,
|
| 415 |
+
session_evaluations=[e for e in session_evals if e is not None],
|
| 416 |
+
qa_evaluations=[e for e in qa_evals if e is not None],
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
session_rows = []
|
| 420 |
+
for srec, s_eval in zip(session_records, session_evals):
|
| 421 |
+
row = _record_to_json_obj(srec)
|
| 422 |
+
row["eval"] = s_eval
|
| 423 |
+
session_rows.append(row)
|
| 424 |
+
|
| 425 |
+
qa_rows = []
|
| 426 |
+
for qrec, q_eval in zip(qa_records, qa_evals):
|
| 427 |
+
row = _record_to_json_obj(qrec)
|
| 428 |
+
row["eval"] = q_eval
|
| 429 |
+
qa_rows.append(row)
|
| 430 |
+
|
| 431 |
+
_write_jsonl(out / "session_records.jsonl", session_rows)
|
| 432 |
+
_write_jsonl(out / "qa_records.jsonl", qa_rows)
|
| 433 |
+
_write_json(out / "aggregate_metrics.json", agg)
|
| 434 |
+
|
| 435 |
+
print(f"\n[Done] Results written to {out}")
|
| 436 |
+
print(f" Aggregate: {json.dumps(agg, indent=2)}")
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 440 |
+
p = argparse.ArgumentParser(prog="eval_framework")
|
| 441 |
+
p.add_argument("--dataset", required=True)
|
| 442 |
+
p.add_argument("--baseline", required=True)
|
| 443 |
+
p.add_argument("--output-dir", default="eval_framework/results")
|
| 444 |
+
p.add_argument("--smoke", action="store_true")
|
| 445 |
+
p.add_argument("--dry-run", action="store_true")
|
| 446 |
+
p.add_argument("--eval-only", action="store_true",
|
| 447 |
+
help="Skip pipeline, load from checkpoint in output-dir.")
|
| 448 |
+
p.add_argument("--max-eval-workers", type=int, default=5,
|
| 449 |
+
help="Parallel threads for eval stage (default 5).")
|
| 450 |
+
return p
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def main(argv: list[str] | None = None) -> None:
|
| 454 |
+
parser = build_parser()
|
| 455 |
+
args = parser.parse_args(argv)
|
| 456 |
+
cfg = config_from_namespace(args)
|
| 457 |
+
|
| 458 |
+
if cfg.dry_run:
|
| 459 |
+
print(json.dumps(cfg.to_display_dict(), indent=2))
|
| 460 |
+
return
|
| 461 |
+
|
| 462 |
+
eval_only = bool(args.eval_only)
|
| 463 |
+
|
| 464 |
+
if not eval_only and not cfg.dataset_path.is_dir():
|
| 465 |
+
raise SystemExit(f"Dataset path is not a directory: {cfg.dataset_path}")
|
| 466 |
+
|
| 467 |
+
run_eval(cfg, max_eval_workers=args.max_eval_workers, eval_only=eval_only)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
if __name__ == "__main__":
|
| 471 |
+
main()
|
config.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration types for eval runs (CLI, dry-run, and smoke execution)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class EvalConfig:
|
| 12 |
+
"""Resolved paths and flags for one eval invocation."""
|
| 13 |
+
|
| 14 |
+
dataset_path: Path
|
| 15 |
+
output_dir: Path
|
| 16 |
+
baseline: str
|
| 17 |
+
smoke: bool = False
|
| 18 |
+
dry_run: bool = False
|
| 19 |
+
|
| 20 |
+
def to_display_dict(self) -> dict[str, Any]:
|
| 21 |
+
"""JSON-friendly snapshot for dry-run and logging."""
|
| 22 |
+
return {
|
| 23 |
+
"dataset_path": str(self.dataset_path),
|
| 24 |
+
"output_dir": str(self.output_dir),
|
| 25 |
+
"baseline": self.baseline,
|
| 26 |
+
"smoke": self.smoke,
|
| 27 |
+
"dry_run": self.dry_run,
|
| 28 |
+
"dataset_profile": "domain_a_v2_academic",
|
| 29 |
+
"judge": "llm (OpenAI API)",
|
| 30 |
+
}
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Dataset loaders and schema definitions."""
|
datasets/convert_vistrajqa.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Convert VisTrajQA sessions-*.jsonl → eval_framework domain_a_v2 format.
|
| 2 |
+
|
| 3 |
+
Reads one or more sessions-*.jsonl files and produces the three files
|
| 4 |
+
expected by ``load_domain_a_v2_academic``:
|
| 5 |
+
|
| 6 |
+
<output_dir>/
|
| 7 |
+
├── domain_a_v2.json
|
| 8 |
+
├── stage4_memory_points.jsonl
|
| 9 |
+
└── stage4b_qa_checkpoints.jsonl
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
python -m eval_framework.datasets.convert_vistrajqa \
|
| 13 |
+
--input data/generated/sessions-vab.jsonl \
|
| 14 |
+
data/generated/sessions-eb-nav.jsonl \
|
| 15 |
+
data/generated/sessions-arena.jsonl \
|
| 16 |
+
--output eval_framework/converted/all \
|
| 17 |
+
--text-only
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import json
|
| 24 |
+
import uuid as _uuid
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
from typing import Any
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# QA type abbreviation → full name
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
QA_TYPE_FULL = {
|
| 33 |
+
"FR": "factual_recall",
|
| 34 |
+
"DU": "dynamic_update",
|
| 35 |
+
"MB": "memory_boundary",
|
| 36 |
+
"TR": "temporal_reasoning",
|
| 37 |
+
"KR": "knowledge_reasoning",
|
| 38 |
+
"VFR": "visual_factual_recall",
|
| 39 |
+
"VS": "visual_search",
|
| 40 |
+
"VU": "visual_update",
|
| 41 |
+
"CMR": "cross_modal_reasoning",
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Turn construction
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
def _build_user_text(step: dict[str, Any], *, text_only: bool) -> str:
|
| 50 |
+
"""Build user turn text from a CanonicalStep.
|
| 51 |
+
|
| 52 |
+
User turn = what the agent perceives: observation + feedback + (caption if text_only).
|
| 53 |
+
"""
|
| 54 |
+
parts: list[str] = []
|
| 55 |
+
obs = step.get("observation") or ""
|
| 56 |
+
if obs:
|
| 57 |
+
parts.append(f"OBSERVATION: {obs}")
|
| 58 |
+
fb = step.get("feedback") or ""
|
| 59 |
+
if fb:
|
| 60 |
+
parts.append(f"FEEDBACK: {fb}")
|
| 61 |
+
if text_only:
|
| 62 |
+
cap = step.get("image_caption") or ""
|
| 63 |
+
if cap:
|
| 64 |
+
parts.append(f"IMAGE: {cap}")
|
| 65 |
+
if not parts:
|
| 66 |
+
parts.append("(no textual observation)")
|
| 67 |
+
return "\n".join(parts)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _build_assistant_text(step: dict[str, Any]) -> str:
|
| 71 |
+
"""Build assistant turn text: thought + action."""
|
| 72 |
+
parts: list[str] = []
|
| 73 |
+
thought = step.get("thought") or ""
|
| 74 |
+
if thought:
|
| 75 |
+
parts.append(f"THOUGHT: {thought}")
|
| 76 |
+
action = step.get("action") or ""
|
| 77 |
+
if action:
|
| 78 |
+
parts.append(f"ACTION: {action}")
|
| 79 |
+
return "\n".join(parts) or "(no action)"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _build_attachment(step: dict[str, Any], *, text_only: bool) -> list[dict[str, Any]]:
|
| 83 |
+
"""Build attachment list for a step (caption-only for text_only mode)."""
|
| 84 |
+
cap = step.get("image_caption") or ""
|
| 85 |
+
if not cap:
|
| 86 |
+
return []
|
| 87 |
+
if text_only:
|
| 88 |
+
# Caption already inlined in user text, no separate attachment needed
|
| 89 |
+
return []
|
| 90 |
+
image_id = step.get("image_id") or step.get("image_path") or ""
|
| 91 |
+
return [{"caption": cap, "type": "image_caption", "image_id": image_id}]
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
# Session segmentation
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
def _segment_steps_by_probes(
|
| 99 |
+
steps: list[dict[str, Any]],
|
| 100 |
+
probes: list[dict[str, Any]],
|
| 101 |
+
total_steps: int,
|
| 102 |
+
) -> list[tuple[str, list[dict[str, Any]]]]:
|
| 103 |
+
"""Split steps into sessions at probe boundaries.
|
| 104 |
+
|
| 105 |
+
Returns list of (session_id, steps_in_session).
|
| 106 |
+
Session after probe i covers steps (prev_boundary+1 .. probe_i.after_step_num].
|
| 107 |
+
The remainder after the last probe is the final session.
|
| 108 |
+
"""
|
| 109 |
+
probe_bounds = sorted(set(p["after_step_num"] for p in probes))
|
| 110 |
+
boundaries = [0] + probe_bounds + [total_steps]
|
| 111 |
+
|
| 112 |
+
sessions: list[tuple[str, list[dict[str, Any]]]] = []
|
| 113 |
+
for i in range(len(boundaries) - 1):
|
| 114 |
+
lo = boundaries[i] # exclusive lower bound (step_num > lo)
|
| 115 |
+
hi = boundaries[i + 1] # inclusive upper bound (step_num <= hi)
|
| 116 |
+
sid = f"S{i:02d}"
|
| 117 |
+
seg = [s for s in steps if lo < s["step_num"] <= hi]
|
| 118 |
+
if seg:
|
| 119 |
+
sessions.append((sid, seg))
|
| 120 |
+
|
| 121 |
+
return sessions
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _assign_mps_to_sessions(
|
| 125 |
+
memory_points: list[dict[str, Any]],
|
| 126 |
+
sessions: list[tuple[str, list[dict[str, Any]]]],
|
| 127 |
+
) -> dict[str, list[dict[str, Any]]]:
|
| 128 |
+
"""Map memory points to sessions by step_num range."""
|
| 129 |
+
# Build session_id → step_num range
|
| 130 |
+
ranges: list[tuple[str, int, int]] = []
|
| 131 |
+
for sid, seg in sessions:
|
| 132 |
+
lo = min(s["step_num"] for s in seg)
|
| 133 |
+
hi = max(s["step_num"] for s in seg)
|
| 134 |
+
ranges.append((sid, lo, hi))
|
| 135 |
+
|
| 136 |
+
result: dict[str, list[dict[str, Any]]] = {sid: [] for sid, _ in sessions}
|
| 137 |
+
|
| 138 |
+
for mp in memory_points:
|
| 139 |
+
sn = mp.get("step_num") or mp.get("probe_step_num") or 0
|
| 140 |
+
assigned = False
|
| 141 |
+
for sid, lo, hi in ranges:
|
| 142 |
+
if lo <= sn <= hi:
|
| 143 |
+
result[sid].append(mp)
|
| 144 |
+
assigned = True
|
| 145 |
+
break
|
| 146 |
+
if not assigned:
|
| 147 |
+
# Fallback: assign to last session
|
| 148 |
+
result[ranges[-1][0]].append(mp)
|
| 149 |
+
|
| 150 |
+
return result
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ---------------------------------------------------------------------------
|
| 154 |
+
# Memory point conversion
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
|
| 157 |
+
def _convert_mp(mp: dict[str, Any]) -> dict[str, Any]:
|
| 158 |
+
"""VisTrajQA memory point → eval_framework gold memory point dict."""
|
| 159 |
+
return {
|
| 160 |
+
"memory_id": mp.get("mp_id", ""),
|
| 161 |
+
"memory_content": mp.get("content", ""),
|
| 162 |
+
"memory_type": mp.get("type", ""),
|
| 163 |
+
"memory_source": mp.get("source", "primary"),
|
| 164 |
+
"is_update": bool(mp.get("is_update", False)),
|
| 165 |
+
"original_memories": mp.get("original_memories") or [],
|
| 166 |
+
"importance": float(mp.get("importance", 0.0)),
|
| 167 |
+
"timestamp": None,
|
| 168 |
+
"update_type": mp.get("update_type") or "",
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
# Question conversion
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
|
| 176 |
+
def _convert_question(
|
| 177 |
+
q: dict[str, Any],
|
| 178 |
+
mp_content_map: dict[str, str],
|
| 179 |
+
) -> dict[str, Any]:
|
| 180 |
+
"""VisTrajQA question → eval_framework checkpoint question dict."""
|
| 181 |
+
qa_type = q.get("qa_type", "FR")
|
| 182 |
+
evidence_ids = q.get("evidence") or []
|
| 183 |
+
return {
|
| 184 |
+
"question": q.get("question", ""),
|
| 185 |
+
"answer": q.get("answer", ""),
|
| 186 |
+
"question_type": QA_TYPE_FULL.get(qa_type, qa_type),
|
| 187 |
+
"question_type_abbrev": qa_type,
|
| 188 |
+
"difficulty": q.get("difficulty", "medium"),
|
| 189 |
+
"evidence": [{"memory_id": mid} for mid in evidence_ids],
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ---------------------------------------------------------------------------
|
| 194 |
+
# Main conversion
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
|
| 197 |
+
def convert_one_session(
|
| 198 |
+
rec: dict[str, Any],
|
| 199 |
+
*,
|
| 200 |
+
text_only: bool = True,
|
| 201 |
+
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
| 202 |
+
"""Convert one VisTrajQA session record → (sample_json, stage4_row, qa_row).
|
| 203 |
+
|
| 204 |
+
Returns dicts ready for serialization into the three target files.
|
| 205 |
+
"""
|
| 206 |
+
sample_id = rec["session_id"]
|
| 207 |
+
sample_uuid = str(_uuid.uuid5(_uuid.NAMESPACE_DNS, sample_id))
|
| 208 |
+
steps = rec["step_plan"]
|
| 209 |
+
total_steps = rec.get("total_steps") or len(steps)
|
| 210 |
+
probes = rec.get("probes") or []
|
| 211 |
+
post_qa = rec.get("post_trajectory_qa") or []
|
| 212 |
+
memory_points = rec.get("memory_points") or []
|
| 213 |
+
|
| 214 |
+
# mp_id → content map for evidence resolution
|
| 215 |
+
mp_content_map: dict[str, str] = {
|
| 216 |
+
mp["mp_id"]: mp.get("content", "")
|
| 217 |
+
for mp in memory_points
|
| 218 |
+
if mp.get("mp_id")
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
# --- Session segmentation ---
|
| 222 |
+
sessions = _segment_steps_by_probes(steps, probes, total_steps)
|
| 223 |
+
mp_by_session = _assign_mps_to_sessions(memory_points, sessions)
|
| 224 |
+
|
| 225 |
+
# --- Build domain_a_v2.json sample ---
|
| 226 |
+
session_objects: list[dict[str, Any]] = []
|
| 227 |
+
for sid, seg_steps in sessions:
|
| 228 |
+
dialogue: list[dict[str, Any]] = []
|
| 229 |
+
for step in seg_steps:
|
| 230 |
+
# User turn
|
| 231 |
+
user_text = _build_user_text(step, text_only=text_only)
|
| 232 |
+
dialogue.append({
|
| 233 |
+
"role": "user",
|
| 234 |
+
"content": user_text,
|
| 235 |
+
"timestamp": f"step_{step['step_num']:04d}",
|
| 236 |
+
"attachments": _build_attachment(step, text_only=text_only),
|
| 237 |
+
})
|
| 238 |
+
# Assistant turn
|
| 239 |
+
assistant_text = _build_assistant_text(step)
|
| 240 |
+
dialogue.append({
|
| 241 |
+
"role": "assistant",
|
| 242 |
+
"content": assistant_text,
|
| 243 |
+
"timestamp": f"step_{step['step_num']:04d}",
|
| 244 |
+
"attachments": [],
|
| 245 |
+
})
|
| 246 |
+
|
| 247 |
+
sess_obj: dict[str, Any] = {
|
| 248 |
+
"_v2_session_id": sid,
|
| 249 |
+
"dialogue": dialogue,
|
| 250 |
+
}
|
| 251 |
+
# S00 carries its own memory_points in the session object
|
| 252 |
+
if sid == "S00":
|
| 253 |
+
sess_obj["memory_points"] = [_convert_mp(mp) for mp in mp_by_session.get(sid, [])]
|
| 254 |
+
session_objects.append(sess_obj)
|
| 255 |
+
|
| 256 |
+
sample_json = {
|
| 257 |
+
"uuid": sample_uuid,
|
| 258 |
+
"sample_id": sample_id,
|
| 259 |
+
"sessions": session_objects,
|
| 260 |
+
# Metadata (not consumed by loader, but useful for debugging)
|
| 261 |
+
"_source": rec.get("source", ""),
|
| 262 |
+
"_env": rec.get("env", ""),
|
| 263 |
+
"_traj_id": rec.get("traj_id", ""),
|
| 264 |
+
"_task": rec.get("task", ""),
|
| 265 |
+
"_total_steps": total_steps,
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
# --- Build stage4_memory_points.jsonl row ---
|
| 269 |
+
stage4_sessions: list[dict[str, Any]] = []
|
| 270 |
+
for sid, _ in sessions:
|
| 271 |
+
if sid == "S00":
|
| 272 |
+
continue # S00 is embedded in domain_a_v2.json
|
| 273 |
+
mps = mp_by_session.get(sid, [])
|
| 274 |
+
if mps:
|
| 275 |
+
stage4_sessions.append({
|
| 276 |
+
"session_id": sid,
|
| 277 |
+
"memory_points": [_convert_mp(mp) for mp in mps],
|
| 278 |
+
})
|
| 279 |
+
|
| 280 |
+
stage4_row = {
|
| 281 |
+
"uuid": sample_uuid,
|
| 282 |
+
"sample_id": sample_id,
|
| 283 |
+
"memory_sessions": stage4_sessions,
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
# --- Build stage4b_qa_checkpoints.jsonl row ---
|
| 287 |
+
session_ids = [sid for sid, _ in sessions]
|
| 288 |
+
checkpoints: list[dict[str, Any]] = []
|
| 289 |
+
|
| 290 |
+
# Probe checkpoints
|
| 291 |
+
probe_by_after_step = {p["after_step_num"]: p for p in probes}
|
| 292 |
+
cumulative_sessions: list[str] = []
|
| 293 |
+
for sid, seg_steps in sessions:
|
| 294 |
+
cumulative_sessions.append(sid)
|
| 295 |
+
max_step_in_session = max(s["step_num"] for s in seg_steps)
|
| 296 |
+
probe = probe_by_after_step.get(max_step_in_session)
|
| 297 |
+
if probe is None:
|
| 298 |
+
continue
|
| 299 |
+
questions = [
|
| 300 |
+
_convert_question(q, mp_content_map)
|
| 301 |
+
for q in probe.get("questions", [])
|
| 302 |
+
]
|
| 303 |
+
if questions:
|
| 304 |
+
checkpoints.append({
|
| 305 |
+
"checkpoint_id": f"probe_{probe['probe_id']}",
|
| 306 |
+
"covered_sessions": list(cumulative_sessions),
|
| 307 |
+
"questions": questions,
|
| 308 |
+
})
|
| 309 |
+
|
| 310 |
+
# Post-trajectory checkpoint (covers all sessions)
|
| 311 |
+
if post_qa:
|
| 312 |
+
post_questions = [
|
| 313 |
+
_convert_question(q, mp_content_map)
|
| 314 |
+
for q in post_qa
|
| 315 |
+
]
|
| 316 |
+
if post_questions:
|
| 317 |
+
checkpoints.append({
|
| 318 |
+
"checkpoint_id": "post_trajectory",
|
| 319 |
+
"covered_sessions": session_ids,
|
| 320 |
+
"questions": post_questions,
|
| 321 |
+
})
|
| 322 |
+
|
| 323 |
+
qa_row = {
|
| 324 |
+
"uuid": sample_uuid,
|
| 325 |
+
"sample_id": sample_id,
|
| 326 |
+
"checkpoints": checkpoints,
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
return sample_json, stage4_row, qa_row
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def convert_files(
|
| 333 |
+
input_paths: list[Path],
|
| 334 |
+
output_dir: Path,
|
| 335 |
+
*,
|
| 336 |
+
text_only: bool = True,
|
| 337 |
+
) -> None:
|
| 338 |
+
"""Read VisTrajQA session files and write the three domain_a_v2 files."""
|
| 339 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 340 |
+
|
| 341 |
+
all_samples: list[dict[str, Any]] = []
|
| 342 |
+
all_stage4: list[dict[str, Any]] = []
|
| 343 |
+
all_qa: list[dict[str, Any]] = []
|
| 344 |
+
|
| 345 |
+
for path in input_paths:
|
| 346 |
+
print(f"Reading {path} ...")
|
| 347 |
+
with path.open(encoding="utf-8") as fh:
|
| 348 |
+
for line_num, line in enumerate(fh, 1):
|
| 349 |
+
line = line.strip()
|
| 350 |
+
if not line:
|
| 351 |
+
continue
|
| 352 |
+
rec = json.loads(line)
|
| 353 |
+
sample_json, stage4_row, qa_row = convert_one_session(
|
| 354 |
+
rec, text_only=text_only,
|
| 355 |
+
)
|
| 356 |
+
all_samples.append(sample_json)
|
| 357 |
+
all_stage4.append(stage4_row)
|
| 358 |
+
all_qa.append(qa_row)
|
| 359 |
+
|
| 360 |
+
# Write domain_a_v2.json
|
| 361 |
+
domain_path = output_dir / "domain_a_v2.json"
|
| 362 |
+
domain_path.write_text(
|
| 363 |
+
json.dumps(all_samples, ensure_ascii=False, indent=2) + "\n",
|
| 364 |
+
encoding="utf-8",
|
| 365 |
+
)
|
| 366 |
+
print(f" → {domain_path} ({len(all_samples)} samples)")
|
| 367 |
+
|
| 368 |
+
# Write stage4_memory_points.jsonl
|
| 369 |
+
stage4_path = output_dir / "stage4_memory_points.jsonl"
|
| 370 |
+
with stage4_path.open("w", encoding="utf-8") as fh:
|
| 371 |
+
for row in all_stage4:
|
| 372 |
+
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 373 |
+
print(f" → {stage4_path} ({len(all_stage4)} rows)")
|
| 374 |
+
|
| 375 |
+
# Write stage4b_qa_checkpoints.jsonl
|
| 376 |
+
qa_path = output_dir / "stage4b_qa_checkpoints.jsonl"
|
| 377 |
+
with qa_path.open("w", encoding="utf-8") as fh:
|
| 378 |
+
for row in all_qa:
|
| 379 |
+
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 380 |
+
print(f" → {qa_path} ({len(all_qa)} rows)")
|
| 381 |
+
|
| 382 |
+
# --- Validation ---
|
| 383 |
+
print("\nValidating ...")
|
| 384 |
+
_validate(output_dir)
|
| 385 |
+
print("Done.")
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _validate(output_dir: Path) -> None:
|
| 389 |
+
"""Quick validation: load through the eval_framework loader."""
|
| 390 |
+
try:
|
| 391 |
+
from eval_framework.datasets.domain_a_v2 import load_domain_a_v2_academic
|
| 392 |
+
bundle = load_domain_a_v2_academic(output_dir)
|
| 393 |
+
print(f" Loaded {len(bundle.samples)} samples successfully")
|
| 394 |
+
for sample in bundle.samples:
|
| 395 |
+
n_sessions = len(sample.sessions)
|
| 396 |
+
n_turns = sum(len(s.turns) for s in sample.sessions)
|
| 397 |
+
n_checkpoints = len(sample.normalized_checkpoints)
|
| 398 |
+
n_questions = sum(len(cp.questions) for cp in sample.normalized_checkpoints)
|
| 399 |
+
n_gold = len(sample.session_gold_states)
|
| 400 |
+
n_gold_points = sum(
|
| 401 |
+
len(g.cumulative_gold_memories)
|
| 402 |
+
for g in sample.session_gold_states[-1:]
|
| 403 |
+
)
|
| 404 |
+
print(
|
| 405 |
+
f" {sample.sample_id}: "
|
| 406 |
+
f"{n_sessions} sessions, {n_turns} turns, "
|
| 407 |
+
f"{n_checkpoints} checkpoints, {n_questions} questions, "
|
| 408 |
+
f"{n_gold_points} gold points"
|
| 409 |
+
)
|
| 410 |
+
except Exception as e:
|
| 411 |
+
print(f" Validation failed: {e}")
|
| 412 |
+
raise
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def main() -> None:
|
| 416 |
+
parser = argparse.ArgumentParser(
|
| 417 |
+
description="Convert VisTrajQA sessions → eval_framework domain_a_v2 format",
|
| 418 |
+
)
|
| 419 |
+
parser.add_argument(
|
| 420 |
+
"--input", "-i",
|
| 421 |
+
nargs="+",
|
| 422 |
+
required=True,
|
| 423 |
+
help="Path(s) to sessions-*.jsonl files",
|
| 424 |
+
)
|
| 425 |
+
parser.add_argument(
|
| 426 |
+
"--output", "-o",
|
| 427 |
+
required=True,
|
| 428 |
+
help="Output directory for the three converted files",
|
| 429 |
+
)
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--text-only",
|
| 432 |
+
action="store_true",
|
| 433 |
+
default=True,
|
| 434 |
+
help="Inline image captions into user turn text (default: true)",
|
| 435 |
+
)
|
| 436 |
+
parser.add_argument(
|
| 437 |
+
"--multimodal",
|
| 438 |
+
action="store_true",
|
| 439 |
+
help="Keep image captions as attachments instead of inlining",
|
| 440 |
+
)
|
| 441 |
+
args = parser.parse_args()
|
| 442 |
+
|
| 443 |
+
text_only = not args.multimodal
|
| 444 |
+
|
| 445 |
+
input_paths = [Path(p).expanduser().resolve() for p in args.input]
|
| 446 |
+
output_dir = Path(args.output).expanduser().resolve()
|
| 447 |
+
|
| 448 |
+
convert_files(input_paths, output_dir, text_only=text_only)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
if __name__ == "__main__":
|
| 452 |
+
main()
|
datasets/domain_a_v2.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Domain A v2 academic bundle: dialogue normalization + staged QA / gold state."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Iterator, Mapping
|
| 9 |
+
|
| 10 |
+
from eval_framework.datasets.schemas import NormalizedTurn, normalize_turn
|
| 11 |
+
from eval_framework.pipeline.gold_state import (
|
| 12 |
+
SessionGoldState,
|
| 13 |
+
build_session_gold_states,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass(frozen=True)
|
| 18 |
+
class Stage4Record:
|
| 19 |
+
uuid: str
|
| 20 |
+
sample_id: str
|
| 21 |
+
memory_sessions: tuple[tuple[str, tuple[Mapping[str, Any], ...]], ...]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class QARecord:
|
| 26 |
+
uuid: str
|
| 27 |
+
sample_id: str
|
| 28 |
+
raw_checkpoints: tuple[Mapping[str, Any], ...]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass(frozen=True)
|
| 32 |
+
class NormalizedCheckpointQuestion:
|
| 33 |
+
question: str
|
| 34 |
+
gold_answer: str
|
| 35 |
+
gold_evidence_memory_ids: tuple[str, ...]
|
| 36 |
+
gold_evidence_contents: tuple[str, ...]
|
| 37 |
+
question_type: str
|
| 38 |
+
question_type_abbrev: str
|
| 39 |
+
difficulty: str
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass(frozen=True)
|
| 43 |
+
class NormalizedCheckpoint:
|
| 44 |
+
checkpoint_id: str
|
| 45 |
+
covered_sessions: tuple[str, ...]
|
| 46 |
+
questions: tuple[NormalizedCheckpointQuestion, ...]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass(frozen=True)
|
| 50 |
+
class DomainAV2Session:
|
| 51 |
+
session_id: str
|
| 52 |
+
turns: tuple[NormalizedTurn, ...]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass(frozen=True)
|
| 56 |
+
class DomainAV2AcademicSample:
|
| 57 |
+
uuid: str
|
| 58 |
+
sample_id: str
|
| 59 |
+
sessions: tuple[DomainAV2Session, ...]
|
| 60 |
+
stage4: Stage4Record
|
| 61 |
+
qa_record: QARecord
|
| 62 |
+
normalized_checkpoints: tuple[NormalizedCheckpoint, ...]
|
| 63 |
+
session_gold_states: tuple[SessionGoldState, ...]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass(frozen=True)
|
| 67 |
+
class DomainAV2AcademicBundle:
|
| 68 |
+
samples: tuple[DomainAV2AcademicSample, ...]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _read_jsonl(path: Path) -> Iterator[dict[str, Any]]:
|
| 72 |
+
with path.open(encoding="utf-8") as fh:
|
| 73 |
+
for line in fh:
|
| 74 |
+
line = line.strip()
|
| 75 |
+
if not line:
|
| 76 |
+
continue
|
| 77 |
+
yield json.loads(line)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _stage4_from_obj(obj: Mapping[str, Any]) -> Stage4Record:
|
| 81 |
+
blocks: list[tuple[str, tuple[Mapping[str, Any], ...]]] = []
|
| 82 |
+
for ms in obj.get("memory_sessions") or []:
|
| 83 |
+
sid = str(ms.get("session_id", ""))
|
| 84 |
+
pts = ms.get("memory_points") or []
|
| 85 |
+
if not isinstance(pts, list):
|
| 86 |
+
pts = []
|
| 87 |
+
blocks.append((sid, tuple(pts)))
|
| 88 |
+
return Stage4Record(
|
| 89 |
+
uuid=str(obj["uuid"]),
|
| 90 |
+
sample_id=str(obj["sample_id"]),
|
| 91 |
+
memory_sessions=tuple(blocks),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _qa_from_obj(obj: Mapping[str, Any]) -> QARecord:
|
| 96 |
+
cps = obj.get("checkpoints") or []
|
| 97 |
+
if not isinstance(cps, list):
|
| 98 |
+
cps = []
|
| 99 |
+
return QARecord(
|
| 100 |
+
uuid=str(obj["uuid"]),
|
| 101 |
+
sample_id=str(obj["sample_id"]),
|
| 102 |
+
raw_checkpoints=tuple(cps),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _normalize_checkpoint_question(
|
| 107 |
+
raw: Mapping[str, Any],
|
| 108 |
+
memory_content_map: Mapping[str, str],
|
| 109 |
+
) -> NormalizedCheckpointQuestion:
|
| 110 |
+
evidence = raw.get("evidence") or []
|
| 111 |
+
mem_ids: list[str] = []
|
| 112 |
+
mem_contents: list[str] = []
|
| 113 |
+
if isinstance(evidence, list):
|
| 114 |
+
for item in evidence:
|
| 115 |
+
if isinstance(item, dict) and "memory_id" in item:
|
| 116 |
+
mid = str(item["memory_id"])
|
| 117 |
+
mem_ids.append(mid)
|
| 118 |
+
content = memory_content_map.get(mid, "")
|
| 119 |
+
if content:
|
| 120 |
+
mem_contents.append(content)
|
| 121 |
+
return NormalizedCheckpointQuestion(
|
| 122 |
+
question=str(raw.get("question", "")),
|
| 123 |
+
gold_answer=str(raw.get("answer", "")),
|
| 124 |
+
gold_evidence_memory_ids=tuple(mem_ids),
|
| 125 |
+
gold_evidence_contents=tuple(mem_contents),
|
| 126 |
+
question_type=str(raw.get("question_type", "")),
|
| 127 |
+
question_type_abbrev=str(raw.get("question_type_abbrev", "")),
|
| 128 |
+
difficulty=str(raw.get("difficulty", "")),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _normalize_checkpoints(
|
| 133 |
+
raw_checkpoints: tuple[Mapping[str, Any], ...],
|
| 134 |
+
memory_content_map: Mapping[str, str],
|
| 135 |
+
) -> tuple[NormalizedCheckpoint, ...]:
|
| 136 |
+
out: list[NormalizedCheckpoint] = []
|
| 137 |
+
for cp in raw_checkpoints:
|
| 138 |
+
qs = cp.get("questions") or []
|
| 139 |
+
if not isinstance(qs, list):
|
| 140 |
+
qs = []
|
| 141 |
+
covered = cp.get("covered_sessions") or []
|
| 142 |
+
if not isinstance(covered, list):
|
| 143 |
+
covered = []
|
| 144 |
+
out.append(
|
| 145 |
+
NormalizedCheckpoint(
|
| 146 |
+
checkpoint_id=str(cp.get("checkpoint_id", "")),
|
| 147 |
+
covered_sessions=tuple(str(x) for x in covered),
|
| 148 |
+
questions=tuple(
|
| 149 |
+
_normalize_checkpoint_question(q, memory_content_map)
|
| 150 |
+
for q in qs
|
| 151 |
+
if isinstance(q, Mapping)
|
| 152 |
+
),
|
| 153 |
+
)
|
| 154 |
+
)
|
| 155 |
+
return tuple(out)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _dialogue_turns(sample_id: str, session_id: str, dialogue: list[Any]) -> tuple[NormalizedTurn, ...]:
|
| 159 |
+
turns: list[NormalizedTurn] = []
|
| 160 |
+
for turn_index, entry in enumerate(dialogue):
|
| 161 |
+
if not isinstance(entry, dict):
|
| 162 |
+
continue
|
| 163 |
+
text = str(entry.get("content", ""))
|
| 164 |
+
attachments_raw = entry.get("attachments") or []
|
| 165 |
+
captions: list[str] = []
|
| 166 |
+
if isinstance(attachments_raw, list):
|
| 167 |
+
for att in attachments_raw:
|
| 168 |
+
if isinstance(att, dict):
|
| 169 |
+
cap = att.get("caption", "")
|
| 170 |
+
captions.append(cap if isinstance(cap, str) else str(cap))
|
| 171 |
+
if captions:
|
| 172 |
+
text = text + "\n\n" + "\n".join(captions)
|
| 173 |
+
ts = entry.get("timestamp")
|
| 174 |
+
timestamp = ts if isinstance(ts, str) else (str(ts) if ts is not None else None)
|
| 175 |
+
raw_turn = {
|
| 176 |
+
"sample_id": sample_id,
|
| 177 |
+
"session_id": session_id,
|
| 178 |
+
"turn_index": turn_index,
|
| 179 |
+
"role": str(entry.get("role", "user")),
|
| 180 |
+
"text": text,
|
| 181 |
+
"attachments": [],
|
| 182 |
+
"timestamp": timestamp,
|
| 183 |
+
}
|
| 184 |
+
turns.append(normalize_turn(raw_turn))
|
| 185 |
+
return tuple(turns)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def load_domain_a_v2_academic(data_dir: Path) -> DomainAV2AcademicBundle:
|
| 189 |
+
data_dir = data_dir.resolve()
|
| 190 |
+
main_path = data_dir / "domain_a_v2.json"
|
| 191 |
+
stage4_path = data_dir / "stage4_memory_points.jsonl"
|
| 192 |
+
qa_path = data_dir / "stage4b_qa_checkpoints.jsonl"
|
| 193 |
+
|
| 194 |
+
raw_samples = json.loads(main_path.read_text(encoding="utf-8"))
|
| 195 |
+
if not isinstance(raw_samples, list):
|
| 196 |
+
raise ValueError("domain_a_v2.json must be a list")
|
| 197 |
+
|
| 198 |
+
stage4_by_id: dict[str, Stage4Record] = {}
|
| 199 |
+
for obj in _read_jsonl(stage4_path):
|
| 200 |
+
rec = _stage4_from_obj(obj)
|
| 201 |
+
stage4_by_id[rec.sample_id] = rec
|
| 202 |
+
|
| 203 |
+
qa_by_id: dict[str, QARecord] = {}
|
| 204 |
+
for obj in _read_jsonl(qa_path):
|
| 205 |
+
rec = _qa_from_obj(obj)
|
| 206 |
+
qa_by_id[rec.sample_id] = rec
|
| 207 |
+
|
| 208 |
+
built: list[DomainAV2AcademicSample] = []
|
| 209 |
+
for item in raw_samples:
|
| 210 |
+
if not isinstance(item, dict):
|
| 211 |
+
continue
|
| 212 |
+
sample_id = str(item["sample_id"])
|
| 213 |
+
uuid = str(item["uuid"])
|
| 214 |
+
stage4 = stage4_by_id.get(sample_id)
|
| 215 |
+
qa = qa_by_id.get(sample_id)
|
| 216 |
+
if stage4 is None or qa is None:
|
| 217 |
+
raise KeyError(f"missing stage4 or QA row for sample_id={sample_id}")
|
| 218 |
+
|
| 219 |
+
stage4_map = {sid: pts for sid, pts in stage4.memory_sessions}
|
| 220 |
+
|
| 221 |
+
sessions_raw = item.get("sessions") or []
|
| 222 |
+
if not isinstance(sessions_raw, list):
|
| 223 |
+
sessions_raw = []
|
| 224 |
+
|
| 225 |
+
session_blocks: list[DomainAV2Session] = []
|
| 226 |
+
ordered_ids: list[str] = []
|
| 227 |
+
s00_points: tuple[Mapping[str, Any], ...] = ()
|
| 228 |
+
|
| 229 |
+
for sess in sessions_raw:
|
| 230 |
+
if not isinstance(sess, dict):
|
| 231 |
+
continue
|
| 232 |
+
sid = str(sess.get("_v2_session_id", ""))
|
| 233 |
+
if not sid:
|
| 234 |
+
continue
|
| 235 |
+
ordered_ids.append(sid)
|
| 236 |
+
dialogue = sess.get("dialogue") or []
|
| 237 |
+
if not isinstance(dialogue, list):
|
| 238 |
+
dialogue = []
|
| 239 |
+
session_blocks.append(
|
| 240 |
+
DomainAV2Session(
|
| 241 |
+
session_id=sid,
|
| 242 |
+
turns=_dialogue_turns(sample_id, sid, dialogue),
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
if sid == "S00":
|
| 246 |
+
mps = sess.get("memory_points") or []
|
| 247 |
+
if isinstance(mps, list):
|
| 248 |
+
s00_points = tuple(mps)
|
| 249 |
+
|
| 250 |
+
gold_states = build_session_gold_states(
|
| 251 |
+
ordered_ids,
|
| 252 |
+
s00_memory_points=s00_points,
|
| 253 |
+
stage4_by_session_id=stage4_map,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Build memory_id -> memory_content map from all sources
|
| 257 |
+
memory_content_map: dict[str, str] = {}
|
| 258 |
+
for mp_raw in s00_points:
|
| 259 |
+
if isinstance(mp_raw, Mapping):
|
| 260 |
+
mid = mp_raw.get("memory_id")
|
| 261 |
+
mc = mp_raw.get("memory_content")
|
| 262 |
+
if mid is not None and mc is not None:
|
| 263 |
+
memory_content_map[str(mid)] = str(mc)
|
| 264 |
+
for _sid, pts in stage4.memory_sessions:
|
| 265 |
+
for mp_raw in pts:
|
| 266 |
+
if isinstance(mp_raw, Mapping):
|
| 267 |
+
mid = mp_raw.get("memory_id")
|
| 268 |
+
mc = mp_raw.get("memory_content")
|
| 269 |
+
if mid is not None and mc is not None:
|
| 270 |
+
memory_content_map[str(mid)] = str(mc)
|
| 271 |
+
|
| 272 |
+
built.append(
|
| 273 |
+
DomainAV2AcademicSample(
|
| 274 |
+
uuid=uuid,
|
| 275 |
+
sample_id=sample_id,
|
| 276 |
+
sessions=tuple(session_blocks),
|
| 277 |
+
stage4=stage4,
|
| 278 |
+
qa_record=qa,
|
| 279 |
+
normalized_checkpoints=_normalize_checkpoints(
|
| 280 |
+
qa.raw_checkpoints, memory_content_map
|
| 281 |
+
),
|
| 282 |
+
session_gold_states=gold_states,
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
return DomainAV2AcademicBundle(samples=tuple(built))
|
datasets/schemas.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Normalized runtime schemas shared across adapters, pipeline, and evaluators."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any, Mapping
|
| 7 |
+
|
| 8 |
+
MemoryDeltaOp = str
|
| 9 |
+
|
| 10 |
+
_VALID_DELTA_OPS: frozenset[str] = frozenset(
|
| 11 |
+
{"add", "update", "keep", "suppress", "archive"}
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class Attachment:
|
| 17 |
+
"""Caption-first attachment; image_id is optional for caption-only items."""
|
| 18 |
+
|
| 19 |
+
caption: str
|
| 20 |
+
type: str = "image_caption"
|
| 21 |
+
image_id: str | None = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass(frozen=True)
|
| 25 |
+
class NormalizedTurn:
|
| 26 |
+
sample_id: str
|
| 27 |
+
session_id: str
|
| 28 |
+
turn_index: int
|
| 29 |
+
role: str
|
| 30 |
+
text: str
|
| 31 |
+
attachments: tuple[Attachment, ...] = ()
|
| 32 |
+
timestamp: str | None = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def normalize_turn(raw: Mapping[str, Any]) -> NormalizedTurn:
|
| 36 |
+
"""Build a turn record, keeping attachments that only carry captions."""
|
| 37 |
+
attachments: list[Attachment] = []
|
| 38 |
+
for item in raw.get("attachments") or []:
|
| 39 |
+
if not isinstance(item, dict):
|
| 40 |
+
continue
|
| 41 |
+
cap = item.get("caption", "")
|
| 42 |
+
caption = cap if isinstance(cap, str) else str(cap)
|
| 43 |
+
iid = item.get("image_id")
|
| 44 |
+
if iid is None or iid == "":
|
| 45 |
+
image_id: str | None = None
|
| 46 |
+
else:
|
| 47 |
+
image_id = str(iid)
|
| 48 |
+
typ = item.get("type", "image_caption")
|
| 49 |
+
type_str = typ if isinstance(typ, str) else str(typ)
|
| 50 |
+
attachments.append(
|
| 51 |
+
Attachment(caption=caption, type=type_str, image_id=image_id)
|
| 52 |
+
)
|
| 53 |
+
ts = raw.get("timestamp")
|
| 54 |
+
timestamp = ts if isinstance(ts, str) else (str(ts) if ts is not None else None)
|
| 55 |
+
return NormalizedTurn(
|
| 56 |
+
sample_id=str(raw["sample_id"]),
|
| 57 |
+
session_id=str(raw["session_id"]),
|
| 58 |
+
turn_index=int(raw["turn_index"]),
|
| 59 |
+
role=str(raw["role"]),
|
| 60 |
+
text=str(raw["text"]),
|
| 61 |
+
attachments=tuple(attachments),
|
| 62 |
+
timestamp=timestamp,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass(frozen=True)
|
| 67 |
+
class MemorySnapshotRecord:
|
| 68 |
+
memory_id: str
|
| 69 |
+
text: str
|
| 70 |
+
session_id: str
|
| 71 |
+
status: str
|
| 72 |
+
source: str | None = None
|
| 73 |
+
raw_backend_id: str | None = None
|
| 74 |
+
raw_backend_type: str | None = None
|
| 75 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@dataclass(frozen=True)
|
| 79 |
+
class MemoryDeltaRecord:
|
| 80 |
+
session_id: str
|
| 81 |
+
op: MemoryDeltaOp
|
| 82 |
+
text: str
|
| 83 |
+
linked_previous: tuple[str, ...] = ()
|
| 84 |
+
raw_backend_id: str | None = None
|
| 85 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 86 |
+
|
| 87 |
+
def __post_init__(self) -> None:
|
| 88 |
+
if self.op not in _VALID_DELTA_OPS:
|
| 89 |
+
raise ValueError(f"invalid memory delta op: {self.op!r}")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass(frozen=True)
|
| 93 |
+
class RetrievalItem:
|
| 94 |
+
rank: int
|
| 95 |
+
memory_id: str
|
| 96 |
+
text: str
|
| 97 |
+
score: float
|
| 98 |
+
raw_backend_id: str | None = None
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@dataclass(frozen=True)
|
| 102 |
+
class RetrievalRecord:
|
| 103 |
+
query: str
|
| 104 |
+
top_k: int
|
| 105 |
+
items: list[RetrievalItem]
|
| 106 |
+
raw_trace: dict[str, Any] = field(default_factory=dict)
|
docs/DATA_CONVERSION.md
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VisTrajQA → Eval Framework 数据适配指南
|
| 2 |
+
|
| 3 |
+
## 概述
|
| 4 |
+
|
| 5 |
+
`convert_vistrajqa.py` 将 VisTrajQA 的 `sessions-*.jsonl` 转换为 eval_framework 所需的 domain_a_v2 三文件格式,从而可以用 Mem-Gallery / A-Mem / MemoryOS 等 baseline 进行统一评测。
|
| 6 |
+
|
| 7 |
+
## 快速使用
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
# 转换所有数据源(text-only 模式,默认)
|
| 11 |
+
python -m eval_framework.datasets.convert_vistrajqa \
|
| 12 |
+
--input data/generated/sessions-vab.jsonl \
|
| 13 |
+
data/generated/sessions-eb-nav.jsonl \
|
| 14 |
+
data/generated/sessions-arena.jsonl \
|
| 15 |
+
data/generated/sessions-eb-alfred.jsonl \
|
| 16 |
+
data/generated/sessions-infini-thor.jsonl \
|
| 17 |
+
--output eval_framework/converted/all
|
| 18 |
+
|
| 19 |
+
# 只转换某个数据源
|
| 20 |
+
python -m eval_framework.datasets.convert_vistrajqa \
|
| 21 |
+
--input data/generated/sessions-vab.jsonl \
|
| 22 |
+
--output eval_framework/converted/vab
|
| 23 |
+
|
| 24 |
+
# multimodal 模式(image caption 作为 attachment 而非内联文本)
|
| 25 |
+
python -m eval_framework.datasets.convert_vistrajqa \
|
| 26 |
+
--input data/generated/sessions-vab.jsonl \
|
| 27 |
+
--output eval_framework/converted/vab-mm \
|
| 28 |
+
--multimodal
|
| 29 |
+
|
| 30 |
+
# 转换后直接跑 eval
|
| 31 |
+
python -m eval_framework.cli \
|
| 32 |
+
--dataset eval_framework/converted/all \
|
| 33 |
+
--baseline FUMemory \
|
| 34 |
+
--output-dir eval_framework/results/FUMemory
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## 转换映射
|
| 38 |
+
|
| 39 |
+
### 数据结构映射
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
VisTrajQA session → eval_framework sample
|
| 43 |
+
├── session_id → sample_id
|
| 44 |
+
├── step_plan[] → sessions[].dialogue[] (user + assistant turns)
|
| 45 |
+
├── probes[] → checkpoints[] (probe checkpoints)
|
| 46 |
+
├── post_trajectory_qa[] → checkpoints[-1] (post-trajectory checkpoint)
|
| 47 |
+
└── memory_points[] → gold memory points (S00 embedded + stage4)
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### Session 切分
|
| 51 |
+
|
| 52 |
+
一条 VisTrajQA 轨迹(如 30 步,4 个 probe 在 step 6/12/18/24)按 probe 边界切分为 5 个 session:
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
步骤 1-6 → S00 (probe 1 在此 session 结束后触发)
|
| 56 |
+
步骤 7-12 → S01 (probe 2)
|
| 57 |
+
步骤 13-18 → S02 (probe 3)
|
| 58 |
+
步骤 19-24 → S03 (probe 4)
|
| 59 |
+
步骤 25-30 → S04 (post-trajectory QA 在全部 session 结束后触发)
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
这样保证 eval_framework 的 runner 在每个 session 完成后恰好触发对应的 checkpoint。
|
| 63 |
+
|
| 64 |
+
### Turn 构建
|
| 65 |
+
|
| 66 |
+
每个 step 生成 2 个 dialogue turn:
|
| 67 |
+
|
| 68 |
+
| Turn | Role | 内容 |
|
| 69 |
+
|------|------|------|
|
| 70 |
+
| User turn | `user` | OBSERVATION + FEEDBACK + IMAGE caption(text-only 模式) |
|
| 71 |
+
| Assistant turn | `assistant` | THOUGHT + ACTION |
|
| 72 |
+
|
| 73 |
+
**text-only 模式**(默认):image caption 直接写入 user turn 文本,格式为 `IMAGE: <caption>`。适用于所有 text-only baseline。
|
| 74 |
+
|
| 75 |
+
**multimodal 模式**(`--multimodal`):image caption 作为 `attachment` 附加,不写入正文。适用于 MMMemory 等多模态 baseline。
|
| 76 |
+
|
| 77 |
+
### Memory Point 映射
|
| 78 |
+
|
| 79 |
+
| VisTrajQA 字段 | eval_framework 字段 | 说明 |
|
| 80 |
+
|----------------|---------------------|------|
|
| 81 |
+
| `mp_id` | `memory_id` | 如 `mp_S04_1` |
|
| 82 |
+
| `content` | `memory_content` | 一句话事实描述 |
|
| 83 |
+
| `type` | `memory_type` | `event_memory` / `state_memory` / `spatial_memory` |
|
| 84 |
+
| `source` | `memory_source` | `primary` (文本) / `secondary` (推断) |
|
| 85 |
+
| `is_update` | `is_update` | 是否为更新型记忆 |
|
| 86 |
+
| `original_memories` | `original_memories` | 被替换的旧内容列表 |
|
| 87 |
+
| `importance` | `importance` | 0.4 / 0.6 / 0.8 / 1.0 |
|
| 88 |
+
| `update_type` | `update_type` | `status_update` / `location_change` / ... |
|
| 89 |
+
|
| 90 |
+
Memory point 按 `step_num` 分配到对应 session:
|
| 91 |
+
- S00 的 memory points 嵌入在 `domain_a_v2.json` 的 session 对象中
|
| 92 |
+
- 其他 session 的 memory points 写入 `stage4_memory_points.jsonl`
|
| 93 |
+
|
| 94 |
+
### QA / Checkpoint 映射
|
| 95 |
+
|
| 96 |
+
**Probe checkpoint**:每个 probe 生成一个 checkpoint,`covered_sessions` 为该 probe 及之前所有 session。
|
| 97 |
+
|
| 98 |
+
**Post-trajectory checkpoint**:覆盖全部 session,包含 9 类 QA。
|
| 99 |
+
|
| 100 |
+
| VisTrajQA QA type | eval_framework question_type | 缩写 |
|
| 101 |
+
|----|----|-----|
|
| 102 |
+
| FR | factual_recall | FR |
|
| 103 |
+
| DU | dynamic_update | DU |
|
| 104 |
+
| MB | memory_boundary | MB |
|
| 105 |
+
| TR | temporal_reasoning | TR |
|
| 106 |
+
| KR | knowledge_reasoning | KR |
|
| 107 |
+
| VFR | visual_factual_recall | VFR |
|
| 108 |
+
| VS | visual_search | VS |
|
| 109 |
+
| VU | visual_update | VU |
|
| 110 |
+
| CMR | cross_modal_reasoning | CMR |
|
| 111 |
+
|
| 112 |
+
Evidence 字段从 `["mp_S04_1"]`(字符串列表)转换为 `[{"memory_id": "mp_S04_1"}]`(字典列表)以匹配 eval_framework 格式。
|
| 113 |
+
|
| 114 |
+
## 输出文件
|
| 115 |
+
|
| 116 |
+
```
|
| 117 |
+
eval_framework/converted/all/
|
| 118 |
+
├── domain_a_v2.json # 主对话数据 (JSON array)
|
| 119 |
+
├── stage4_memory_points.jsonl # 每 session 的 gold memory points
|
| 120 |
+
└── stage4b_qa_checkpoints.jsonl # checkpoint QA 题目
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
## 评测维度与 VisTrajQA 的对应
|
| 124 |
+
|
| 125 |
+
| eval_framework 维度 | 测量内容 | 对应 VisTrajQA 特性 |
|
| 126 |
+
|-----|-----|-----|
|
| 127 |
+
| Memory Recall | 记忆系统存储了多少 gold points | 直接对应,所有 MP 类型 |
|
| 128 |
+
| Memory Correctness | 存储的记忆是否正确 | 检测 hallucination |
|
| 129 |
+
| Update Handling | 更新型记忆是否正确替换 | 对应 `is_update=true` 的 MP |
|
| 130 |
+
| Interference Rejection | 干扰信息是否被过滤 | VisTrajQA 无 interference 标注,此维度为空 |
|
| 131 |
+
| QA Accuracy | 问答正确率 | 对应 9 类 QA (FR/DU/MB/TR/KR/VFR/VS/VU/CMR) |
|
| 132 |
+
| Evidence Coverage | 回答引用了多少 gold evidence | 对应 evidence memory_point_ids |
|
| 133 |
+
|
| 134 |
+
> **注意**:VisTrajQA 没有 interference(干扰信息)标注,因此 eval_framework 的 Interference Rejection 维度在评测结果中会为空值。MB(Memory Boundary)类型的题目在 QA 层面测试了类似能力。
|
| 135 |
+
|
| 136 |
+
## 注意事项
|
| 137 |
+
|
| 138 |
+
1. **text-only baseline(FU/ST/LT/GA/MG/RF)**:使用默认 `--text-only`,image caption 内联到用户消息文本中
|
| 139 |
+
2. **multimodal baseline(MM/MMFU/NG/AUGUSTUS)**:使用 `--multimodal`,caption 作为 attachment
|
| 140 |
+
3. **caption 质量**:text-only baseline 对图像的理解完全依赖 caption 质量。如果 `image_caption` 为空,用户 turn 中不会有任何视觉信息
|
| 141 |
+
4. **Arena 数据**:observation 恒为空字符串,视觉信息完全来自 image_caption
|
| 142 |
+
5. **转换器会自动验证**:运行后会调用 `load_domain_a_v2_academic` 检验输出是否合法
|
docs/EXPERIMENTS.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 实验记录
|
| 2 |
+
|
| 3 |
+
## 实验环境
|
| 4 |
+
|
| 5 |
+
- **数据**:VisTrajQA VAB smoke(1 sample, vab_minecraft, 30 步, 5 sessions, 45 QA)
|
| 6 |
+
- **转换模式**:text-only(image caption 内联到 user turn 文本)
|
| 7 |
+
- **Judge / Answer 模型**:从 `.env` 读取 `OPENAI_MODEL`
|
| 8 |
+
- **日期**:2026-04-15
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## 总表
|
| 13 |
+
|
| 14 |
+
### Memory 维度
|
| 15 |
+
|
| 16 |
+
| Baseline | Recall ↑ | Update Recall ↑ | Correctness ↑ | Hallucination ↓ | Irrelevant ↓ | Update Score ↑ |
|
| 17 |
+
|----------|----------|-----------------|---------------|-----------------|--------------|----------------|
|
| 18 |
+
| Dummy | 95.0% | 90.0% | 50.0% | 5.0% | 45.0% | 71.4% |
|
| 19 |
+
| FUMemory | 95.0% | 90.0% | 80.0% | 10.0% | 10.0% | 69.0% |
|
| 20 |
+
| STMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
|
| 21 |
+
| LTMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 69.0% |
|
| 22 |
+
| GAMemory | 92.1% | 90.0% | 62.5% | 22.5% | 15.0% | 71.4% |
|
| 23 |
+
| MGMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
|
| 24 |
+
| RFMemory | 95.0% | 90.0% | 90.0% | 6.7% | 3.3% | 71.4% |
|
| 25 |
+
| MMMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
|
| 26 |
+
| MMFUMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
|
| 27 |
+
| NGMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
|
| 28 |
+
| AUGUSTUSMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
|
| 29 |
+
| UniversalRAGMemory | 95.0% | 90.0% | 80.0% | 10.0% | 10.0% | 71.4% |
|
| 30 |
+
| Mem0 | 47.9% | 56.7% | 57.3% | 18.0% | 24.7% | 81.0% |
|
| 31 |
+
| Mem0-Graph | 48.2% | 56.7% | 37.9% | 14.6% | 38.1% | 71.4% |
|
| 32 |
+
| SimpleMem | 95.0% | 90.0% | 48.3% | 5.0% | 46.7% | 71.4% |
|
| 33 |
+
| Omni-SimpleMem | — | — | — | — | — | — |
|
| 34 |
+
| MemVerse | 95.0% | 90.0% | 65.0% | 23.3% | 11.7% | 71.4% |
|
| 35 |
+
| Zep | — | — | — | — | — | — |
|
| 36 |
+
| A-Mem | 95.0% | 90.0% | 56.7% | 5.0% | 38.3% | 71.4% |
|
| 37 |
+
| MemoryOS | 47.1% | 41.0% | 33.8% | 3.5% | 62.8% | 50.0% |
|
| 38 |
+
|
| 39 |
+
### QA 维度
|
| 40 |
+
|
| 41 |
+
| Baseline | QA Correct ↑ | QA Hallucination ↓ | QA Omission ↓ | Evidence Coverage ↑ |
|
| 42 |
+
|----------|--------------|---------------------|---------------|---------------------|
|
| 43 |
+
| Dummy | 44.4% | 24.4% | 31.1% | 47.6% |
|
| 44 |
+
| FUMemory | 57.8% | 40.0% | 2.2% | 65.0% |
|
| 45 |
+
| STMemory | 31.1% | 33.3% | 35.6% | 32.0% |
|
| 46 |
+
| LTMemory | 40.0% | 26.7% | 33.3% | 51.5% |
|
| 47 |
+
| GAMemory | 37.8% | 20.0% | 42.2% | 51.5% |
|
| 48 |
+
| MGMemory | 60.0% | 35.6% | 4.4% | 68.9% |
|
| 49 |
+
| RFMemory | 60.0% | 37.8% | 2.2% | 63.1% |
|
| 50 |
+
| MMMemory | 57.8% | 40.0% | 2.2% | 66.0% |
|
| 51 |
+
| MMFUMemory | 60.0% | 37.8% | 2.2% | 67.0% |
|
| 52 |
+
| NGMemory | 62.2% | 24.4% | 13.3% | 71.8% |
|
| 53 |
+
| AUGUSTUSMemory | 64.4% | 22.2% | 13.3% | 68.9% |
|
| 54 |
+
| UniversalRAGMemory | 40.0% | 28.9% | 31.1% | 49.5% |
|
| 55 |
+
| Mem0 | 26.7% | 28.9% | 44.4% | 31.1% |
|
| 56 |
+
| Mem0-Graph | 31.1% | 20.0% | 48.9% | 32.0% |
|
| 57 |
+
| SimpleMem | 57.8% | 22.2% | 20.0% | 35.9% |
|
| 58 |
+
| Omni-SimpleMem | — | — | — | — |
|
| 59 |
+
| MemVerse | 40.0% | 28.9% | 31.1% | 40.8% |
|
| 60 |
+
| Zep | — | — | — | — |
|
| 61 |
+
| A-Mem | 46.7% | 20.0% | 33.3% | 50.5% |
|
| 62 |
+
| MemoryOS | 28.9% | 24.4% | 46.7% | 37.9% |
|
| 63 |
+
|
| 64 |
+
### QA 分类型正确率
|
| 65 |
+
|
| 66 |
+
| Baseline | FR | DU | MB | TR | KR | VFR | VS | VU | CMR |
|
| 67 |
+
|----------|----|----|----|----|----|----|----|----|-----|
|
| 68 |
+
| Dummy | 5/5 | 0/5 | 4/5 | 2/5 | 2/5 | 4/5 | 0/5 | 1/5 | 2/5 |
|
| 69 |
+
| FUMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 2/5 | 0/5 |
|
| 70 |
+
| STMemory | 0/5 | 0/5 | 5/5 | 1/5 | 3/5 | 2/5 | 0/5 | 0/5 | 3/5 |
|
| 71 |
+
| LTMemory | 2/5 | 1/5 | 5/5 | 1/5 | 3/5 | 4/5 | 0/5 | 0/5 | 2/5 |
|
| 72 |
+
| GAMemory | 4/5 | 0/5 | 5/5 | 0/5 | 2/5 | 4/5 | 0/5 | 0/5 | 2/5 |
|
| 73 |
+
| MGMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 1/5 | 2/5 |
|
| 74 |
+
| RFMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 1/5 | 2/5 |
|
| 75 |
+
| MMMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 2/5 | 0/5 |
|
| 76 |
+
| MMFUMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 2/5 | 1/5 |
|
| 77 |
+
| NGMemory | 5/5 | 3/5 | 5/5 | 3/5 | 3/5 | 4/5 | 0/5 | 2/5 | 3/5 |
|
| 78 |
+
| AUGUSTUSMemory | 5/5 | 2/5 | 5/5 | 4/5 | 3/5 | 4/5 | 1/5 | 2/5 | 3/5 |
|
| 79 |
+
| UniversalRAGMemory | 2/5 | 1/5 | 5/5 | 1/5 | 3/5 | 4/5 | 0/5 | 0/5 | 2/5 |
|
| 80 |
+
| Mem0 | 2/5 | 1/5 | 5/5 | 0/5 | 0/5 | 2/5 | 0/5 | 0/5 | 2/5 |
|
| 81 |
+
| Mem0-Graph | 2/5 | 0/5 | 5/5 | 0/5 | 1/5 | 3/5 | 0/5 | 0/5 | 3/5 |
|
| 82 |
+
| SimpleMem | 5/5 | 3/5 | 4/5 | 3/5 | 2/5 | 2/5 | 0/5 | 3/5 | 4/5 |
|
| 83 |
+
| Omni-SimpleMem | — | — | — | — | — | — | — | — | — |
|
| 84 |
+
| MemVerse | 5/5 | 0/5 | 5/5 | 0/5 | 2/5 | 3/5 | 0/5 | 0/5 | 3/5 |
|
| 85 |
+
| Zep | — | — | — | — | — | — | — | — | — |
|
| 86 |
+
| A-Mem | 3/5 | 3/5 | 5/5 | 0/5 | 2/5 | 4/5 | 0/5 | 0/5 | 4/5 |
|
| 87 |
+
| MemoryOS | 4/5 | 1/5 | 5/5 | 0/5 | 0/5 | 1/5 | 0/5 | 0/5 | 2/5 |
|
docs/GUIDE.md
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Eval Framework 使用指南
|
| 2 |
+
|
| 3 |
+
## 1. 整体架构
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
eval_framework/
|
| 7 |
+
├── cli.py # 入口:CLI 解析 + 三阶段编排 (Pipeline → Eval → Aggregate)
|
| 8 |
+
├── config.py # EvalConfig 数据类
|
| 9 |
+
├── openai_compat.py # GPT-5 系列 max_tokens→max_completion_tokens 兼容补丁
|
| 10 |
+
├── datasets/
|
| 11 |
+
│ ├── schemas.py # 运行时共享数据结构 (NormalizedTurn, MemorySnapshotRecord, RetrievalRecord 等)
|
| 12 |
+
│ └── domain_a_v2.py # domain_a_v2 数据集加载器
|
| 13 |
+
├── memory_adapters/
|
| 14 |
+
│ ├── base.py # MemoryAdapter 抽象基类 (7 个接口方法)
|
| 15 |
+
│ ├── registry.py # Baseline 注册表 + Mem-Gallery 默认配置覆盖
|
| 16 |
+
│ ├── memgallery_native.py # Mem-Gallery 11 种内置 baseline 的统一适配器
|
| 17 |
+
│ ├── amem.py # A-Mem 外部 baseline 适配器
|
| 18 |
+
│ ├── memoryos.py # MemoryOS 外部 baseline 适配器
|
| 19 |
+
│ └── export_utils.py # 快照/检索结果归一化工具
|
| 20 |
+
├── pipeline/
|
| 21 |
+
│ ├── runner.py # 按 session 顺序喂入对话 → 生成 snapshot/delta → 触发 QA
|
| 22 |
+
│ ├── qa_runner.py # 对每个 checkpoint question 做 retrieve + answer
|
| 23 |
+
│ ├── gold_state.py # Gold memory points 累积构建
|
| 24 |
+
│ └── records.py # PipelineSessionRecord / PipelineCheckpointQARecord
|
| 25 |
+
├── evaluators/
|
| 26 |
+
│ ├── extraction.py # Session 级评估:Recall + Correctness + Update + Interference
|
| 27 |
+
│ ├── qa.py # Checkpoint QA 评估:Answer 正确性 + Evidence 覆盖率
|
| 28 |
+
│ └── aggregate.py # 聚合所有 session/QA 评估到 baseline 级汇总指标
|
| 29 |
+
└── judges/
|
| 30 |
+
├── llm_client.py # OpenAI 兼容 LLM 调用 + JSON 解析 + 重试 + 并发控制
|
| 31 |
+
└── prompts.py # 6 套 LLM judge prompt 模板
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## 2. 运行流程
|
| 35 |
+
|
| 36 |
+
整个 eval 分三个阶段(`cli.py: run_eval()`):
|
| 37 |
+
|
| 38 |
+
### Stage 1 — Pipeline(串行,适配器有状态)
|
| 39 |
+
|
| 40 |
+
```
|
| 41 |
+
for each sample:
|
| 42 |
+
adapter = create_adapter(baseline_name)
|
| 43 |
+
adapter.reset()
|
| 44 |
+
for each session in sample.sessions:
|
| 45 |
+
for each turn in session.turns:
|
| 46 |
+
adapter.ingest_turn(turn) # 喂入一条对话
|
| 47 |
+
adapter.end_session(session_id) # 触发 session 后处理(如 GA 反思、RF 优化)
|
| 48 |
+
snapshot = adapter.snapshot_memories() # 拍快照
|
| 49 |
+
delta = adapter.export_memory_delta() # 导出本 session 增量
|
| 50 |
+
→ PipelineSessionRecord
|
| 51 |
+
|
| 52 |
+
# 当某个 checkpoint 的 covered_sessions 全部完成时触发 QA
|
| 53 |
+
for each question in checkpoint:
|
| 54 |
+
retrieval = adapter.retrieve(question, top_k=5)
|
| 55 |
+
answer = answer_fn(question, retrieval) # 可注入外部 LLM 回答
|
| 56 |
+
→ PipelineCheckpointQARecord
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Pipeline 结束后写入 checkpoint 文件 `pipeline_sessions.jsonl` + `pipeline_qa.jsonl`,支持 `--eval-only` 跳过此阶段直接从 checkpoint 恢复。
|
| 60 |
+
|
| 61 |
+
### Stage 2 — Eval(并行,ThreadPoolExecutor)
|
| 62 |
+
|
| 63 |
+
- **Session 评估**(`evaluators/extraction.py`)— 每个 session 4+ 次 LLM 调用:
|
| 64 |
+
1. **Recall**:本 session 的 gold points 中有多少被 delta 覆盖?
|
| 65 |
+
2. **Correctness**:每条 delta 记忆是 correct / hallucination / irrelevant?
|
| 66 |
+
3. **Update handling**:每个 update gold point → updated / both / outdated
|
| 67 |
+
4. **Interference rejection**:每个 interference gold point → rejected / memorized
|
| 68 |
+
|
| 69 |
+
- **QA 评估**(`evaluators/qa.py`)— 每个 question 2 次 LLM 调用:
|
| 70 |
+
1. **Answer 正确性**:Correct / Hallucination / Omission
|
| 71 |
+
2. **Evidence 覆盖率**:cited memories 覆盖了多少 gold evidence points
|
| 72 |
+
|
| 73 |
+
### Stage 3 — Aggregate
|
| 74 |
+
|
| 75 |
+
将所有 session 和 QA 级别的评估结果聚合为 6 个维度的 baseline 级指标:
|
| 76 |
+
|
| 77 |
+
| 维度 | 聚合方式 | 关键指标 |
|
| 78 |
+
|------|---------|---------|
|
| 79 |
+
| Memory Recall | 按 session 平均 | `avg_recall`, `avg_update_recall` |
|
| 80 |
+
| Memory Correctness | 按 session 平均 | `avg_correctness`, `avg_hallucination` |
|
| 81 |
+
| Update Handling | 跨 session 池化 | `score` (updated=1.0, both=0.5, outdated=0.0) |
|
| 82 |
+
| Interference Rejection | 跨 session 池化 | `score` (rejected/total) |
|
| 83 |
+
| Question Answering | 跨 question 池化 | `correct_ratio`, `hallucination_ratio`, `omission_ratio` |
|
| 84 |
+
| Evidence Coverage | 跨 question 池化 | `hit_rate` |
|
| 85 |
+
|
| 86 |
+
输出文件:
|
| 87 |
+
- `session_records.jsonl` — 每条含 pipeline 数据 + eval 结果
|
| 88 |
+
- `qa_records.jsonl` — 同上
|
| 89 |
+
- `aggregate_metrics.json` — baseline 级汇总
|
| 90 |
+
|
| 91 |
+
## 3. 支持的 Baselines
|
| 92 |
+
|
| 93 |
+
### 3.1 Mem-Gallery 内置(11 种)
|
| 94 |
+
|
| 95 |
+
通过 `MemGalleryNativeAdapter` 统一包装,需要在 `eval_framework/` 同级目录放置 `memengine/` 和 `default_config/`(从 Mem-Gallery 的 `benchmark/` 目录复制)。
|
| 96 |
+
|
| 97 |
+
| Baseline | 类型 | 特性 | 额外依赖 |
|
| 98 |
+
|----------|------|------|---------|
|
| 99 |
+
| `FUMemory` | text-only | 全量存储(FIFO 截断) | — |
|
| 100 |
+
| `STMemory` | text-only | 短期记忆 | — |
|
| 101 |
+
| `LTMemory` | text-only | 长期记忆,embedding 检索 | sentence-transformers |
|
| 102 |
+
| `GAMemory` | text-only | 带 importance judge + 自反思 | LLM API |
|
| 103 |
+
| `MGMemory` | text-only | 多层存储(working/FIFO/recall/archival) | LLM API, sentence-transformers |
|
| 104 |
+
| `RFMemory` | text-only | 带 reflection optimizer | LLM API |
|
| 105 |
+
| `MMMemory` | multimodal | 多模态记忆 | torch |
|
| 106 |
+
| `MMFUMemory` | multimodal | 多模态全量存储 | torch |
|
| 107 |
+
| `NGMemory` | multimodal | 知识图谱节点存储 | torch |
|
| 108 |
+
| `AUGUSTUSMemory` | multimodal | 概念抽取 + 图谱 | LLM API, torch |
|
| 109 |
+
| `UniversalRAGMemory` | multimodal | RAG routing + 存储 | LLM API |
|
| 110 |
+
|
| 111 |
+
### 3.2 外部适配器
|
| 112 |
+
|
| 113 |
+
| Baseline | 来源 | 安装方式 | 需要外部服务 |
|
| 114 |
+
|----------|------|---------|-------------|
|
| 115 |
+
| `Mem0` | [mem0ai/mem0](https://github.com/mem0ai/mem0) | `pip install mem0ai` | 否(内置 Qdrant + SQLite) |
|
| 116 |
+
| `Mem0-Graph` | 同上(graph 模式) | `pip install "mem0ai[graph]"` | 需要 Neo4j |
|
| 117 |
+
| `SimpleMem` | [aiming-lab/SimpleMem](https://github.com/aiming-lab/SimpleMem) | clone + requirements | 否 |
|
| 118 |
+
| `Omni-SimpleMem` | 同上(omni 模式) | 同上 | 否 |
|
| 119 |
+
| `Zep` | [getzep/zep](https://github.com/getzep/zep) | `pip install zep-python` | 需要 Zep server |
|
| 120 |
+
| `A-Mem` | [A-Mem](https://arxiv.org/abs/2504.19413) | clone 源码 | 否 |
|
| 121 |
+
| `MemoryOS` | [MemoryOS](https://github.com/memodb-io/memobase) | clone 源码 | 否 |
|
| 122 |
+
|
| 123 |
+
**论文来源:**
|
| 124 |
+
|
| 125 |
+
| Baseline | 论文 | GitHub |
|
| 126 |
+
|----------|------|--------|
|
| 127 |
+
| Mem0 / Mem0-Graph | [arXiv:2504.19413](https://arxiv.org/abs/2504.19413) | https://github.com/mem0ai/mem0 |
|
| 128 |
+
| SimpleMem | [arXiv:2601.02553](https://arxiv.org/abs/2601.02553) | https://github.com/aiming-lab/SimpleMem |
|
| 129 |
+
| Omni-SimpleMem | [arXiv:2604.01007](https://arxiv.org/abs/2604.01007) | https://github.com/aiming-lab/SimpleMem |
|
| 130 |
+
| MemVerse | [arXiv:2512.03627](https://arxiv.org/abs/2512.03627) | https://github.com/KnowledgeXLab/MemVerse |
|
| 131 |
+
| Memobase | — | https://github.com/memodb-io/memobase |
|
| 132 |
+
| Supermemory | — | https://github.com/supermemoryai/supermemory |
|
| 133 |
+
| Zep | [arXiv:2501.13956](https://arxiv.org/abs/2501.13956) | https://github.com/getzep/zep |
|
| 134 |
+
|
| 135 |
+
### 3.3 添加新 Baseline
|
| 136 |
+
|
| 137 |
+
实现 `MemoryAdapter` 的 7 个抽象方法:
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
class MyAdapter(MemoryAdapter):
|
| 141 |
+
def reset(self) -> None: ...
|
| 142 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None: ...
|
| 143 |
+
def end_session(self, session_id: str) -> None: ...
|
| 144 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]: ...
|
| 145 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: ...
|
| 146 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord: ...
|
| 147 |
+
def get_capabilities(self) -> dict[str, Any]: ...
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
然后在 `registry.py` 的 `EXTERNAL_ADAPTER_REGISTRY` 中注册。
|
| 151 |
+
|
| 152 |
+
## 4. 数据适配
|
| 153 |
+
|
| 154 |
+
### 4.1 数据集格式(domain_a_v2)
|
| 155 |
+
|
| 156 |
+
加载器 `load_domain_a_v2_academic(data_dir)` 要求 `data_dir` 下有三个文件:
|
| 157 |
+
|
| 158 |
+
```
|
| 159 |
+
data_dir/
|
| 160 |
+
├── domain_a_v2.json # 主对话数据(JSON array)
|
| 161 |
+
├── stage4_memory_points.jsonl # 每 session 的 gold memory points
|
| 162 |
+
└── stage4b_qa_checkpoints.jsonl # checkpoint QA 题目
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
**`domain_a_v2.json`** 中每个 sample 结构:
|
| 166 |
+
|
| 167 |
+
```json
|
| 168 |
+
{
|
| 169 |
+
"uuid": "unique-id",
|
| 170 |
+
"sample_id": "sample_001",
|
| 171 |
+
"sessions": [
|
| 172 |
+
{
|
| 173 |
+
"_v2_session_id": "S00",
|
| 174 |
+
"dialogue": [
|
| 175 |
+
{
|
| 176 |
+
"role": "user",
|
| 177 |
+
"content": "Hello...",
|
| 178 |
+
"timestamp": "2025-01-01T10:00:00",
|
| 179 |
+
"attachments": [{"caption": "photo of...", "type": "image_caption"}]
|
| 180 |
+
},
|
| 181 |
+
{"role": "assistant", "content": "Hi..."}
|
| 182 |
+
],
|
| 183 |
+
"memory_points": [...] // 仅 S00 需要
|
| 184 |
+
},
|
| 185 |
+
{"_v2_session_id": "S01", "dialogue": [...]}
|
| 186 |
+
]
|
| 187 |
+
}
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
**`stage4_memory_points.jsonl`** 每行一个 sample:
|
| 191 |
+
|
| 192 |
+
```json
|
| 193 |
+
{
|
| 194 |
+
"uuid": "...", "sample_id": "sample_001",
|
| 195 |
+
"memory_sessions": [
|
| 196 |
+
{
|
| 197 |
+
"session_id": "S01",
|
| 198 |
+
"memory_points": [
|
| 199 |
+
{
|
| 200 |
+
"memory_id": "m001",
|
| 201 |
+
"memory_content": "User prefers dark mode",
|
| 202 |
+
"memory_type": "preference",
|
| 203 |
+
"memory_source": "normal",
|
| 204 |
+
"is_update": false,
|
| 205 |
+
"original_memories": [],
|
| 206 |
+
"importance": 0.8
|
| 207 |
+
}
|
| 208 |
+
]
|
| 209 |
+
}
|
| 210 |
+
]
|
| 211 |
+
}
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
**`stage4b_qa_checkpoints.jsonl`** 每行一个 sample:
|
| 215 |
+
|
| 216 |
+
```json
|
| 217 |
+
{
|
| 218 |
+
"uuid": "...", "sample_id": "sample_001",
|
| 219 |
+
"checkpoints": [
|
| 220 |
+
{
|
| 221 |
+
"checkpoint_id": "cp01",
|
| 222 |
+
"covered_sessions": ["S00", "S01"],
|
| 223 |
+
"questions": [
|
| 224 |
+
{
|
| 225 |
+
"question": "What theme does the user prefer?",
|
| 226 |
+
"answer": "Dark mode",
|
| 227 |
+
"question_type": "preference_recall",
|
| 228 |
+
"question_type_abbrev": "pref",
|
| 229 |
+
"difficulty": "easy",
|
| 230 |
+
"evidence": [{"memory_id": "m001"}]
|
| 231 |
+
}
|
| 232 |
+
]
|
| 233 |
+
}
|
| 234 |
+
]
|
| 235 |
+
}
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
### 4.2 适配自有数据
|
| 239 |
+
|
| 240 |
+
若要接入新数据源,有两条路径:
|
| 241 |
+
|
| 242 |
+
**路径 A:��换为 domain_a_v2 格式**(推荐)
|
| 243 |
+
- 将原始对话整理为上述三文件格式
|
| 244 |
+
- 直接使用现有 CLI 运行
|
| 245 |
+
|
| 246 |
+
**路径 B:编写新的 dataset loader**
|
| 247 |
+
- 在 `datasets/` 下新建加载器,返回 `DomainAV2AcademicBundle`(或等价结构)
|
| 248 |
+
- 在 `cli.py` 的 `run_eval()` 中通过 `load_domain_bundle` 参数注入
|
| 249 |
+
|
| 250 |
+
### 4.3 关键数据结构
|
| 251 |
+
|
| 252 |
+
每条对话 turn 会被归一化为 `NormalizedTurn`:
|
| 253 |
+
|
| 254 |
+
```python
|
| 255 |
+
NormalizedTurn(
|
| 256 |
+
sample_id="sample_001",
|
| 257 |
+
session_id="S01",
|
| 258 |
+
turn_index=0,
|
| 259 |
+
role="user", # "user" | "assistant"
|
| 260 |
+
text="Hello...",
|
| 261 |
+
attachments=(Attachment(caption="...", type="image_caption"),),
|
| 262 |
+
timestamp="2025-01-01T10:00:00",
|
| 263 |
+
)
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
Memory 的 gold 标注支持三种来源标记:
|
| 267 |
+
- `normal` — 正常记忆点
|
| 268 |
+
- `interference` — 干扰信息(不应被记忆)
|
| 269 |
+
- `is_update=True` — 更新型记忆(应替换旧记忆)
|
| 270 |
+
|
| 271 |
+
## 5. 环境配置(uv)
|
| 272 |
+
|
| 273 |
+
### 5.1 安装 uv
|
| 274 |
+
|
| 275 |
+
```bash
|
| 276 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
### 5.2 初始化项目环境
|
| 280 |
+
|
| 281 |
+
```bash
|
| 282 |
+
cd /data1/toby/nips26
|
| 283 |
+
|
| 284 |
+
# 创建虚拟环境
|
| 285 |
+
uv venv .venv --python 3.11
|
| 286 |
+
source .venv/bin/activate
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
### 5.3 安装核心依赖
|
| 290 |
+
|
| 291 |
+
```bash
|
| 292 |
+
# 最小依赖(可跑 FUMemory/STMemory 等纯文本 baseline)
|
| 293 |
+
uv pip install openai tenacity
|
| 294 |
+
|
| 295 |
+
# embedding 检索类 baseline(LTMemory, GAMemory, MGMemory 等)
|
| 296 |
+
uv pip install sentence-transformers
|
| 297 |
+
|
| 298 |
+
# 多模态 baseline(MMMemory, NGMemory, AUGUSTUSMemory 等)
|
| 299 |
+
uv pip install torch torchvision transformers
|
| 300 |
+
|
| 301 |
+
# 外部 baseline(A-Mem, MemoryOS)— 按各自文档安装额外依赖
|
| 302 |
+
# A-Mem 需要其源码目录下的 requirements
|
| 303 |
+
# MemoryOS 需要 memoryos 包
|
| 304 |
+
```
|
| 305 |
+
|
| 306 |
+
### 5.4 环境变量(.env 文件)
|
| 307 |
+
|
| 308 |
+
在项目根目录 (`nips26/`) 创建 `.env` 文件,框架会自动加载:
|
| 309 |
+
|
| 310 |
+
```bash
|
| 311 |
+
# .env
|
| 312 |
+
# 必需 — LLM API(pipeline 答题 + judge 评估统一使用)
|
| 313 |
+
OPENAI_API_KEY=sk-...
|
| 314 |
+
OPENAI_BASE_URL=https://api.openai.com/v1 # 或兼容端点
|
| 315 |
+
OPENAI_MODEL=gpt-4o
|
| 316 |
+
|
| 317 |
+
# 可选
|
| 318 |
+
OPENAI_TEMPERATURE=0.0
|
| 319 |
+
OPENAI_MAX_TOKENS=1024
|
| 320 |
+
OPENAI_TIMEOUT=120
|
| 321 |
+
JUDGE_TEMPERATURE=0.0 # judge 专用温度
|
| 322 |
+
LLM_MAX_CONCURRENT=5 # LLM 并发上限
|
| 323 |
+
```
|
| 324 |
+
|
| 325 |
+
### 5.5 Mem-Gallery 本地依赖
|
| 326 |
+
|
| 327 |
+
Mem-Gallery 内置 baseline 需要将其源码放到 `eval_framework/` 的同级目录:
|
| 328 |
+
|
| 329 |
+
```bash
|
| 330 |
+
# 假设 Mem-Gallery repo 在 /path/to/Mem-Gallery
|
| 331 |
+
cp -r /path/to/Mem-Gallery/benchmark/memengine /data1/toby/nips26/
|
| 332 |
+
cp -r /path/to/Mem-Gallery/benchmark/default_config /data1/toby/nips26/
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
最终目录结构应为:
|
| 336 |
+
|
| 337 |
+
```
|
| 338 |
+
nips26/
|
| 339 |
+
├── eval_framework/
|
| 340 |
+
├── memengine/ # Mem-Gallery 记忆引擎
|
| 341 |
+
└── default_config/ # Mem-Gallery 默认配置
|
| 342 |
+
```
|
| 343 |
+
|
| 344 |
+
## 6. 运行示例
|
| 345 |
+
|
| 346 |
+
### 基本运行
|
| 347 |
+
|
| 348 |
+
```bash
|
| 349 |
+
# 运行单个 baseline
|
| 350 |
+
python -m eval_framework.cli \
|
| 351 |
+
--dataset /path/to/domain_a_v2_data/ \
|
| 352 |
+
--baseline FUMemory \
|
| 353 |
+
--output-dir eval_framework/results/FUMemory
|
| 354 |
+
|
| 355 |
+
# smoke 模式(只跑第 1 个 sample,快速验证)
|
| 356 |
+
python -m eval_framework.cli \
|
| 357 |
+
--dataset /path/to/domain_a_v2_data/ \
|
| 358 |
+
--baseline FUMemory \
|
| 359 |
+
--output-dir eval_framework/results/FUMemory_smoke \
|
| 360 |
+
--smoke
|
| 361 |
+
|
| 362 |
+
# dry-run(不实际运行,打印配置)
|
| 363 |
+
python -m eval_framework.cli \
|
| 364 |
+
--dataset /path/to/domain_a_v2_data/ \
|
| 365 |
+
--baseline FUMemory \
|
| 366 |
+
--dry-run
|
| 367 |
+
|
| 368 |
+
# 仅重跑 eval 阶段(从 checkpoint 恢复,pipeline 不重跑)
|
| 369 |
+
python -m eval_framework.cli \
|
| 370 |
+
--dataset /path/to/domain_a_v2_data/ \
|
| 371 |
+
--baseline FUMemory \
|
| 372 |
+
--output-dir eval_framework/results/FUMemory \
|
| 373 |
+
--eval-only
|
| 374 |
+
|
| 375 |
+
# 调整 eval 并发数
|
| 376 |
+
python -m eval_framework.cli \
|
| 377 |
+
--dataset /path/to/domain_a_v2_data/ \
|
| 378 |
+
--baseline MGMemory \
|
| 379 |
+
--output-dir eval_framework/results/MGMemory \
|
| 380 |
+
--max-eval-workers 10
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
### 批量跑所有 baseline
|
| 384 |
+
|
| 385 |
+
```bash
|
| 386 |
+
DATASET="/path/to/domain_a_v2_data"
|
| 387 |
+
for baseline in FUMemory STMemory LTMemory GAMemory MGMemory RFMemory A-Mem MemoryOS; do
|
| 388 |
+
echo "=== Running $baseline ==="
|
| 389 |
+
python -m eval_framework.cli \
|
| 390 |
+
--dataset "$DATASET" \
|
| 391 |
+
--baseline "$baseline" \
|
| 392 |
+
--output-dir "eval_framework/results/$baseline"
|
| 393 |
+
done
|
| 394 |
+
```
|
| 395 |
+
|
| 396 |
+
### 输出文件说明
|
| 397 |
+
|
| 398 |
+
运行完成后 `output-dir` 下包含:
|
| 399 |
+
|
| 400 |
+
```
|
| 401 |
+
results/FUMemory/
|
| 402 |
+
├── pipeline_sessions.jsonl # Stage 1 checkpoint — session 级 pipeline 结果
|
| 403 |
+
├── pipeline_qa.jsonl # Stage 1 checkpoint — QA 级 pipeline 结果
|
| 404 |
+
├── session_records.jsonl # 最终 session 结果(含 eval)
|
| 405 |
+
├── qa_records.jsonl # 最终 QA 结果(含 eval)
|
| 406 |
+
└── aggregate_metrics.json # baseline 级汇总指标
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
## 7. LLM API 开销估算
|
| 410 |
+
|
| 411 |
+
每个 sample 的 LLM 调用量:
|
| 412 |
+
|
| 413 |
+
| 来源 | 调用次数 |
|
| 414 |
+
|------|---------|
|
| 415 |
+
| Pipeline answer(每个 QA question) | N_questions |
|
| 416 |
+
| Session Recall judge | N_sessions |
|
| 417 |
+
| Session Correctness judge | N_sessions |
|
| 418 |
+
| Update judge | N_update_points(逐条) |
|
| 419 |
+
| Interference judge | N_interference_points(逐条) |
|
| 420 |
+
| QA Answer judge | N_questions |
|
| 421 |
+
| QA Evidence judge | N_questions |
|
| 422 |
+
|
| 423 |
+
典型场景下一个 sample 约 20-50 次 LLM 调用。通过 `LLM_MAX_CONCURRENT` 控制并发避免 rate limit。
|
docs/OUTPUT_FORMAT.md
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Eval Framework 输出格式
|
| 2 |
+
|
| 3 |
+
## 输出目录结构
|
| 4 |
+
|
| 5 |
+
运行完成后 `--output-dir` 下包含 5 个文件:
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
output-dir/
|
| 9 |
+
├── pipeline_sessions.jsonl # Stage 1 checkpoint — pipeline 中间结果(session 级)
|
| 10 |
+
├── pipeline_qa.jsonl # Stage 1 checkpoint — pipeline 中间结果(QA 级)
|
| 11 |
+
├── session_records.jsonl # 最终结果:session pipeline 数据 + eval 评判
|
| 12 |
+
├── qa_records.jsonl # 最终结果:QA pipeline 数据 + eval 评判
|
| 13 |
+
└── aggregate_metrics.json # 最终结果:baseline 级别汇总指标
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## 文件详解
|
| 17 |
+
|
| 18 |
+
### 1. `session_records.jsonl`
|
| 19 |
+
|
| 20 |
+
每行一个 session,包含 pipeline 原始数据和 `eval` 评判结果:
|
| 21 |
+
|
| 22 |
+
```json
|
| 23 |
+
{
|
| 24 |
+
"sample_id": "vab_minecraft_...",
|
| 25 |
+
"sample_uuid": "uuid-...",
|
| 26 |
+
"session_id": "S01",
|
| 27 |
+
"memory_snapshot": [
|
| 28 |
+
{
|
| 29 |
+
"memory_id": "3",
|
| 30 |
+
"text": "user: OBSERVATION: ...\nassistant: THOUGHT: ...",
|
| 31 |
+
"session_id": "S01",
|
| 32 |
+
"status": "active",
|
| 33 |
+
"source": "FUMemory",
|
| 34 |
+
"raw_backend_id": "3",
|
| 35 |
+
"raw_backend_type": "linear",
|
| 36 |
+
"metadata": {}
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
"memory_delta": [
|
| 40 |
+
{
|
| 41 |
+
"session_id": "S01",
|
| 42 |
+
"op": "add",
|
| 43 |
+
"text": "user: OBSERVATION: ...",
|
| 44 |
+
"linked_previous": [],
|
| 45 |
+
"raw_backend_id": "3",
|
| 46 |
+
"metadata": {"baseline": "FUMemory"}
|
| 47 |
+
}
|
| 48 |
+
],
|
| 49 |
+
"gold_state": {
|
| 50 |
+
"session_id": "S01",
|
| 51 |
+
"cumulative_gold_memories": [...],
|
| 52 |
+
"session_new_memories": [...],
|
| 53 |
+
"session_update_memories": [...],
|
| 54 |
+
"session_interference_memories": []
|
| 55 |
+
},
|
| 56 |
+
"eval": {
|
| 57 |
+
"session_id": "S01",
|
| 58 |
+
"recall": 0.8,
|
| 59 |
+
"covered_count": 4,
|
| 60 |
+
"num_gold": 5,
|
| 61 |
+
"update_recall": 1.0,
|
| 62 |
+
"update_covered_count": 2,
|
| 63 |
+
"update_total": 2,
|
| 64 |
+
"recall_reasoning": "4 of 5 gold points are covered...",
|
| 65 |
+
"correctness_rate": 0.75,
|
| 66 |
+
"num_memories": 8,
|
| 67 |
+
"num_correct": 6,
|
| 68 |
+
"num_hallucination": 1,
|
| 69 |
+
"num_irrelevant": 1,
|
| 70 |
+
"correctness_reasoning": "...",
|
| 71 |
+
"correctness_records": [
|
| 72 |
+
{"id": 1, "label": "correct"},
|
| 73 |
+
{"id": 2, "label": "hallucination"}
|
| 74 |
+
],
|
| 75 |
+
"update_score": 1.0,
|
| 76 |
+
"update_num_updated": 2,
|
| 77 |
+
"update_num_both": 0,
|
| 78 |
+
"update_num_outdated": 0,
|
| 79 |
+
"update_total_items": 2,
|
| 80 |
+
"update_records": [
|
| 81 |
+
{"memory_id": "mp_S08_3", "label": "updated", "reasoning": "..."}
|
| 82 |
+
],
|
| 83 |
+
"interference_score": null,
|
| 84 |
+
"interference_num_rejected": 0,
|
| 85 |
+
"interference_num_memorized": 0,
|
| 86 |
+
"interference_total_items": 0,
|
| 87 |
+
"interference_records": []
|
| 88 |
+
}
|
| 89 |
+
}
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
**eval 字段说明:**
|
| 93 |
+
|
| 94 |
+
| 字段 | 含义 |
|
| 95 |
+
|------|------|
|
| 96 |
+
| `recall` | 本 session gold points 被 delta 覆盖的比例 (0-1) |
|
| 97 |
+
| `update_recall` | update 类型 gold points 的覆盖比例 |
|
| 98 |
+
| `correctness_rate` | delta 中正确记忆的比例 |
|
| 99 |
+
| `num_hallucination` | delta 中幻觉记忆数量 |
|
| 100 |
+
| `num_irrelevant` | delta 中无关记忆数量 |
|
| 101 |
+
| `update_score` | 更新处理得分 (updated=1.0, both=0.5, outdated=0.0) |
|
| 102 |
+
| `interference_score` | 干扰拒绝得分 (rejected=1.0, memorized=0.0) |
|
| 103 |
+
|
| 104 |
+
### 2. `qa_records.jsonl`
|
| 105 |
+
|
| 106 |
+
每行一个 QA question,包含检索结果、模型回答和评判:
|
| 107 |
+
|
| 108 |
+
```json
|
| 109 |
+
{
|
| 110 |
+
"sample_id": "vab_minecraft_...",
|
| 111 |
+
"sample_uuid": "uuid-...",
|
| 112 |
+
"checkpoint_id": "probe_e980c238",
|
| 113 |
+
"question": "What was in the agent's inventory at step 1?",
|
| 114 |
+
"gold_answer": "At step 1, the agent's inventory was empty.",
|
| 115 |
+
"gold_evidence_memory_ids": ["mp_S04_1"],
|
| 116 |
+
"gold_evidence_contents": ["The agent started with empty inventory"],
|
| 117 |
+
"question_type": "factual_recall",
|
| 118 |
+
"question_type_abbrev": "FR",
|
| 119 |
+
"difficulty": "easy",
|
| 120 |
+
"retrieval": {
|
| 121 |
+
"query": "What was in the agent's inventory at step 1?",
|
| 122 |
+
"top_k": 5,
|
| 123 |
+
"items": [
|
| 124 |
+
{
|
| 125 |
+
"rank": 0,
|
| 126 |
+
"memory_id": "memgallery:string_bundle",
|
| 127 |
+
"text": "user: OBSERVATION: Your Inventory: ...",
|
| 128 |
+
"score": 1.0,
|
| 129 |
+
"raw_backend_id": null
|
| 130 |
+
}
|
| 131 |
+
],
|
| 132 |
+
"raw_trace": {"baseline": "FUMemory"}
|
| 133 |
+
},
|
| 134 |
+
"generated_answer": "The agent's inventory was empty at step 1.",
|
| 135 |
+
"cited_memories": ["user: OBSERVATION: Inventory: nothing"],
|
| 136 |
+
"eval": {
|
| 137 |
+
"answer_label": "Correct",
|
| 138 |
+
"answer_reasoning": "The response matches the reference answer...",
|
| 139 |
+
"answer_is_valid": true,
|
| 140 |
+
"evidence_hit_rate": 1.0,
|
| 141 |
+
"evidence_covered_count": 1,
|
| 142 |
+
"num_evidence": 1,
|
| 143 |
+
"evidence_reasoning": "The cited memory covers the gold evidence...",
|
| 144 |
+
"num_cited_memories": 1
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
**eval 字段说明:**
|
| 150 |
+
|
| 151 |
+
| 字段 | 含义 |
|
| 152 |
+
|------|------|
|
| 153 |
+
| `answer_label` | `Correct` / `Hallucination` / `Omission` |
|
| 154 |
+
| `answer_is_valid` | 评判是否成功(非 LLM 错误) |
|
| 155 |
+
| `evidence_hit_rate` | cited memories 覆盖了多少 gold evidence (0-1) |
|
| 156 |
+
| `evidence_covered_count` | 被覆盖的 gold evidence 数量 |
|
| 157 |
+
| `num_cited_memories` | 模型回答时引用的记忆条数 |
|
| 158 |
+
|
| 159 |
+
### 3. `aggregate_metrics.json`
|
| 160 |
+
|
| 161 |
+
baseline 级别的 6 维汇总指标:
|
| 162 |
+
|
| 163 |
+
```json
|
| 164 |
+
{
|
| 165 |
+
"baseline_id": "FUMemory",
|
| 166 |
+
"memory_recall": {
|
| 167 |
+
"avg_recall": 0.72,
|
| 168 |
+
"avg_update_recall": 0.65,
|
| 169 |
+
"num_sessions_with_recall": 110,
|
| 170 |
+
"num_sessions_with_update": 85,
|
| 171 |
+
"total_covered": 320,
|
| 172 |
+
"total_gold": 445
|
| 173 |
+
},
|
| 174 |
+
"memory_correctness": {
|
| 175 |
+
"avg_correctness": 0.81,
|
| 176 |
+
"avg_hallucination": 0.08,
|
| 177 |
+
"avg_irrelevant": 0.11,
|
| 178 |
+
"num_sessions": 110,
|
| 179 |
+
"total_memories": 1200,
|
| 180 |
+
"total_correct": 972,
|
| 181 |
+
"total_hallucination": 96,
|
| 182 |
+
"total_irrelevant": 132
|
| 183 |
+
},
|
| 184 |
+
"update_handling": {
|
| 185 |
+
"score": 0.65,
|
| 186 |
+
"num_updated": 52,
|
| 187 |
+
"num_both": 18,
|
| 188 |
+
"num_outdated": 15,
|
| 189 |
+
"num_total": 85
|
| 190 |
+
},
|
| 191 |
+
"interference_rejection": {
|
| 192 |
+
"score": 0.0,
|
| 193 |
+
"num_rejected": 0,
|
| 194 |
+
"num_memorized": 0,
|
| 195 |
+
"num_total": 0
|
| 196 |
+
},
|
| 197 |
+
"question_answering": {
|
| 198 |
+
"correct_ratio": 0.58,
|
| 199 |
+
"hallucination_ratio": 0.22,
|
| 200 |
+
"omission_ratio": 0.20,
|
| 201 |
+
"num_total": 990,
|
| 202 |
+
"num_valid": 990
|
| 203 |
+
},
|
| 204 |
+
"evidence_coverage": {
|
| 205 |
+
"hit_rate": 0.43,
|
| 206 |
+
"num_covered": 425,
|
| 207 |
+
"num_total": 990
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
**6 个维度:**
|
| 213 |
+
|
| 214 |
+
| 维度 | 聚合方式 | 核心指标 | 方向 |
|
| 215 |
+
|------|---------|---------|------|
|
| 216 |
+
| Memory Recall | 按 session 平均 | `avg_recall` | ↑ |
|
| 217 |
+
| Memory Correctness | 按 session 平均 | `avg_correctness`, `avg_hallucination` | ↑, ↓ |
|
| 218 |
+
| Update Handling | 跨 session 池化 | `score` | ↑ |
|
| 219 |
+
| Interference Rejection | 跨 session 池化 | `score` | ↑ |
|
| 220 |
+
| Question Answering | 跨 question 池化 | `correct_ratio`, `hallucination_ratio` | ↑, ↓ |
|
| 221 |
+
| Evidence Coverage | 跨 question 池化 | `hit_rate` | ↑ |
|
| 222 |
+
|
| 223 |
+
### 4. `pipeline_sessions.jsonl` / `pipeline_qa.jsonl`
|
| 224 |
+
|
| 225 |
+
Stage 1 的 checkpoint 文件,结构与 `session_records.jsonl` / `qa_records.jsonl` 相同但**不含 `eval` 字段**。
|
| 226 |
+
|
| 227 |
+
用途:`--eval-only` 模式跳过 pipeline 直接从 checkpoint 恢复,只重跑 eval 阶段。典型场景:
|
| 228 |
+
|
| 229 |
+
```bash
|
| 230 |
+
# 首次完整运行
|
| 231 |
+
python -m eval_framework.cli --dataset ... --baseline FUMemory --output-dir results/FU
|
| 232 |
+
|
| 233 |
+
# 换 judge 模型重评(不重跑 pipeline)
|
| 234 |
+
OPENAI_MODEL=gpt-4o-mini python -m eval_framework.cli \
|
| 235 |
+
--dataset ... --baseline FUMemory --output-dir results/FU --eval-only
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## 结果分析示例
|
| 239 |
+
|
| 240 |
+
```python
|
| 241 |
+
import json
|
| 242 |
+
|
| 243 |
+
# 读取汇总
|
| 244 |
+
with open("results/FUMemory/aggregate_metrics.json") as f:
|
| 245 |
+
agg = json.load(f)
|
| 246 |
+
print(f"Recall: {agg['memory_recall']['avg_recall']:.2%}")
|
| 247 |
+
print(f"QA Correct: {agg['question_answering']['correct_ratio']:.2%}")
|
| 248 |
+
|
| 249 |
+
# 按 QA type 分析正确率
|
| 250 |
+
qa_by_type = {}
|
| 251 |
+
with open("results/FUMemory/qa_records.jsonl") as f:
|
| 252 |
+
for line in f:
|
| 253 |
+
rec = json.loads(line)
|
| 254 |
+
qt = rec["question_type_abbrev"]
|
| 255 |
+
label = rec["eval"]["answer_label"]
|
| 256 |
+
qa_by_type.setdefault(qt, []).append(label)
|
| 257 |
+
|
| 258 |
+
for qt, labels in sorted(qa_by_type.items()):
|
| 259 |
+
correct = sum(1 for l in labels if l == "Correct")
|
| 260 |
+
print(f" {qt}: {correct}/{len(labels)} = {correct/len(labels):.0%}")
|
| 261 |
+
```
|
evaluators/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Session- and checkpoint-level evaluators using batch LLM judge."""
|
| 2 |
+
|
| 3 |
+
from eval_framework.evaluators.aggregate import aggregate_metrics
|
| 4 |
+
from eval_framework.evaluators.extraction import evaluate_extraction
|
| 5 |
+
from eval_framework.evaluators.qa import evaluate_checkpoint_qa
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"aggregate_metrics",
|
| 9 |
+
"evaluate_checkpoint_qa",
|
| 10 |
+
"evaluate_extraction",
|
| 11 |
+
]
|
evaluators/aggregate.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Roll up per-session and per-QA evaluations into baseline-level summaries.
|
| 2 |
+
|
| 3 |
+
Recall & correctness: per-session average (not pooled cumulative).
|
| 4 |
+
Interference: pooled across sessions.
|
| 5 |
+
QA & evidence: pooled across questions.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from collections.abc import Mapping, Sequence
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _safe_div(a: float, b: float) -> float:
|
| 14 |
+
return a / b if b else 0.0
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def aggregate_metrics(
|
| 18 |
+
baseline_id: str,
|
| 19 |
+
*,
|
| 20 |
+
session_evaluations: Sequence[Mapping[str, object]] = (),
|
| 21 |
+
qa_evaluations: Sequence[Mapping[str, object]] = (),
|
| 22 |
+
) -> dict[str, object]:
|
| 23 |
+
"""Aggregate all per-session and per-QA evaluations."""
|
| 24 |
+
|
| 25 |
+
# --- Per-session recall (average) ---
|
| 26 |
+
recall_scores: list[float] = []
|
| 27 |
+
update_recall_scores: list[float] = []
|
| 28 |
+
|
| 29 |
+
# --- Per-session correctness (average) ---
|
| 30 |
+
correctness_scores: list[float] = []
|
| 31 |
+
hallucination_scores: list[float] = []
|
| 32 |
+
irrelevant_scores: list[float] = []
|
| 33 |
+
|
| 34 |
+
# --- Update handling (pooled) ---
|
| 35 |
+
upd_num_updated = 0
|
| 36 |
+
upd_num_both = 0
|
| 37 |
+
upd_num_outdated = 0
|
| 38 |
+
upd_total_items = 0
|
| 39 |
+
|
| 40 |
+
# --- Interference rejection (pooled) ---
|
| 41 |
+
interf_num_rejected = 0
|
| 42 |
+
interf_num_memorized = 0
|
| 43 |
+
interf_total_items = 0
|
| 44 |
+
|
| 45 |
+
# --- Per-session detail counters (for reference) ---
|
| 46 |
+
total_gold_points = 0
|
| 47 |
+
total_covered = 0
|
| 48 |
+
total_memories = 0
|
| 49 |
+
total_correct = 0
|
| 50 |
+
total_hallucination = 0
|
| 51 |
+
total_irrelevant = 0
|
| 52 |
+
|
| 53 |
+
for s in session_evaluations:
|
| 54 |
+
# Recall: per-session score
|
| 55 |
+
r = s.get("recall")
|
| 56 |
+
if r is not None:
|
| 57 |
+
recall_scores.append(float(r))
|
| 58 |
+
|
| 59 |
+
ur = s.get("update_recall")
|
| 60 |
+
if ur is not None:
|
| 61 |
+
update_recall_scores.append(float(ur))
|
| 62 |
+
|
| 63 |
+
# Correctness: per-session score
|
| 64 |
+
cr = s.get("correctness_rate")
|
| 65 |
+
if cr is not None:
|
| 66 |
+
correctness_scores.append(float(cr))
|
| 67 |
+
|
| 68 |
+
nm = int(s.get("num_memories", 0))
|
| 69 |
+
if nm > 0:
|
| 70 |
+
hallucination_scores.append(
|
| 71 |
+
float(s.get("num_hallucination", 0)) / nm
|
| 72 |
+
)
|
| 73 |
+
irrelevant_scores.append(
|
| 74 |
+
float(s.get("num_irrelevant", 0)) / nm
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Detail counters
|
| 78 |
+
c = s.get("covered_count")
|
| 79 |
+
if c is not None:
|
| 80 |
+
total_covered += int(c)
|
| 81 |
+
total_gold_points += int(s.get("num_gold", 0))
|
| 82 |
+
total_memories += nm
|
| 83 |
+
total_correct += int(s.get("num_correct", 0))
|
| 84 |
+
total_hallucination += int(s.get("num_hallucination", 0))
|
| 85 |
+
total_irrelevant += int(s.get("num_irrelevant", 0))
|
| 86 |
+
|
| 87 |
+
# Update handling (pooled)
|
| 88 |
+
upd_num_updated += int(s.get("update_num_updated", 0))
|
| 89 |
+
upd_num_both += int(s.get("update_num_both", 0))
|
| 90 |
+
upd_num_outdated += int(s.get("update_num_outdated", 0))
|
| 91 |
+
upd_total_items += int(s.get("update_total_items", 0))
|
| 92 |
+
|
| 93 |
+
# Interference rejection (pooled)
|
| 94 |
+
interf_num_rejected += int(s.get("interference_num_rejected", 0))
|
| 95 |
+
interf_num_memorized += int(s.get("interference_num_memorized", 0))
|
| 96 |
+
interf_total_items += int(s.get("interference_total_items", 0))
|
| 97 |
+
|
| 98 |
+
# --- QA (pooled) ---
|
| 99 |
+
qa_total = 0
|
| 100 |
+
qa_valid = 0
|
| 101 |
+
qa_correct = 0
|
| 102 |
+
qa_hallucination = 0
|
| 103 |
+
qa_omission = 0
|
| 104 |
+
evidence_covered = 0
|
| 105 |
+
evidence_total = 0
|
| 106 |
+
|
| 107 |
+
for q in qa_evaluations:
|
| 108 |
+
qa_total += 1
|
| 109 |
+
label = q.get("answer_label")
|
| 110 |
+
if label in ("Correct", "Hallucination", "Omission"):
|
| 111 |
+
qa_valid += 1
|
| 112 |
+
if label == "Correct":
|
| 113 |
+
qa_correct += 1
|
| 114 |
+
elif label == "Hallucination":
|
| 115 |
+
qa_hallucination += 1
|
| 116 |
+
elif label == "Omission":
|
| 117 |
+
qa_omission += 1
|
| 118 |
+
|
| 119 |
+
ec = q.get("evidence_covered_count")
|
| 120 |
+
if ec is not None:
|
| 121 |
+
evidence_covered += int(ec)
|
| 122 |
+
evidence_total += int(q.get("num_evidence", 0))
|
| 123 |
+
|
| 124 |
+
n_recall = len(recall_scores)
|
| 125 |
+
n_update = len(update_recall_scores)
|
| 126 |
+
n_correct = len(correctness_scores)
|
| 127 |
+
n_hallu = len(hallucination_scores)
|
| 128 |
+
n_irrel = len(irrelevant_scores)
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
"baseline_id": baseline_id,
|
| 132 |
+
"memory_recall": {
|
| 133 |
+
"avg_recall": _safe_div(sum(recall_scores), n_recall),
|
| 134 |
+
"avg_update_recall": _safe_div(sum(update_recall_scores), n_update),
|
| 135 |
+
"num_sessions_with_recall": n_recall,
|
| 136 |
+
"num_sessions_with_update": n_update,
|
| 137 |
+
"total_covered": total_covered,
|
| 138 |
+
"total_gold": total_gold_points,
|
| 139 |
+
},
|
| 140 |
+
"memory_correctness": {
|
| 141 |
+
"avg_correctness": _safe_div(sum(correctness_scores), n_correct),
|
| 142 |
+
"avg_hallucination": _safe_div(sum(hallucination_scores), n_hallu),
|
| 143 |
+
"avg_irrelevant": _safe_div(sum(irrelevant_scores), n_irrel),
|
| 144 |
+
"num_sessions": n_correct,
|
| 145 |
+
"total_memories": total_memories,
|
| 146 |
+
"total_correct": total_correct,
|
| 147 |
+
"total_hallucination": total_hallucination,
|
| 148 |
+
"total_irrelevant": total_irrelevant,
|
| 149 |
+
},
|
| 150 |
+
"update_handling": {
|
| 151 |
+
"score": _safe_div(upd_num_updated * 1.0 + upd_num_both * 0.5, upd_total_items),
|
| 152 |
+
"num_updated": upd_num_updated,
|
| 153 |
+
"num_both": upd_num_both,
|
| 154 |
+
"num_outdated": upd_num_outdated,
|
| 155 |
+
"num_total": upd_total_items,
|
| 156 |
+
},
|
| 157 |
+
"interference_rejection": {
|
| 158 |
+
"score": _safe_div(interf_num_rejected, interf_total_items),
|
| 159 |
+
"num_rejected": interf_num_rejected,
|
| 160 |
+
"num_memorized": interf_num_memorized,
|
| 161 |
+
"num_total": interf_total_items,
|
| 162 |
+
},
|
| 163 |
+
"question_answering": {
|
| 164 |
+
"correct_ratio": _safe_div(qa_correct, qa_valid),
|
| 165 |
+
"hallucination_ratio": _safe_div(qa_hallucination, qa_valid),
|
| 166 |
+
"omission_ratio": _safe_div(qa_omission, qa_valid),
|
| 167 |
+
"num_total": qa_total,
|
| 168 |
+
"num_valid": qa_valid,
|
| 169 |
+
},
|
| 170 |
+
"evidence_coverage": {
|
| 171 |
+
"hit_rate": _safe_div(evidence_covered, evidence_total),
|
| 172 |
+
"num_covered": evidence_covered,
|
| 173 |
+
"num_total": evidence_total,
|
| 174 |
+
},
|
| 175 |
+
}
|
evaluators/extraction.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unified session evaluation: recall + correctness (includes update & interference).
|
| 2 |
+
|
| 3 |
+
Per session, 2 LLM calls — both scoped to THIS SESSION's memory delta only:
|
| 4 |
+
Call 1 — Recall: how many of this session's gold points are covered by the
|
| 5 |
+
session's memory delta (add/update ops)?
|
| 6 |
+
Call 2 — Correctness: is each delta memory correct, hallucinated, or irrelevant?
|
| 7 |
+
(reference = this session's gold points + interference)
|
| 8 |
+
|
| 9 |
+
Aggregate: per-session recall/correctness averaged across sessions.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from eval_framework.judges import (
|
| 15 |
+
evaluate_correctness_batch,
|
| 16 |
+
evaluate_interference_single,
|
| 17 |
+
evaluate_recall_batch,
|
| 18 |
+
evaluate_update_single,
|
| 19 |
+
)
|
| 20 |
+
from eval_framework.pipeline.records import PipelineSessionRecord
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _delta_to_text(session: PipelineSessionRecord) -> str:
|
| 24 |
+
"""Only the memories added or updated in THIS session (not the full snapshot)."""
|
| 25 |
+
lines: list[str] = []
|
| 26 |
+
idx = 0
|
| 27 |
+
for d in session.memory_delta:
|
| 28 |
+
if d.op in ("add", "update"):
|
| 29 |
+
idx += 1
|
| 30 |
+
lines.append(f"[{idx}] {d.text}")
|
| 31 |
+
return "\n".join(lines)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _delta_texts(session: PipelineSessionRecord) -> list[str]:
|
| 35 |
+
"""Text list of memories added or updated in THIS session."""
|
| 36 |
+
return [d.text for d in session.memory_delta if d.op in ("add", "update")]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _build_recall_gold_points(session: PipelineSessionRecord) -> list[str]:
|
| 40 |
+
"""Current session's new + update gold points only (NOT cumulative)."""
|
| 41 |
+
out: list[str] = []
|
| 42 |
+
for g in session.gold_state.session_new_memories:
|
| 43 |
+
out.append(f"[normal] {g.memory_content}")
|
| 44 |
+
for g in session.gold_state.session_update_memories:
|
| 45 |
+
out.append(f"[update] {g.memory_content}")
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _build_correctness_gold_points(session: PipelineSessionRecord) -> list[str]:
|
| 50 |
+
"""Current session's new + update + interference gold points as reference."""
|
| 51 |
+
out: list[str] = []
|
| 52 |
+
for g in session.gold_state.session_new_memories:
|
| 53 |
+
out.append(f"[normal] {g.memory_content}")
|
| 54 |
+
for g in session.gold_state.session_update_memories:
|
| 55 |
+
out.append(f"[update] {g.memory_content}")
|
| 56 |
+
for g in session.gold_state.session_interference_memories:
|
| 57 |
+
out.append(f"[interference] {g.memory_content}")
|
| 58 |
+
return out
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def evaluate_extraction(
|
| 62 |
+
session: PipelineSessionRecord,
|
| 63 |
+
**_kwargs: object,
|
| 64 |
+
) -> dict[str, object]:
|
| 65 |
+
"""Unified session evaluation: recall + correctness in 2 LLM calls.
|
| 66 |
+
|
| 67 |
+
Uses only THIS session's new gold points for recall and correctness,
|
| 68 |
+
not the cumulative history. Aggregate averages per-session scores.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
delta_str = _delta_to_text(session)
|
| 72 |
+
delta_texts = _delta_texts(session)
|
| 73 |
+
interference_total = len(session.gold_state.session_interference_memories)
|
| 74 |
+
|
| 75 |
+
# --- Call 1: Recall (this session's gold points vs this session's delta) ---
|
| 76 |
+
recall_gold = _build_recall_gold_points(session)
|
| 77 |
+
|
| 78 |
+
if not recall_gold:
|
| 79 |
+
recall = None
|
| 80 |
+
update_recall = None
|
| 81 |
+
recall_result: dict[str, object] = {
|
| 82 |
+
"covered_count": 0, "update_covered_count": 0,
|
| 83 |
+
"total": 0, "update_total": 0,
|
| 84 |
+
"reasoning": "No new gold points in this session.",
|
| 85 |
+
}
|
| 86 |
+
elif not delta_str.strip():
|
| 87 |
+
recall = 0.0
|
| 88 |
+
update_recall = 0.0
|
| 89 |
+
update_total = sum(1 for p in recall_gold if p.startswith("[update]"))
|
| 90 |
+
recall_result = {
|
| 91 |
+
"covered_count": 0, "update_covered_count": 0,
|
| 92 |
+
"total": len(recall_gold), "update_total": update_total,
|
| 93 |
+
"reasoning": "No add/update memories in this session's delta.",
|
| 94 |
+
}
|
| 95 |
+
else:
|
| 96 |
+
recall_result = evaluate_recall_batch(delta_str, recall_gold)
|
| 97 |
+
|
| 98 |
+
covered = recall_result.get("covered_count")
|
| 99 |
+
upd_covered = recall_result.get("update_covered_count")
|
| 100 |
+
total_gold = recall_result.get("total", len(recall_gold))
|
| 101 |
+
upd_total = recall_result.get("update_total", 0)
|
| 102 |
+
|
| 103 |
+
if recall_gold:
|
| 104 |
+
recall = float(covered) / float(total_gold) if covered is not None and total_gold else None
|
| 105 |
+
update_recall = float(upd_covered) / float(upd_total) if upd_covered is not None and upd_total else None
|
| 106 |
+
|
| 107 |
+
# --- Call 2: Correctness (this session's delta memories, reference = this session's golds) ---
|
| 108 |
+
correctness_gold = _build_correctness_gold_points(session)
|
| 109 |
+
correctness_result = evaluate_correctness_batch(delta_texts, correctness_gold, interference_total)
|
| 110 |
+
correctness_records = correctness_result.get("results", [])
|
| 111 |
+
|
| 112 |
+
num_correct = sum(1 for r in correctness_records if r.get("label") == "correct")
|
| 113 |
+
num_hallucination = sum(1 for r in correctness_records if r.get("label") == "hallucination")
|
| 114 |
+
num_irrelevant = sum(1 for r in correctness_records if r.get("label") == "irrelevant")
|
| 115 |
+
num_memories = len(delta_texts)
|
| 116 |
+
correctness_rate = float(num_correct) / float(num_memories) if num_memories else 0.0
|
| 117 |
+
|
| 118 |
+
# --- Call 3+: Update handling (one LLM call per update gold point) ---
|
| 119 |
+
update_records: list[dict[str, object]] = []
|
| 120 |
+
for g in session.gold_state.session_update_memories:
|
| 121 |
+
res = evaluate_update_single(
|
| 122 |
+
delta_str,
|
| 123 |
+
new_content=g.memory_content,
|
| 124 |
+
old_contents=list(g.original_memories),
|
| 125 |
+
)
|
| 126 |
+
update_records.append({
|
| 127 |
+
"memory_id": g.memory_id,
|
| 128 |
+
"label": res["label"],
|
| 129 |
+
"reasoning": res["reasoning"],
|
| 130 |
+
})
|
| 131 |
+
|
| 132 |
+
num_updated = sum(1 for r in update_records if r["label"] == "updated")
|
| 133 |
+
num_both = sum(1 for r in update_records if r["label"] == "both")
|
| 134 |
+
num_outdated = sum(1 for r in update_records if r["label"] == "outdated")
|
| 135 |
+
update_total_items = len(update_records)
|
| 136 |
+
# Score: updated=1.0, both=0.5, outdated=0.0
|
| 137 |
+
update_score = (
|
| 138 |
+
(num_updated * 1.0 + num_both * 0.5) / update_total_items
|
| 139 |
+
if update_total_items else None
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# --- Call 4+: Interference rejection (one LLM call per interference gold point) ---
|
| 143 |
+
interference_records: list[dict[str, object]] = []
|
| 144 |
+
for g in session.gold_state.session_interference_memories:
|
| 145 |
+
res = evaluate_interference_single(
|
| 146 |
+
delta_str,
|
| 147 |
+
interference_content=g.memory_content,
|
| 148 |
+
)
|
| 149 |
+
interference_records.append({
|
| 150 |
+
"memory_id": g.memory_id,
|
| 151 |
+
"label": res["label"],
|
| 152 |
+
"reasoning": res["reasoning"],
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
num_rejected = sum(1 for r in interference_records if r["label"] == "rejected")
|
| 156 |
+
num_memorized = sum(1 for r in interference_records if r["label"] == "memorized")
|
| 157 |
+
interference_total_items = len(interference_records)
|
| 158 |
+
# Score: rejected=1.0, memorized=0.0
|
| 159 |
+
interference_score = (
|
| 160 |
+
float(num_rejected) / interference_total_items
|
| 161 |
+
if interference_total_items else None
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return {
|
| 165 |
+
"session_id": session.session_id,
|
| 166 |
+
"recall": recall,
|
| 167 |
+
"covered_count": covered,
|
| 168 |
+
"num_gold": total_gold,
|
| 169 |
+
"update_recall": update_recall,
|
| 170 |
+
"update_covered_count": upd_covered,
|
| 171 |
+
"update_total": upd_total,
|
| 172 |
+
"recall_reasoning": recall_result.get("reasoning", ""),
|
| 173 |
+
"correctness_rate": correctness_rate,
|
| 174 |
+
"num_memories": num_memories,
|
| 175 |
+
"num_correct": num_correct,
|
| 176 |
+
"num_hallucination": num_hallucination,
|
| 177 |
+
"num_irrelevant": num_irrelevant,
|
| 178 |
+
"correctness_reasoning": correctness_result.get("reasoning", ""),
|
| 179 |
+
"correctness_records": correctness_records,
|
| 180 |
+
# Update handling
|
| 181 |
+
"update_score": update_score,
|
| 182 |
+
"update_num_updated": num_updated,
|
| 183 |
+
"update_num_both": num_both,
|
| 184 |
+
"update_num_outdated": num_outdated,
|
| 185 |
+
"update_total_items": update_total_items,
|
| 186 |
+
"update_records": update_records,
|
| 187 |
+
# Interference rejection
|
| 188 |
+
"interference_score": interference_score,
|
| 189 |
+
"interference_num_rejected": num_rejected,
|
| 190 |
+
"interference_num_memorized": num_memorized,
|
| 191 |
+
"interference_total_items": interference_total_items,
|
| 192 |
+
"interference_records": interference_records,
|
| 193 |
+
}
|
evaluators/qa.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Checkpoint QA evaluation: answer quality + batch evidence coverage.
|
| 2 |
+
|
| 3 |
+
Two dimensions:
|
| 4 |
+
1. Answer evaluation: Correct / Hallucination / Omission (1 LLM call)
|
| 5 |
+
2. Evidence coverage: how many gold evidence points are covered by the
|
| 6 |
+
memories the model actually *cited* when answering? (1 LLM call)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from eval_framework.judges import evaluate_evidence_batch, evaluate_qa_llm
|
| 12 |
+
from eval_framework.pipeline.records import PipelineCheckpointQARecord
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def evaluate_checkpoint_qa(
|
| 16 |
+
record: PipelineCheckpointQARecord,
|
| 17 |
+
**_kwargs: object,
|
| 18 |
+
) -> dict[str, object]:
|
| 19 |
+
"""LLM-judged QA evaluation: answer correctness + evidence coverage."""
|
| 20 |
+
|
| 21 |
+
# --- Build cited-memories text (what the model actually used) ---
|
| 22 |
+
if record.cited_memories:
|
| 23 |
+
cited_lines = [f"[{i + 1}] {m}" for i, m in enumerate(record.cited_memories)]
|
| 24 |
+
cited_str = "\n".join(cited_lines)
|
| 25 |
+
else:
|
| 26 |
+
# Fallback: use full retrieval (legacy records without cited_memories)
|
| 27 |
+
cited_lines = [f"[{item.rank}] {item.text}" for item in record.retrieval.items]
|
| 28 |
+
cited_str = "\n".join(cited_lines) if cited_lines else ""
|
| 29 |
+
|
| 30 |
+
# --- Answer evaluation (1 LLM call, unchanged) ---
|
| 31 |
+
gold_evidence_str = (
|
| 32 |
+
"\n".join(record.gold_evidence_contents)
|
| 33 |
+
if record.gold_evidence_contents
|
| 34 |
+
else "No evidence available."
|
| 35 |
+
)
|
| 36 |
+
answer_result = evaluate_qa_llm(
|
| 37 |
+
question=record.question,
|
| 38 |
+
reference_answer=record.gold_answer,
|
| 39 |
+
key_memory_points=gold_evidence_str,
|
| 40 |
+
system_response=record.generated_answer,
|
| 41 |
+
)
|
| 42 |
+
answer_label = answer_result.get("evaluation_result")
|
| 43 |
+
|
| 44 |
+
# --- Evidence coverage (1 LLM call, batch) ---
|
| 45 |
+
# Only check against cited memories, not the full retrieval
|
| 46 |
+
gold_contents = list(record.gold_evidence_contents)
|
| 47 |
+
evidence_result: dict[str, object] = {
|
| 48 |
+
"covered_count": 0, "total": len(gold_contents), "reasoning": ""
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
if gold_contents and cited_str.strip():
|
| 52 |
+
evidence_result = evaluate_evidence_batch(cited_str, gold_contents)
|
| 53 |
+
|
| 54 |
+
covered = evidence_result.get("covered_count")
|
| 55 |
+
total_ev = evidence_result.get("total", len(gold_contents))
|
| 56 |
+
if covered is not None and total_ev:
|
| 57 |
+
evidence_hit_rate = float(covered) / float(total_ev)
|
| 58 |
+
else:
|
| 59 |
+
evidence_hit_rate = 0.0
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
"answer_label": answer_label,
|
| 63 |
+
"answer_reasoning": answer_result.get("reasoning", ""),
|
| 64 |
+
"answer_is_valid": answer_label in ("Correct", "Hallucination", "Omission"),
|
| 65 |
+
"evidence_hit_rate": evidence_hit_rate,
|
| 66 |
+
"evidence_covered_count": covered,
|
| 67 |
+
"num_evidence": total_ev,
|
| 68 |
+
"evidence_reasoning": evidence_result.get("reasoning", ""),
|
| 69 |
+
"num_cited_memories": len(record.cited_memories),
|
| 70 |
+
}
|
judges/__init__.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Judge stack: batch LLM evaluation.
|
| 2 |
+
|
| 3 |
+
Session: 2 calls (recall + correctness) + per-item calls for update/interference.
|
| 4 |
+
QA: 2 calls (answer + evidence).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from eval_framework.judges.llm_client import llm_request_for_json
|
| 10 |
+
from eval_framework.judges.prompts import (
|
| 11 |
+
CORRECTNESS_BATCH_PROMPT,
|
| 12 |
+
EVIDENCE_BATCH_PROMPT,
|
| 13 |
+
INTERFERENCE_EVAL_PROMPT,
|
| 14 |
+
QA_EVALUATION_PROMPT,
|
| 15 |
+
RECALL_BATCH_PROMPT,
|
| 16 |
+
UPDATE_EVAL_PROMPT,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"evaluate_recall_batch",
|
| 21 |
+
"evaluate_correctness_batch",
|
| 22 |
+
"evaluate_update_single",
|
| 23 |
+
"evaluate_interference_single",
|
| 24 |
+
"evaluate_evidence_batch",
|
| 25 |
+
"evaluate_qa_llm",
|
| 26 |
+
"llm_request_for_json",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def evaluate_recall_batch(
|
| 31 |
+
extracted_memories_str: str,
|
| 32 |
+
gold_points_tagged: list[str],
|
| 33 |
+
) -> dict[str, object]:
|
| 34 |
+
"""One LLM call: how many gold points are covered? Distinguishes update sub-score.
|
| 35 |
+
|
| 36 |
+
gold_points_tagged: list of "[normal] content" or "[update] content" strings.
|
| 37 |
+
Returns {covered_count, update_covered_count, total, update_total, reasoning}.
|
| 38 |
+
"""
|
| 39 |
+
if not extracted_memories_str.strip():
|
| 40 |
+
update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]"))
|
| 41 |
+
return {
|
| 42 |
+
"covered_count": 0, "update_covered_count": 0,
|
| 43 |
+
"total": len(gold_points_tagged), "update_total": update_total,
|
| 44 |
+
"reasoning": "No extracted memories.",
|
| 45 |
+
}
|
| 46 |
+
if not gold_points_tagged:
|
| 47 |
+
return {
|
| 48 |
+
"covered_count": 0, "update_covered_count": 0,
|
| 49 |
+
"total": 0, "update_total": 0, "reasoning": "No gold points.",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(gold_points_tagged))
|
| 53 |
+
update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]"))
|
| 54 |
+
|
| 55 |
+
prompt = RECALL_BATCH_PROMPT.format(memories=extracted_memories_str, gold_points=numbered)
|
| 56 |
+
try:
|
| 57 |
+
result = llm_request_for_json(prompt)
|
| 58 |
+
covered = int(result.get("covered_count", 0))
|
| 59 |
+
upd_covered = int(result.get("update_covered_count", 0))
|
| 60 |
+
return {
|
| 61 |
+
"covered_count": min(covered, len(gold_points_tagged)),
|
| 62 |
+
"update_covered_count": min(upd_covered, update_total),
|
| 63 |
+
"total": len(gold_points_tagged),
|
| 64 |
+
"update_total": update_total,
|
| 65 |
+
"reasoning": result.get("reasoning", ""),
|
| 66 |
+
}
|
| 67 |
+
except Exception as e:
|
| 68 |
+
return {
|
| 69 |
+
"covered_count": None, "update_covered_count": None,
|
| 70 |
+
"total": len(gold_points_tagged), "update_total": update_total,
|
| 71 |
+
"reasoning": f"LLM error: {e}",
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def evaluate_correctness_batch(
|
| 76 |
+
snapshot_memories: list[str],
|
| 77 |
+
gold_points_tagged: list[str],
|
| 78 |
+
interference_total: int,
|
| 79 |
+
) -> dict[str, object]:
|
| 80 |
+
"""One LLM call: is each snapshot memory correct? Includes interference detection.
|
| 81 |
+
|
| 82 |
+
gold_points_tagged: list of "[normal] content", "[update] content", "[interference] content".
|
| 83 |
+
Returns {results: [{id, label}], interference_memorized_count, interference_total, reasoning}.
|
| 84 |
+
"""
|
| 85 |
+
if not snapshot_memories:
|
| 86 |
+
return {
|
| 87 |
+
"results": [],
|
| 88 |
+
"interference_memorized_count": 0,
|
| 89 |
+
"interference_total": interference_total,
|
| 90 |
+
"reasoning": "No snapshot memories.",
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
numbered_memories = "\n".join(f"[{i+1}] {m}" for i, m in enumerate(snapshot_memories))
|
| 94 |
+
numbered_golds = "\n".join(f"- {p}" for p in gold_points_tagged) if gold_points_tagged else "(no ground-truth)"
|
| 95 |
+
|
| 96 |
+
prompt = CORRECTNESS_BATCH_PROMPT.format(memories=numbered_memories, gold_points=numbered_golds)
|
| 97 |
+
try:
|
| 98 |
+
result = llm_request_for_json(prompt)
|
| 99 |
+
raw_results = result.get("results", [])
|
| 100 |
+
valid_labels = {"correct", "hallucination", "irrelevant"}
|
| 101 |
+
cleaned = []
|
| 102 |
+
for r in raw_results:
|
| 103 |
+
label = str(r.get("label", "irrelevant")).lower().strip()
|
| 104 |
+
if label not in valid_labels:
|
| 105 |
+
label = "irrelevant"
|
| 106 |
+
cleaned.append({"id": r.get("id"), "label": label})
|
| 107 |
+
interf_mem = int(result.get("interference_memorized_count", 0))
|
| 108 |
+
return {
|
| 109 |
+
"results": cleaned,
|
| 110 |
+
"interference_memorized_count": min(interf_mem, interference_total),
|
| 111 |
+
"interference_total": interference_total,
|
| 112 |
+
"reasoning": result.get("reasoning", ""),
|
| 113 |
+
}
|
| 114 |
+
except Exception as e:
|
| 115 |
+
return {
|
| 116 |
+
"results": [],
|
| 117 |
+
"interference_memorized_count": None,
|
| 118 |
+
"interference_total": interference_total,
|
| 119 |
+
"reasoning": f"LLM error: {e}",
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def evaluate_update_single(
|
| 124 |
+
delta_memories_str: str,
|
| 125 |
+
new_content: str,
|
| 126 |
+
old_contents: list[str],
|
| 127 |
+
) -> dict[str, object]:
|
| 128 |
+
"""One LLM call: how did the system handle a single memory update?
|
| 129 |
+
|
| 130 |
+
Returns {label: "updated"|"both"|"outdated", reasoning}.
|
| 131 |
+
"""
|
| 132 |
+
old_str = "\n".join(f"- {o}" for o in old_contents) if old_contents else "(none)"
|
| 133 |
+
prompt = UPDATE_EVAL_PROMPT.format(
|
| 134 |
+
memories=delta_memories_str,
|
| 135 |
+
new_content=new_content,
|
| 136 |
+
old_contents=old_str,
|
| 137 |
+
)
|
| 138 |
+
try:
|
| 139 |
+
result = llm_request_for_json(prompt)
|
| 140 |
+
label = str(result.get("label", "outdated")).lower().strip()
|
| 141 |
+
if label not in ("updated", "both", "outdated"):
|
| 142 |
+
label = "outdated"
|
| 143 |
+
return {"label": label, "reasoning": result.get("reasoning", "")}
|
| 144 |
+
except Exception as e:
|
| 145 |
+
return {"label": None, "reasoning": f"LLM error: {e}"}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def evaluate_interference_single(
|
| 149 |
+
delta_memories_str: str,
|
| 150 |
+
interference_content: str,
|
| 151 |
+
) -> dict[str, object]:
|
| 152 |
+
"""One LLM call: did the system incorrectly memorize an interference point?
|
| 153 |
+
|
| 154 |
+
Returns {label: "rejected"|"memorized", reasoning}.
|
| 155 |
+
"""
|
| 156 |
+
prompt = INTERFERENCE_EVAL_PROMPT.format(
|
| 157 |
+
memories=delta_memories_str,
|
| 158 |
+
interference_content=interference_content,
|
| 159 |
+
)
|
| 160 |
+
try:
|
| 161 |
+
result = llm_request_for_json(prompt)
|
| 162 |
+
label = str(result.get("label", "memorized")).lower().strip()
|
| 163 |
+
if label not in ("rejected", "memorized"):
|
| 164 |
+
label = "memorized"
|
| 165 |
+
return {"label": label, "reasoning": result.get("reasoning", "")}
|
| 166 |
+
except Exception as e:
|
| 167 |
+
return {"label": None, "reasoning": f"LLM error: {e}"}
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def evaluate_evidence_batch(
|
| 171 |
+
retrieved_memories_str: str,
|
| 172 |
+
evidence_points: list[str],
|
| 173 |
+
) -> dict[str, object]:
|
| 174 |
+
"""One LLM call: how many gold evidence points are covered by retrieval?"""
|
| 175 |
+
if not retrieved_memories_str.strip():
|
| 176 |
+
return {"covered_count": 0, "total": len(evidence_points), "reasoning": "No retrieved memories."}
|
| 177 |
+
if not evidence_points:
|
| 178 |
+
return {"covered_count": 0, "total": 0, "reasoning": "No evidence points."}
|
| 179 |
+
|
| 180 |
+
numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(evidence_points))
|
| 181 |
+
prompt = EVIDENCE_BATCH_PROMPT.format(retrieved_memories=retrieved_memories_str, gold_evidence_points=numbered)
|
| 182 |
+
try:
|
| 183 |
+
result = llm_request_for_json(prompt)
|
| 184 |
+
covered = int(result.get("covered_count", 0))
|
| 185 |
+
return {
|
| 186 |
+
"covered_count": min(covered, len(evidence_points)),
|
| 187 |
+
"total": len(evidence_points),
|
| 188 |
+
"reasoning": result.get("reasoning", ""),
|
| 189 |
+
}
|
| 190 |
+
except Exception as e:
|
| 191 |
+
return {"covered_count": None, "total": len(evidence_points), "reasoning": f"LLM error: {e}"}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def evaluate_qa_llm(
|
| 195 |
+
question: str,
|
| 196 |
+
reference_answer: str,
|
| 197 |
+
key_memory_points: str,
|
| 198 |
+
system_response: str,
|
| 199 |
+
) -> dict[str, object]:
|
| 200 |
+
"""LLM judge: classify the QA response as Correct/Hallucination/Omission."""
|
| 201 |
+
if not system_response.strip():
|
| 202 |
+
return {"evaluation_result": "Omission", "reasoning": "Empty system response."}
|
| 203 |
+
|
| 204 |
+
prompt = QA_EVALUATION_PROMPT.format(
|
| 205 |
+
question=question, reference_answer=reference_answer,
|
| 206 |
+
key_memory_points=key_memory_points, response=system_response,
|
| 207 |
+
)
|
| 208 |
+
try:
|
| 209 |
+
result = llm_request_for_json(prompt)
|
| 210 |
+
label = result.get("evaluation_result", "Omission")
|
| 211 |
+
if label not in ("Correct", "Hallucination", "Omission"):
|
| 212 |
+
label = "Omission"
|
| 213 |
+
return {"evaluation_result": label, "reasoning": result.get("reasoning", "")}
|
| 214 |
+
except Exception as e:
|
| 215 |
+
return {"evaluation_result": None, "reasoning": f"LLM judge error: {e}"}
|
judges/llm_client.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenAI LLM client for judge calls with retry logic and concurrency control."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import logging
|
| 9 |
+
import threading
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
from tenacity import retry, stop_after_attempt, wait_random_exponential, before_sleep_log
|
| 15 |
+
|
| 16 |
+
# Load .env from project root (walk up from this file to find it)
|
| 17 |
+
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
| 18 |
+
load_dotenv(_PROJECT_ROOT / ".env")
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
_JSON_FENCE_RE = re.compile(r"```(?:json)?\s*\n?(.*?)\n?\s*```", re.DOTALL)
|
| 23 |
+
_JSON_FENCE_OPEN_RE = re.compile(r"```(?:json)?\s*\n?(.*)", re.DOTALL)
|
| 24 |
+
|
| 25 |
+
_client: Any = None
|
| 26 |
+
_client_lock = threading.Lock()
|
| 27 |
+
_semaphore: threading.Semaphore | None = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _build_client() -> Any:
|
| 31 |
+
from openai import OpenAI
|
| 32 |
+
return OpenAI(
|
| 33 |
+
api_key=os.getenv("OPENAI_API_KEY", ""),
|
| 34 |
+
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _get_client() -> Any:
|
| 39 |
+
global _client
|
| 40 |
+
if _client is None:
|
| 41 |
+
with _client_lock:
|
| 42 |
+
if _client is None:
|
| 43 |
+
_client = _build_client()
|
| 44 |
+
return _client
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _get_semaphore() -> threading.Semaphore:
|
| 48 |
+
global _semaphore
|
| 49 |
+
if _semaphore is None:
|
| 50 |
+
max_concurrent = int(os.getenv("LLM_MAX_CONCURRENT", "5"))
|
| 51 |
+
_semaphore = threading.Semaphore(max_concurrent)
|
| 52 |
+
return _semaphore
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _common_params() -> dict[str, Any]:
|
| 56 |
+
params: dict[str, Any] = {}
|
| 57 |
+
model = os.getenv("OPENAI_MODEL") or ""
|
| 58 |
+
max_tok = os.getenv("OPENAI_MAX_TOKENS")
|
| 59 |
+
if max_tok:
|
| 60 |
+
if model.startswith("gpt-5") or model.startswith("o"):
|
| 61 |
+
params["max_completion_tokens"] = int(max_tok)
|
| 62 |
+
else:
|
| 63 |
+
params["max_tokens"] = int(max_tok)
|
| 64 |
+
params["temperature"] = float(os.getenv("JUDGE_TEMPERATURE", "0.0"))
|
| 65 |
+
if os.getenv("OPENAI_TIMEOUT"):
|
| 66 |
+
params["timeout"] = int(os.getenv("OPENAI_TIMEOUT"))
|
| 67 |
+
return params
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@retry(
|
| 71 |
+
wait=wait_random_exponential(min=2, max=60),
|
| 72 |
+
stop=stop_after_attempt(8),
|
| 73 |
+
reraise=True,
|
| 74 |
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
| 75 |
+
)
|
| 76 |
+
def llm_request_for_json(prompt: str) -> dict[str, Any]:
|
| 77 |
+
"""Send a prompt to the LLM and parse the JSON block from the response.
|
| 78 |
+
|
| 79 |
+
Respects global concurrency limit (LLM_MAX_CONCURRENT env var, default 5).
|
| 80 |
+
"""
|
| 81 |
+
sem = _get_semaphore()
|
| 82 |
+
sem.acquire()
|
| 83 |
+
try:
|
| 84 |
+
client = _get_client()
|
| 85 |
+
model = os.getenv("OPENAI_MODEL") or "gpt-4o"
|
| 86 |
+
|
| 87 |
+
response = client.chat.completions.create(
|
| 88 |
+
model=model,
|
| 89 |
+
messages=[{"role": "user", "content": prompt}],
|
| 90 |
+
**_common_params(),
|
| 91 |
+
)
|
| 92 |
+
content = response.choices[0].message.content or ""
|
| 93 |
+
finally:
|
| 94 |
+
sem.release()
|
| 95 |
+
|
| 96 |
+
parsed = _extract_json(content)
|
| 97 |
+
if parsed is not None:
|
| 98 |
+
return parsed
|
| 99 |
+
|
| 100 |
+
raise ValueError(f"No JSON block found in model output: {content[:500]}")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _extract_json(content: str) -> dict[str, Any] | None:
|
| 104 |
+
"""Try to extract a JSON object from model output, with truncation repair."""
|
| 105 |
+
# 1. Closed fence: ```json ... ```
|
| 106 |
+
for match in _JSON_FENCE_RE.finditer(content):
|
| 107 |
+
candidate = match.group(1).strip()
|
| 108 |
+
if candidate.startswith("{"):
|
| 109 |
+
try:
|
| 110 |
+
return json.loads(candidate)
|
| 111 |
+
except json.JSONDecodeError:
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
# 2. Open fence (output truncated before closing ```): ```json ...EOF
|
| 115 |
+
match = _JSON_FENCE_OPEN_RE.search(content)
|
| 116 |
+
if match:
|
| 117 |
+
candidate = match.group(1).strip()
|
| 118 |
+
if candidate.startswith("{"):
|
| 119 |
+
repaired = _repair_truncated_json(candidate)
|
| 120 |
+
if repaired is not None:
|
| 121 |
+
return repaired
|
| 122 |
+
|
| 123 |
+
# 3. Raw JSON without fences
|
| 124 |
+
stripped = content.strip()
|
| 125 |
+
if stripped.startswith("{"):
|
| 126 |
+
try:
|
| 127 |
+
return json.loads(stripped)
|
| 128 |
+
except json.JSONDecodeError:
|
| 129 |
+
repaired = _repair_truncated_json(stripped)
|
| 130 |
+
if repaired is not None:
|
| 131 |
+
return repaired
|
| 132 |
+
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _repair_truncated_json(text: str) -> dict[str, Any] | None:
|
| 137 |
+
"""Best-effort repair of truncated JSON by closing open brackets/braces."""
|
| 138 |
+
# Remove trailing partial tokens (incomplete key/value after last comma)
|
| 139 |
+
text = re.sub(r',\s*"[^"]*$', "", text) # trailing partial key
|
| 140 |
+
text = re.sub(r',\s*\{[^}]*$', "", text) # trailing partial object
|
| 141 |
+
text = re.sub(r',\s*$', "", text) # trailing comma
|
| 142 |
+
|
| 143 |
+
# Count unclosed brackets and braces, then append closers
|
| 144 |
+
open_braces = text.count("{") - text.count("}")
|
| 145 |
+
open_brackets = text.count("[") - text.count("]")
|
| 146 |
+
suffix = "]" * max(open_brackets, 0) + "}" * max(open_braces, 0)
|
| 147 |
+
candidate = text + suffix
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
result = json.loads(candidate)
|
| 151 |
+
if isinstance(result, dict):
|
| 152 |
+
logger.warning("Repaired truncated JSON (appended %r)", suffix)
|
| 153 |
+
return result
|
| 154 |
+
except json.JSONDecodeError:
|
| 155 |
+
pass
|
| 156 |
+
return None
|
judges/prompts.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM judge prompt templates — batch mode.
|
| 2 |
+
|
| 3 |
+
Each session: 2 LLM calls (recall + correctness).
|
| 4 |
+
Each QA question: 2 LLM calls (answer + evidence).
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# ---------------------------------------------------------------------------
|
| 8 |
+
# Session Call 1: Recall (Gold -> Snapshot)
|
| 9 |
+
# ---------------------------------------------------------------------------
|
| 10 |
+
RECALL_BATCH_PROMPT = """You are a **Memory Recall Evaluator**.
|
| 11 |
+
Determine how many of the **Expected Memory Points** are covered by the system's **Extracted Memories**.
|
| 12 |
+
|
| 13 |
+
# Inputs
|
| 14 |
+
|
| 15 |
+
1. **Extracted Memories** (what the system actually stored):
|
| 16 |
+
{memories}
|
| 17 |
+
|
| 18 |
+
2. **Expected Memory Points** (numbered, each tagged [normal] or [update]):
|
| 19 |
+
{gold_points}
|
| 20 |
+
|
| 21 |
+
# Instructions
|
| 22 |
+
|
| 23 |
+
- Go through each Expected Memory Point and check whether the Extracted Memories contain information that covers it.
|
| 24 |
+
- Semantic matching is acceptable; exact wording is NOT required.
|
| 25 |
+
- Count **total** covered points AND separately count how many **[update]** tagged points are covered.
|
| 26 |
+
|
| 27 |
+
# Output
|
| 28 |
+
|
| 29 |
+
```json
|
| 30 |
+
{{
|
| 31 |
+
"covered_count": <int>,
|
| 32 |
+
"update_covered_count": <int>,
|
| 33 |
+
"total": <int>,
|
| 34 |
+
"update_total": <int>,
|
| 35 |
+
"reasoning": "Brief summary"
|
| 36 |
+
}}
|
| 37 |
+
```
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Session Call 2: Correctness (Snapshot -> Gold)
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
CORRECTNESS_BATCH_PROMPT = """You are a **Memory Correctness Evaluator**.
|
| 44 |
+
Evaluate whether each memory stored by the system is factually correct.
|
| 45 |
+
|
| 46 |
+
# Inputs
|
| 47 |
+
|
| 48 |
+
1. **System Memories** (numbered, what the system actually stored):
|
| 49 |
+
{memories}
|
| 50 |
+
|
| 51 |
+
2. **Ground-Truth Reference Points** (tagged [normal], [update], or [interference]):
|
| 52 |
+
{gold_points}
|
| 53 |
+
|
| 54 |
+
# Instructions
|
| 55 |
+
|
| 56 |
+
For **each** System Memory, classify it as one of:
|
| 57 |
+
- **correct**: The memory is factually accurate and consistent with the [normal] or [update] ground-truth points.
|
| 58 |
+
- **hallucination**: The memory contains fabricated or incorrect information, OR it contains content from [interference] points (information that should NOT have been memorized).
|
| 59 |
+
- **irrelevant**: The memory is not wrong per se, but is trivial filler or not related to any ground-truth point.
|
| 60 |
+
|
| 61 |
+
**IMPORTANT**: If a System Memory matches or contains information from an [interference] tagged point, it MUST be classified as **hallucination**, because the system should have ignored that information.
|
| 62 |
+
|
| 63 |
+
Also count how many [interference] ground-truth points appear in the System Memories.
|
| 64 |
+
|
| 65 |
+
# Output
|
| 66 |
+
|
| 67 |
+
```json
|
| 68 |
+
{{
|
| 69 |
+
"results": [
|
| 70 |
+
{{"id": 1, "label": "correct|hallucination|irrelevant"}},
|
| 71 |
+
{{"id": 2, "label": "correct|hallucination|irrelevant"}}
|
| 72 |
+
],
|
| 73 |
+
"interference_memorized_count": <int>,
|
| 74 |
+
"interference_total": <int>,
|
| 75 |
+
"reasoning": "Brief justification"
|
| 76 |
+
}}
|
| 77 |
+
```
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
# Session: Update handling (per update gold point)
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
UPDATE_EVAL_PROMPT = """You are a **Memory Update Evaluator**.
|
| 84 |
+
Determine how a memory system handled an information update.
|
| 85 |
+
|
| 86 |
+
# Inputs
|
| 87 |
+
|
| 88 |
+
1. **System Memories** (what the system currently stores after this session):
|
| 89 |
+
{memories}
|
| 90 |
+
|
| 91 |
+
2. **Updated Fact** (the NEW correct information):
|
| 92 |
+
{new_content}
|
| 93 |
+
|
| 94 |
+
3. **Outdated Fact(s)** (the OLD information that should have been replaced):
|
| 95 |
+
{old_contents}
|
| 96 |
+
|
| 97 |
+
# Instructions
|
| 98 |
+
|
| 99 |
+
Check the System Memories and classify the update handling as one of:
|
| 100 |
+
|
| 101 |
+
- **updated**: The system stores ONLY the new/updated information. The outdated fact is no longer present. This is the ideal outcome.
|
| 102 |
+
- **both**: The system stores BOTH the new and the old information. The update was partially handled — the new fact was added but the old was not removed.
|
| 103 |
+
- **outdated**: The system stores ONLY the old/outdated information. The update was missed entirely — the new fact is absent.
|
| 104 |
+
|
| 105 |
+
Use semantic matching, not exact wording.
|
| 106 |
+
|
| 107 |
+
# Output
|
| 108 |
+
|
| 109 |
+
```json
|
| 110 |
+
{{
|
| 111 |
+
"label": "updated|both|outdated",
|
| 112 |
+
"reasoning": "Brief justification"
|
| 113 |
+
}}
|
| 114 |
+
```
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
# Session: Interference rejection (per interference gold point)
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
INTERFERENCE_EVAL_PROMPT = """You are a **Memory Interference Evaluator**.
|
| 121 |
+
Determine whether a memory system incorrectly stored information that should have been ignored.
|
| 122 |
+
|
| 123 |
+
# Inputs
|
| 124 |
+
|
| 125 |
+
1. **System Memories** (what the system currently stores after this session):
|
| 126 |
+
{memories}
|
| 127 |
+
|
| 128 |
+
2. **Interference Content** (information that should NOT have been memorized):
|
| 129 |
+
{interference_content}
|
| 130 |
+
|
| 131 |
+
# Instructions
|
| 132 |
+
|
| 133 |
+
Check whether the System Memories contain the interference content (or its semantic equivalent). Classify as:
|
| 134 |
+
|
| 135 |
+
- **rejected**: The interference content is NOT present in the system memories. The system correctly ignored it.
|
| 136 |
+
- **memorized**: The interference content IS present (or semantically equivalent) in the system memories. The system incorrectly stored it.
|
| 137 |
+
|
| 138 |
+
Use semantic matching, not exact wording.
|
| 139 |
+
|
| 140 |
+
# Output
|
| 141 |
+
|
| 142 |
+
```json
|
| 143 |
+
{{
|
| 144 |
+
"label": "rejected|memorized",
|
| 145 |
+
"reasoning": "Brief justification"
|
| 146 |
+
}}
|
| 147 |
+
```
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
# ---------------------------------------------------------------------------
|
| 151 |
+
# QA Evidence Coverage (batch per question)
|
| 152 |
+
# ---------------------------------------------------------------------------
|
| 153 |
+
EVIDENCE_BATCH_PROMPT = """You are an **Evidence Retrieval Evaluator**.
|
| 154 |
+
Determine how many of the **Gold Evidence Points** are covered by the system's **Retrieved Memories** when answering a question.
|
| 155 |
+
|
| 156 |
+
# Inputs
|
| 157 |
+
|
| 158 |
+
1. **Retrieved Memories** (what the system retrieved to answer the question):
|
| 159 |
+
{retrieved_memories}
|
| 160 |
+
|
| 161 |
+
2. **Gold Evidence Points** (key facts needed for the correct answer, numbered):
|
| 162 |
+
{gold_evidence_points}
|
| 163 |
+
|
| 164 |
+
# Instructions
|
| 165 |
+
|
| 166 |
+
Go through each Gold Evidence Point and check whether the Retrieved Memories contain information that covers it. Semantic matching is acceptable; exact wording is NOT required.
|
| 167 |
+
|
| 168 |
+
Count how many Gold Evidence Points are **fully covered or logically implied** by the Retrieved Memories.
|
| 169 |
+
|
| 170 |
+
# Output
|
| 171 |
+
|
| 172 |
+
```json
|
| 173 |
+
{{
|
| 174 |
+
"covered_count": <int>,
|
| 175 |
+
"total": <int>,
|
| 176 |
+
"reasoning": "Brief summary"
|
| 177 |
+
}}
|
| 178 |
+
```
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
# QA Answer evaluation
|
| 183 |
+
# ---------------------------------------------------------------------------
|
| 184 |
+
QA_EVALUATION_PROMPT = """You are an **evaluation expert for AI memory system question answering**.
|
| 185 |
+
Based **only** on the provided **"Question"**, **"Reference Answer"**, and **"Key Memory Points"**, strictly evaluate the **accuracy** of the **"Memory System Response."** Classify it as one of **"Correct"**, **"Hallucination"**, or **"Omission."** Do **not** use any external knowledge or subjective inference.
|
| 186 |
+
|
| 187 |
+
# Evaluation Criteria
|
| 188 |
+
|
| 189 |
+
### 1. Correct
|
| 190 |
+
* The response accurately answers the question and is **semantically equivalent** to the Reference Answer.
|
| 191 |
+
* No contradictions with Key Memory Points or Reference Answer.
|
| 192 |
+
* Synonyms, paraphrasing, and reasonable summarization are acceptable.
|
| 193 |
+
|
| 194 |
+
### 2. Hallucination
|
| 195 |
+
* The response includes information that **contradicts** the Reference Answer or Key Memory Points.
|
| 196 |
+
* When the Reference Answer is *unknown/uncertain*, yet the response provides a specific fact.
|
| 197 |
+
|
| 198 |
+
### 3. Omission
|
| 199 |
+
* The response is **incomplete** compared to the Reference Answer.
|
| 200 |
+
* It states "don't know" or "no related memory" even though relevant information exists.
|
| 201 |
+
* For multi-element questions, missing **any** element counts as Omission.
|
| 202 |
+
|
| 203 |
+
## Priority Rules
|
| 204 |
+
* Both missing info AND fabricated info -> **Hallucination**.
|
| 205 |
+
* No fabrication but missing info -> **Omission**.
|
| 206 |
+
* Fully equivalent -> **Correct**.
|
| 207 |
+
|
| 208 |
+
# Information
|
| 209 |
+
|
| 210 |
+
* **Question:** {question}
|
| 211 |
+
* **Reference Answer:** {reference_answer}
|
| 212 |
+
* **Key Memory Points:** {key_memory_points}
|
| 213 |
+
* **Memory System Response:** {response}
|
| 214 |
+
|
| 215 |
+
# Output
|
| 216 |
+
|
| 217 |
+
```json
|
| 218 |
+
{{
|
| 219 |
+
"reasoning": "Concise evaluation rationale",
|
| 220 |
+
"evaluation_result": "Correct | Hallucination | Omission"
|
| 221 |
+
}}
|
| 222 |
+
```
|
| 223 |
+
"""
|
memory_adapters/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Memory system adapters for the eval framework."""
|
| 2 |
+
|
| 3 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 4 |
+
from eval_framework.memory_adapters.memgallery_native import (
|
| 5 |
+
MemGalleryNativeAdapter,
|
| 6 |
+
instantiate_memgallery_memory,
|
| 7 |
+
)
|
| 8 |
+
from eval_framework.memory_adapters.registry import (
|
| 9 |
+
EXTERNAL_ADAPTER_KEYS,
|
| 10 |
+
EXTERNAL_ADAPTER_REGISTRY,
|
| 11 |
+
MEMGALLERY_NATIVE_BASELINES,
|
| 12 |
+
MEMGALLERY_NATIVE_REGISTRY,
|
| 13 |
+
create_external_adapter,
|
| 14 |
+
create_memgallery_adapter,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"EXTERNAL_ADAPTER_KEYS",
|
| 19 |
+
"EXTERNAL_ADAPTER_REGISTRY",
|
| 20 |
+
"MEMGALLERY_NATIVE_BASELINES",
|
| 21 |
+
"MEMGALLERY_NATIVE_REGISTRY",
|
| 22 |
+
"MemoryAdapter",
|
| 23 |
+
"MemGalleryNativeAdapter",
|
| 24 |
+
"create_external_adapter",
|
| 25 |
+
"create_memgallery_adapter",
|
| 26 |
+
"instantiate_memgallery_memory",
|
| 27 |
+
]
|
memory_adapters/amem.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adapter for the external A-Mem baseline."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import importlib
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Callable
|
| 10 |
+
|
| 11 |
+
from eval_framework.datasets.schemas import (
|
| 12 |
+
MemoryDeltaRecord,
|
| 13 |
+
MemorySnapshotRecord,
|
| 14 |
+
NormalizedTurn,
|
| 15 |
+
RetrievalItem,
|
| 16 |
+
RetrievalRecord,
|
| 17 |
+
)
|
| 18 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 19 |
+
|
| 20 |
+
_BACKEND_ID = "A-Mem"
|
| 21 |
+
|
| 22 |
+
INTEGRATION_ERROR = (
|
| 23 |
+
f"{_BACKEND_ID} backend unavailable."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AMemAdapter(MemoryAdapter):
|
| 28 |
+
"""Thin wrapper around A-Mem's robust memory system."""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
*,
|
| 33 |
+
backend: Any | None = None,
|
| 34 |
+
backend_factory: Callable[[], Any] | None = None,
|
| 35 |
+
source_root: str | os.PathLike[str] | None = None,
|
| 36 |
+
model_name: str = "all-MiniLM-L6-v2",
|
| 37 |
+
llm_backend: str = "openai",
|
| 38 |
+
llm_model: str | None = None,
|
| 39 |
+
api_key: str | None = None,
|
| 40 |
+
api_base: str | None = None,
|
| 41 |
+
sglang_host: str = "http://localhost",
|
| 42 |
+
sglang_port: int = 30000,
|
| 43 |
+
) -> None:
|
| 44 |
+
self._source_root = Path(source_root).resolve() if source_root else self._default_source_root()
|
| 45 |
+
resolved_llm_model = llm_model or os.getenv("OPENAI_MODEL") or "gpt-5.1"
|
| 46 |
+
self._backend: Any | None = None
|
| 47 |
+
self._backend_factory = backend_factory
|
| 48 |
+
self._integration_error: str | None = None
|
| 49 |
+
self._session_id = ""
|
| 50 |
+
self._prev_snapshot_ids: set[str] = set()
|
| 51 |
+
self._note_session_map: dict[str, str] = {}
|
| 52 |
+
|
| 53 |
+
if backend is not None:
|
| 54 |
+
self._backend = backend
|
| 55 |
+
else:
|
| 56 |
+
try:
|
| 57 |
+
if self._backend_factory is None:
|
| 58 |
+
self._backend_factory = self._build_backend_factory(
|
| 59 |
+
model_name=model_name,
|
| 60 |
+
llm_backend=llm_backend,
|
| 61 |
+
llm_model=resolved_llm_model,
|
| 62 |
+
api_key=api_key,
|
| 63 |
+
api_base=api_base,
|
| 64 |
+
sglang_host=sglang_host,
|
| 65 |
+
sglang_port=sglang_port,
|
| 66 |
+
)
|
| 67 |
+
self._backend = self._backend_factory()
|
| 68 |
+
except Exception as exc:
|
| 69 |
+
self._integration_error = str(exc)
|
| 70 |
+
|
| 71 |
+
@staticmethod
|
| 72 |
+
def _default_source_root() -> Path:
|
| 73 |
+
here = Path(__file__).resolve()
|
| 74 |
+
# memory_adapters/ -> eval_framework/ -> our/ -> Benchmark/
|
| 75 |
+
return (here.parents[2].parent / "data_pipline" / "A-mem").resolve()
|
| 76 |
+
|
| 77 |
+
def _build_backend_factory(
|
| 78 |
+
self,
|
| 79 |
+
*,
|
| 80 |
+
model_name: str,
|
| 81 |
+
llm_backend: str,
|
| 82 |
+
llm_model: str,
|
| 83 |
+
api_key: str | None,
|
| 84 |
+
api_base: str | None,
|
| 85 |
+
sglang_host: str,
|
| 86 |
+
sglang_port: int,
|
| 87 |
+
) -> Callable[[], Any]:
|
| 88 |
+
if not self._source_root.is_dir():
|
| 89 |
+
raise RuntimeError(
|
| 90 |
+
f"{_BACKEND_ID}: source root not found at {self._source_root}"
|
| 91 |
+
)
|
| 92 |
+
src = str(self._source_root)
|
| 93 |
+
if src not in sys.path:
|
| 94 |
+
sys.path.insert(0, src)
|
| 95 |
+
mod = importlib.import_module("memory_layer_robust")
|
| 96 |
+
backend_cls = getattr(mod, "RobustAgenticMemorySystem")
|
| 97 |
+
return lambda: backend_cls(
|
| 98 |
+
model_name=model_name,
|
| 99 |
+
llm_backend=llm_backend,
|
| 100 |
+
llm_model=llm_model,
|
| 101 |
+
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
| 102 |
+
api_base=api_base or os.getenv("OPENAI_BASE_URL"),
|
| 103 |
+
sglang_host=sglang_host,
|
| 104 |
+
sglang_port=sglang_port,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def _runtime_error(self) -> RuntimeError:
|
| 108 |
+
detail = self._integration_error or INTEGRATION_ERROR
|
| 109 |
+
return RuntimeError(
|
| 110 |
+
f"{_BACKEND_ID}: backend unavailable — {detail}"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def reset(self) -> None:
|
| 114 |
+
if self._backend_factory is None and self._backend is None:
|
| 115 |
+
raise self._runtime_error()
|
| 116 |
+
if self._backend_factory is not None:
|
| 117 |
+
self._backend = self._backend_factory()
|
| 118 |
+
self._prev_snapshot_ids = set()
|
| 119 |
+
self._note_session_map = {}
|
| 120 |
+
self._session_id = ""
|
| 121 |
+
|
| 122 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 123 |
+
backend = self._require_backend()
|
| 124 |
+
self._session_id = turn.session_id
|
| 125 |
+
text = self._turn_text(turn)
|
| 126 |
+
note_id = backend.add_note(text, time=turn.timestamp)
|
| 127 |
+
self._note_session_map[str(note_id)] = turn.session_id
|
| 128 |
+
|
| 129 |
+
def end_session(self, session_id: str) -> None:
|
| 130 |
+
self._require_backend()
|
| 131 |
+
self._session_id = session_id
|
| 132 |
+
|
| 133 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 134 |
+
backend = self._require_backend()
|
| 135 |
+
rows: list[MemorySnapshotRecord] = []
|
| 136 |
+
for note_id, note in getattr(backend, "memories", {}).items():
|
| 137 |
+
sid = self._note_session_map.get(str(note_id), self._session_id)
|
| 138 |
+
content = str(getattr(note, "content", ""))
|
| 139 |
+
context = getattr(note, "context", "")
|
| 140 |
+
keywords = list(getattr(note, "keywords", []) or [])
|
| 141 |
+
tags = list(getattr(note, "tags", []) or [])
|
| 142 |
+
# Include A-Mem enrichments in the snapshot text so that the
|
| 143 |
+
# eval captures what the system actually processed, not just
|
| 144 |
+
# the raw input.
|
| 145 |
+
enriched_parts = [content]
|
| 146 |
+
if context:
|
| 147 |
+
enriched_parts.append(f"[context] {context}")
|
| 148 |
+
if keywords:
|
| 149 |
+
enriched_parts.append(f"[keywords] {', '.join(keywords)}")
|
| 150 |
+
if tags:
|
| 151 |
+
enriched_parts.append(f"[tags] {', '.join(tags)}")
|
| 152 |
+
rows.append(
|
| 153 |
+
MemorySnapshotRecord(
|
| 154 |
+
memory_id=str(getattr(note, "id", note_id)),
|
| 155 |
+
text="\n".join(enriched_parts),
|
| 156 |
+
session_id=sid,
|
| 157 |
+
status="active",
|
| 158 |
+
source=_BACKEND_ID,
|
| 159 |
+
raw_backend_id=str(getattr(note, "id", note_id)),
|
| 160 |
+
raw_backend_type="a_mem_note",
|
| 161 |
+
metadata={
|
| 162 |
+
"timestamp": getattr(note, "timestamp", None),
|
| 163 |
+
"context": context,
|
| 164 |
+
"keywords": keywords,
|
| 165 |
+
"tags": tags,
|
| 166 |
+
"links": list(getattr(note, "links", []) or []),
|
| 167 |
+
},
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
return rows
|
| 171 |
+
|
| 172 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 173 |
+
"""Export delta by diffing current snapshot against previous snapshot."""
|
| 174 |
+
self._require_backend()
|
| 175 |
+
current_snapshot = self.snapshot_memories()
|
| 176 |
+
deltas: list[MemoryDeltaRecord] = []
|
| 177 |
+
current_ids: set[str] = set()
|
| 178 |
+
|
| 179 |
+
for snap in current_snapshot:
|
| 180 |
+
current_ids.add(snap.memory_id)
|
| 181 |
+
if snap.memory_id not in self._prev_snapshot_ids:
|
| 182 |
+
deltas.append(
|
| 183 |
+
MemoryDeltaRecord(
|
| 184 |
+
session_id=session_id,
|
| 185 |
+
op="add",
|
| 186 |
+
text=snap.text,
|
| 187 |
+
linked_previous=(),
|
| 188 |
+
raw_backend_id=snap.raw_backend_id,
|
| 189 |
+
metadata={
|
| 190 |
+
"baseline": _BACKEND_ID,
|
| 191 |
+
"backend_type": snap.raw_backend_type,
|
| 192 |
+
},
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self._prev_snapshot_ids = current_ids
|
| 197 |
+
return deltas
|
| 198 |
+
|
| 199 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 200 |
+
backend = self._require_backend()
|
| 201 |
+
items: list[RetrievalItem] = []
|
| 202 |
+
memories = list(getattr(backend, "memories", {}).values())
|
| 203 |
+
retriever = getattr(backend, "retriever", None)
|
| 204 |
+
if retriever is not None and hasattr(retriever, "search"):
|
| 205 |
+
for rank, idx in enumerate(retriever.search(query, top_k)):
|
| 206 |
+
if 0 <= int(idx) < len(memories):
|
| 207 |
+
note = memories[int(idx)]
|
| 208 |
+
items.append(
|
| 209 |
+
RetrievalItem(
|
| 210 |
+
rank=rank,
|
| 211 |
+
memory_id=str(getattr(note, "id", idx)),
|
| 212 |
+
text=str(getattr(note, "content", "")),
|
| 213 |
+
score=1.0 / float(rank + 1),
|
| 214 |
+
raw_backend_id=str(getattr(note, "id", idx)),
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
if not items and hasattr(backend, "find_related_memories_raw"):
|
| 218 |
+
raw = backend.find_related_memories_raw(query, k=top_k)
|
| 219 |
+
if raw:
|
| 220 |
+
items.append(
|
| 221 |
+
RetrievalItem(
|
| 222 |
+
rank=0,
|
| 223 |
+
memory_id="a_mem:bundle",
|
| 224 |
+
text=str(raw),
|
| 225 |
+
score=1.0,
|
| 226 |
+
raw_backend_id=None,
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
return RetrievalRecord(
|
| 230 |
+
query=query,
|
| 231 |
+
top_k=top_k,
|
| 232 |
+
items=items[:top_k],
|
| 233 |
+
raw_trace={"baseline": _BACKEND_ID},
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 237 |
+
available = self._backend is not None or self._backend_factory is not None
|
| 238 |
+
return {
|
| 239 |
+
"backend": _BACKEND_ID,
|
| 240 |
+
"baseline": _BACKEND_ID,
|
| 241 |
+
"available": available and self._integration_error is None,
|
| 242 |
+
"integration_status": "integrated" if available and self._integration_error is None else "unavailable",
|
| 243 |
+
"integration_error": self._integration_error or INTEGRATION_ERROR,
|
| 244 |
+
"delta_granularity": "ingest_turn_only",
|
| 245 |
+
"snapshot_mode": "full_store",
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
def _require_backend(self) -> Any:
|
| 249 |
+
if self._backend is None:
|
| 250 |
+
raise self._runtime_error()
|
| 251 |
+
return self._backend
|
| 252 |
+
|
| 253 |
+
@staticmethod
|
| 254 |
+
def _turn_text(turn: NormalizedTurn) -> str:
|
| 255 |
+
parts = [f"{turn.role}: {turn.text}"]
|
| 256 |
+
for att in turn.attachments:
|
| 257 |
+
parts.append(f"[{att.type}] {att.caption}")
|
| 258 |
+
return "\n".join(parts)
|
memory_adapters/amem_v2.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adapter for A-Mem (new API: agentic_memory.AgenticMemorySystem)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
load_dotenv(Path(__file__).resolve().parents[2] / ".env")
|
| 13 |
+
|
| 14 |
+
from eval_framework.datasets.schemas import (
|
| 15 |
+
MemoryDeltaRecord,
|
| 16 |
+
MemorySnapshotRecord,
|
| 17 |
+
NormalizedTurn,
|
| 18 |
+
RetrievalItem,
|
| 19 |
+
RetrievalRecord,
|
| 20 |
+
)
|
| 21 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 22 |
+
|
| 23 |
+
_DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/A-Mem")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AMemV2Adapter(MemoryAdapter):
|
| 27 |
+
"""Adapter for A-Mem (new agentic_memory API)."""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
*,
|
| 32 |
+
source_root: str | os.PathLike[str] | None = None,
|
| 33 |
+
**kwargs: Any,
|
| 34 |
+
) -> None:
|
| 35 |
+
root = Path(source_root or _DEFAULT_SOURCE).resolve()
|
| 36 |
+
if str(root) not in sys.path:
|
| 37 |
+
sys.path.insert(0, str(root))
|
| 38 |
+
|
| 39 |
+
from agentic_memory.memory_system import AgenticMemorySystem
|
| 40 |
+
|
| 41 |
+
self._cls = AgenticMemorySystem
|
| 42 |
+
self._backend: Any = None
|
| 43 |
+
self._session_id = ""
|
| 44 |
+
self._prev_snapshot_ids: set[str] = set()
|
| 45 |
+
self._init_backend()
|
| 46 |
+
|
| 47 |
+
def _init_backend(self) -> None:
|
| 48 |
+
self._backend = self._cls(
|
| 49 |
+
model_name="all-MiniLM-L6-v2",
|
| 50 |
+
llm_backend="openai",
|
| 51 |
+
llm_model=os.getenv("OPENAI_MODEL") or "gpt-4o",
|
| 52 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def reset(self) -> None:
|
| 56 |
+
self._init_backend()
|
| 57 |
+
self._prev_snapshot_ids = set()
|
| 58 |
+
|
| 59 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 60 |
+
self._session_id = turn.session_id
|
| 61 |
+
text = f"{turn.role}: {turn.text}"
|
| 62 |
+
for att in turn.attachments:
|
| 63 |
+
text += f"\n[{att.type}] {att.caption}"
|
| 64 |
+
self._backend.add_note(text, time=turn.timestamp)
|
| 65 |
+
|
| 66 |
+
def end_session(self, session_id: str) -> None:
|
| 67 |
+
self._session_id = session_id
|
| 68 |
+
|
| 69 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 70 |
+
rows: list[MemorySnapshotRecord] = []
|
| 71 |
+
for mid, note in self._backend.memories.items():
|
| 72 |
+
content = str(getattr(note, "content", ""))
|
| 73 |
+
context = getattr(note, "context", "")
|
| 74 |
+
keywords = list(getattr(note, "keywords", []) or [])
|
| 75 |
+
parts = [content]
|
| 76 |
+
if context:
|
| 77 |
+
parts.append(f"[context] {context}")
|
| 78 |
+
if keywords:
|
| 79 |
+
parts.append(f"[keywords] {', '.join(keywords)}")
|
| 80 |
+
rows.append(MemorySnapshotRecord(
|
| 81 |
+
memory_id=str(mid),
|
| 82 |
+
text="\n".join(parts),
|
| 83 |
+
session_id=self._session_id,
|
| 84 |
+
status="active",
|
| 85 |
+
source="A-Mem",
|
| 86 |
+
raw_backend_id=str(mid),
|
| 87 |
+
raw_backend_type="a_mem_note",
|
| 88 |
+
metadata={},
|
| 89 |
+
))
|
| 90 |
+
return rows
|
| 91 |
+
|
| 92 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 93 |
+
current = self.snapshot_memories()
|
| 94 |
+
current_ids = {s.memory_id for s in current}
|
| 95 |
+
deltas = [
|
| 96 |
+
MemoryDeltaRecord(
|
| 97 |
+
session_id=session_id, op="add", text=s.text,
|
| 98 |
+
linked_previous=(), raw_backend_id=s.raw_backend_id,
|
| 99 |
+
metadata={"baseline": "A-Mem"},
|
| 100 |
+
)
|
| 101 |
+
for s in current if s.memory_id not in self._prev_snapshot_ids
|
| 102 |
+
]
|
| 103 |
+
self._prev_snapshot_ids = current_ids
|
| 104 |
+
return deltas
|
| 105 |
+
|
| 106 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 107 |
+
items: list[RetrievalItem] = []
|
| 108 |
+
try:
|
| 109 |
+
results = self._backend.search(query, k=top_k)
|
| 110 |
+
for i, r in enumerate(results[:top_k]):
|
| 111 |
+
text = r.get("content", str(r)) if isinstance(r, dict) else str(r)
|
| 112 |
+
mid = r.get("id", str(i)) if isinstance(r, dict) else str(i)
|
| 113 |
+
score = float(r.get("score", 1.0 / (i + 1))) if isinstance(r, dict) else 1.0 / (i + 1)
|
| 114 |
+
items.append(RetrievalItem(
|
| 115 |
+
rank=i, memory_id=str(mid), text=text,
|
| 116 |
+
score=score, raw_backend_id=str(mid),
|
| 117 |
+
))
|
| 118 |
+
except Exception:
|
| 119 |
+
# Fallback to raw search
|
| 120 |
+
try:
|
| 121 |
+
raw = self._backend.find_related_memories_raw(query, k=top_k)
|
| 122 |
+
if raw:
|
| 123 |
+
items.append(RetrievalItem(
|
| 124 |
+
rank=0, memory_id="bundle", text=str(raw),
|
| 125 |
+
score=1.0, raw_backend_id=None,
|
| 126 |
+
))
|
| 127 |
+
except Exception:
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
return RetrievalRecord(
|
| 131 |
+
query=query, top_k=top_k, items=items[:top_k],
|
| 132 |
+
raw_trace={"baseline": "A-Mem"},
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 136 |
+
return {
|
| 137 |
+
"backend": "A-Mem",
|
| 138 |
+
"baseline": "A-Mem",
|
| 139 |
+
"available": self._backend is not None,
|
| 140 |
+
"delta_granularity": "snapshot_diff",
|
| 141 |
+
"snapshot_mode": "full",
|
| 142 |
+
}
|
memory_adapters/base.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Abstract memory adapter API for eval baselines."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
from eval_framework.datasets.schemas import (
|
| 9 |
+
MemoryDeltaRecord,
|
| 10 |
+
MemorySnapshotRecord,
|
| 11 |
+
NormalizedTurn,
|
| 12 |
+
RetrievalRecord,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class MemoryAdapter(ABC):
|
| 17 |
+
"""Baseline-agnostic adapter surface used by the eval pipeline."""
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def reset(self) -> None:
|
| 21 |
+
"""Clear backend state and any adapter-side bookkeeping."""
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 25 |
+
"""Feed one conversation turn into the memory system."""
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def end_session(self, session_id: str) -> None:
|
| 29 |
+
"""Notify the adapter that a session boundary was reached (optional for many backends)."""
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 33 |
+
"""Return a normalized view of memories observable in the backend."""
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 37 |
+
"""Export memory changes for the given session since the last call."""
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 41 |
+
"""Run retrieval and normalize results."""
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 45 |
+
"""Describe adapter behavior limits (deltas, snapshots, backend id)."""
|
memory_adapters/dummy.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dummy memory adapter for end-to-end pipeline testing.
|
| 2 |
+
|
| 3 |
+
Stores all ingested turns as raw text and retrieves by simple substring match.
|
| 4 |
+
No external dependencies required.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
from eval_framework.datasets.schemas import (
|
| 12 |
+
MemoryDeltaRecord,
|
| 13 |
+
MemorySnapshotRecord,
|
| 14 |
+
NormalizedTurn,
|
| 15 |
+
RetrievalItem,
|
| 16 |
+
RetrievalRecord,
|
| 17 |
+
)
|
| 18 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DummyAdapter(MemoryAdapter):
|
| 22 |
+
"""Minimal adapter that stores turns verbatim — for pipeline testing."""
|
| 23 |
+
|
| 24 |
+
def __init__(self) -> None:
|
| 25 |
+
self._memories: list[dict[str, str]] = []
|
| 26 |
+
self._session_id = ""
|
| 27 |
+
self._prev_ids: set[str] = set()
|
| 28 |
+
|
| 29 |
+
def reset(self) -> None:
|
| 30 |
+
self._memories = []
|
| 31 |
+
self._session_id = ""
|
| 32 |
+
self._prev_ids = set()
|
| 33 |
+
|
| 34 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 35 |
+
self._session_id = turn.session_id
|
| 36 |
+
text = f"{turn.role}: {turn.text}"
|
| 37 |
+
for att in turn.attachments:
|
| 38 |
+
text += f"\n[{att.type}] {att.caption}"
|
| 39 |
+
mid = str(len(self._memories))
|
| 40 |
+
self._memories.append({
|
| 41 |
+
"id": mid,
|
| 42 |
+
"text": text,
|
| 43 |
+
"session_id": turn.session_id,
|
| 44 |
+
})
|
| 45 |
+
|
| 46 |
+
def end_session(self, session_id: str) -> None:
|
| 47 |
+
self._session_id = session_id
|
| 48 |
+
|
| 49 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 50 |
+
return [
|
| 51 |
+
MemorySnapshotRecord(
|
| 52 |
+
memory_id=m["id"],
|
| 53 |
+
text=m["text"],
|
| 54 |
+
session_id=m["session_id"],
|
| 55 |
+
status="active",
|
| 56 |
+
source="Dummy",
|
| 57 |
+
raw_backend_id=m["id"],
|
| 58 |
+
raw_backend_type="dummy",
|
| 59 |
+
metadata={},
|
| 60 |
+
)
|
| 61 |
+
for m in self._memories
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 65 |
+
current_ids = {m["id"] for m in self._memories}
|
| 66 |
+
new_ids = current_ids - self._prev_ids
|
| 67 |
+
deltas = [
|
| 68 |
+
MemoryDeltaRecord(
|
| 69 |
+
session_id=session_id,
|
| 70 |
+
op="add",
|
| 71 |
+
text=m["text"],
|
| 72 |
+
linked_previous=(),
|
| 73 |
+
raw_backend_id=m["id"],
|
| 74 |
+
metadata={"baseline": "Dummy"},
|
| 75 |
+
)
|
| 76 |
+
for m in self._memories
|
| 77 |
+
if m["id"] in new_ids
|
| 78 |
+
]
|
| 79 |
+
self._prev_ids = current_ids
|
| 80 |
+
return deltas
|
| 81 |
+
|
| 82 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 83 |
+
query_lower = query.lower()
|
| 84 |
+
scored = []
|
| 85 |
+
for m in self._memories:
|
| 86 |
+
text_lower = m["text"].lower()
|
| 87 |
+
# Simple word overlap score
|
| 88 |
+
query_words = set(query_lower.split())
|
| 89 |
+
text_words = set(text_lower.split())
|
| 90 |
+
overlap = len(query_words & text_words)
|
| 91 |
+
scored.append((overlap, m))
|
| 92 |
+
scored.sort(key=lambda x: x[0], reverse=True)
|
| 93 |
+
|
| 94 |
+
items = [
|
| 95 |
+
RetrievalItem(
|
| 96 |
+
rank=i,
|
| 97 |
+
memory_id=m["id"],
|
| 98 |
+
text=m["text"],
|
| 99 |
+
score=float(overlap) / max(len(query.split()), 1),
|
| 100 |
+
raw_backend_id=m["id"],
|
| 101 |
+
)
|
| 102 |
+
for i, (overlap, m) in enumerate(scored[:top_k])
|
| 103 |
+
]
|
| 104 |
+
return RetrievalRecord(
|
| 105 |
+
query=query,
|
| 106 |
+
top_k=top_k,
|
| 107 |
+
items=items,
|
| 108 |
+
raw_trace={"baseline": "Dummy"},
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 112 |
+
return {
|
| 113 |
+
"backend": "Dummy",
|
| 114 |
+
"baseline": "Dummy",
|
| 115 |
+
"available": True,
|
| 116 |
+
"delta_granularity": "per_turn",
|
| 117 |
+
"snapshot_mode": "full",
|
| 118 |
+
}
|
memory_adapters/export_utils.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helpers to map turns, backend memory dicts, and recall outputs into shared schemas."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Mapping
|
| 6 |
+
|
| 7 |
+
from eval_framework.datasets.schemas import (
|
| 8 |
+
MemorySnapshotRecord,
|
| 9 |
+
NormalizedTurn,
|
| 10 |
+
RetrievalItem,
|
| 11 |
+
RetrievalRecord,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def turn_to_observation_dict(turn: NormalizedTurn) -> dict[str, Any]:
|
| 16 |
+
"""Build a Mem-Gallery store observation from a normalized turn."""
|
| 17 |
+
parts: list[str] = [turn.text]
|
| 18 |
+
for att in turn.attachments:
|
| 19 |
+
parts.append(f"[{att.type}] {att.caption}")
|
| 20 |
+
text = "\n".join(parts)
|
| 21 |
+
obs: dict[str, Any] = {"text": text}
|
| 22 |
+
if turn.timestamp:
|
| 23 |
+
obs["timestamp"] = turn.timestamp
|
| 24 |
+
obs["dialogue_id"] = f"{turn.session_id}:{turn.turn_index}"
|
| 25 |
+
return obs
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def memory_element_text(element: Mapping[str, Any]) -> str:
|
| 29 |
+
"""Best-effort text extraction from a Mem-Gallery memory dict."""
|
| 30 |
+
raw = element.get("text", "")
|
| 31 |
+
if isinstance(raw, list):
|
| 32 |
+
return " ".join(str(x) for x in raw)
|
| 33 |
+
if raw is None:
|
| 34 |
+
base = ""
|
| 35 |
+
else:
|
| 36 |
+
base = str(raw)
|
| 37 |
+
image = element.get("image")
|
| 38 |
+
if isinstance(image, dict):
|
| 39 |
+
cap = image.get("caption")
|
| 40 |
+
if cap:
|
| 41 |
+
base = f"{base}\n[image] {cap}".strip()
|
| 42 |
+
return base
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def linear_element_to_snapshot(
|
| 46 |
+
element: Mapping[str, Any],
|
| 47 |
+
*,
|
| 48 |
+
memory_id: str,
|
| 49 |
+
session_id: str,
|
| 50 |
+
source: str,
|
| 51 |
+
status: str = "active",
|
| 52 |
+
) -> MemorySnapshotRecord:
|
| 53 |
+
"""Map a linear-storage memory dict into MemorySnapshotRecord."""
|
| 54 |
+
cid = element.get("counter_id")
|
| 55 |
+
raw_id = str(cid) if cid is not None else memory_id
|
| 56 |
+
return MemorySnapshotRecord(
|
| 57 |
+
memory_id=memory_id,
|
| 58 |
+
text=memory_element_text(element),
|
| 59 |
+
session_id=session_id,
|
| 60 |
+
status=status,
|
| 61 |
+
source=source,
|
| 62 |
+
raw_backend_id=raw_id,
|
| 63 |
+
raw_backend_type="linear",
|
| 64 |
+
metadata={},
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def normalize_recall_to_retrieval(
|
| 69 |
+
query: str,
|
| 70 |
+
top_k: int,
|
| 71 |
+
raw: Any,
|
| 72 |
+
*,
|
| 73 |
+
raw_trace: dict[str, Any] | None = None,
|
| 74 |
+
) -> RetrievalRecord:
|
| 75 |
+
"""Normalize Mem-Gallery recall outputs into RetrievalRecord."""
|
| 76 |
+
trace = dict(raw_trace or {})
|
| 77 |
+
items: list[RetrievalItem] = []
|
| 78 |
+
|
| 79 |
+
if isinstance(raw, str):
|
| 80 |
+
items.append(
|
| 81 |
+
RetrievalItem(
|
| 82 |
+
rank=0,
|
| 83 |
+
memory_id="memgallery:string_bundle",
|
| 84 |
+
text=raw,
|
| 85 |
+
score=1.0,
|
| 86 |
+
raw_backend_id=None,
|
| 87 |
+
)
|
| 88 |
+
)
|
| 89 |
+
elif isinstance(raw, list):
|
| 90 |
+
for i, row in enumerate(raw[: max(0, top_k)]):
|
| 91 |
+
if isinstance(row, dict):
|
| 92 |
+
mid = row.get("counter_id")
|
| 93 |
+
items.append(
|
| 94 |
+
RetrievalItem(
|
| 95 |
+
rank=i,
|
| 96 |
+
memory_id=str(mid if mid is not None else i),
|
| 97 |
+
text=memory_element_text(row),
|
| 98 |
+
score=float(row.get("score", 1.0)),
|
| 99 |
+
raw_backend_id=str(mid) if mid is not None else None,
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
items.append(
|
| 104 |
+
RetrievalItem(
|
| 105 |
+
rank=i,
|
| 106 |
+
memory_id=str(i),
|
| 107 |
+
text=str(row),
|
| 108 |
+
score=1.0,
|
| 109 |
+
raw_backend_id=None,
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
items.append(
|
| 114 |
+
RetrievalItem(
|
| 115 |
+
rank=0,
|
| 116 |
+
memory_id="memgallery:object_bundle",
|
| 117 |
+
text=str(raw),
|
| 118 |
+
score=1.0,
|
| 119 |
+
raw_backend_id=None,
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return RetrievalRecord(query=query, top_k=top_k, items=items[:top_k], raw_trace=trace)
|
memory_adapters/mem0_adapter.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adapters for Mem0 and Mem0-Graph baselines."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import uuid as _uuid
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
|
| 12 |
+
load_dotenv(Path(__file__).resolve().parents[2] / ".env")
|
| 13 |
+
|
| 14 |
+
from eval_framework.datasets.schemas import (
|
| 15 |
+
MemoryDeltaRecord,
|
| 16 |
+
MemorySnapshotRecord,
|
| 17 |
+
NormalizedTurn,
|
| 18 |
+
RetrievalItem,
|
| 19 |
+
RetrievalRecord,
|
| 20 |
+
)
|
| 21 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Mem0Adapter(MemoryAdapter):
|
| 25 |
+
"""Adapter for Mem0 (vector mode)."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, *, use_graph: bool = False, **kwargs: Any) -> None:
|
| 28 |
+
from mem0 import Memory
|
| 29 |
+
|
| 30 |
+
self._user_id = f"eval_{_uuid.uuid4().hex[:8]}"
|
| 31 |
+
self._session_id = ""
|
| 32 |
+
self._prev_snapshot_ids: set[str] = set()
|
| 33 |
+
|
| 34 |
+
config: dict[str, Any] = {
|
| 35 |
+
"llm": {
|
| 36 |
+
"provider": "openai",
|
| 37 |
+
"config": {
|
| 38 |
+
"model": os.getenv("OPENAI_MODEL") or "gpt-4o",
|
| 39 |
+
"api_key": os.getenv("OPENAI_API_KEY") or "",
|
| 40 |
+
},
|
| 41 |
+
},
|
| 42 |
+
"embedder": {
|
| 43 |
+
"provider": "openai",
|
| 44 |
+
"config": {
|
| 45 |
+
"model": "text-embedding-3-small",
|
| 46 |
+
"api_key": os.getenv("OPENAI_API_KEY") or "",
|
| 47 |
+
"embedding_dims": 1536,
|
| 48 |
+
},
|
| 49 |
+
},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
base_url = os.getenv("OPENAI_BASE_URL")
|
| 53 |
+
if base_url:
|
| 54 |
+
config["llm"]["config"]["openai_base_url"] = base_url
|
| 55 |
+
config["embedder"]["config"]["openai_base_url"] = base_url
|
| 56 |
+
|
| 57 |
+
if use_graph:
|
| 58 |
+
config["graph_store"] = {
|
| 59 |
+
"provider": "kuzu",
|
| 60 |
+
"config": {
|
| 61 |
+
"url": "/tmp/mem0_kuzu_eval",
|
| 62 |
+
},
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
self._memory = Memory.from_config(config)
|
| 66 |
+
self._use_graph = use_graph
|
| 67 |
+
|
| 68 |
+
def reset(self) -> None:
|
| 69 |
+
self._memory.delete_all(user_id=self._user_id)
|
| 70 |
+
self._user_id = f"eval_{_uuid.uuid4().hex[:8]}"
|
| 71 |
+
self._prev_snapshot_ids = set()
|
| 72 |
+
|
| 73 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 74 |
+
self._session_id = turn.session_id
|
| 75 |
+
text = f"{turn.role}: {turn.text}"
|
| 76 |
+
for att in turn.attachments:
|
| 77 |
+
text += f"\n[{att.type}] {att.caption}"
|
| 78 |
+
# Truncate to avoid excessively long inputs that break graph entity extraction
|
| 79 |
+
text = text[:2000]
|
| 80 |
+
try:
|
| 81 |
+
self._memory.add(
|
| 82 |
+
messages=[{"role": turn.role, "content": text}],
|
| 83 |
+
user_id=self._user_id,
|
| 84 |
+
)
|
| 85 |
+
except Exception:
|
| 86 |
+
# Graph mode can fail on entity embedding; fall back silently
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
def end_session(self, session_id: str) -> None:
|
| 90 |
+
self._session_id = session_id
|
| 91 |
+
|
| 92 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 93 |
+
all_mems = self._memory.get_all(user_id=self._user_id)
|
| 94 |
+
rows: list[MemorySnapshotRecord] = []
|
| 95 |
+
|
| 96 |
+
# Vector results (standard mode)
|
| 97 |
+
results = all_mems.get("results", []) if isinstance(all_mems, dict) else all_mems
|
| 98 |
+
for mem in results:
|
| 99 |
+
mid = str(mem.get("id", ""))
|
| 100 |
+
text = str(mem.get("memory", ""))
|
| 101 |
+
rows.append(MemorySnapshotRecord(
|
| 102 |
+
memory_id=mid, text=text,
|
| 103 |
+
session_id=self._session_id, status="active",
|
| 104 |
+
source="Mem0", raw_backend_id=mid,
|
| 105 |
+
raw_backend_type="mem0_vector", metadata={},
|
| 106 |
+
))
|
| 107 |
+
|
| 108 |
+
# Graph relations (graph mode)
|
| 109 |
+
relations = all_mems.get("relations", []) if isinstance(all_mems, dict) else []
|
| 110 |
+
for i, rel in enumerate(relations):
|
| 111 |
+
if isinstance(rel, dict):
|
| 112 |
+
src = rel.get("source", "")
|
| 113 |
+
rtype = rel.get("relationship", "")
|
| 114 |
+
tgt = rel.get("target") or rel.get("destination", "")
|
| 115 |
+
text = f"{src} → {rtype} → {tgt}"
|
| 116 |
+
mid = f"rel_{i}"
|
| 117 |
+
rows.append(MemorySnapshotRecord(
|
| 118 |
+
memory_id=mid, text=text,
|
| 119 |
+
session_id=self._session_id, status="active",
|
| 120 |
+
source="Mem0-Graph", raw_backend_id=mid,
|
| 121 |
+
raw_backend_type="mem0_graph_relation", metadata=rel,
|
| 122 |
+
))
|
| 123 |
+
|
| 124 |
+
return rows
|
| 125 |
+
|
| 126 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 127 |
+
current = self.snapshot_memories()
|
| 128 |
+
current_ids = {s.memory_id for s in current}
|
| 129 |
+
deltas = [
|
| 130 |
+
MemoryDeltaRecord(
|
| 131 |
+
session_id=session_id,
|
| 132 |
+
op="add",
|
| 133 |
+
text=s.text,
|
| 134 |
+
linked_previous=(),
|
| 135 |
+
raw_backend_id=s.raw_backend_id,
|
| 136 |
+
metadata={"baseline": "Mem0"},
|
| 137 |
+
)
|
| 138 |
+
for s in current if s.memory_id not in self._prev_snapshot_ids
|
| 139 |
+
]
|
| 140 |
+
self._prev_snapshot_ids = current_ids
|
| 141 |
+
return deltas
|
| 142 |
+
|
| 143 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 144 |
+
results = self._memory.search(query=query, user_id=self._user_id, limit=top_k)
|
| 145 |
+
items: list[RetrievalItem] = []
|
| 146 |
+
|
| 147 |
+
# Vector results
|
| 148 |
+
search_results = results.get("results", []) if isinstance(results, dict) else results
|
| 149 |
+
for i, r in enumerate(search_results[:top_k]):
|
| 150 |
+
items.append(RetrievalItem(
|
| 151 |
+
rank=len(items),
|
| 152 |
+
memory_id=str(r.get("id", i)),
|
| 153 |
+
text=str(r.get("memory", "")),
|
| 154 |
+
score=float(r.get("score", 1.0 / (i + 1))),
|
| 155 |
+
raw_backend_id=str(r.get("id", "")),
|
| 156 |
+
))
|
| 157 |
+
|
| 158 |
+
# Graph relations
|
| 159 |
+
relations = results.get("relations", []) if isinstance(results, dict) else []
|
| 160 |
+
for rel in relations:
|
| 161 |
+
if isinstance(rel, dict) and len(items) < top_k:
|
| 162 |
+
src = rel.get("source", "")
|
| 163 |
+
rtype = rel.get("relationship", "")
|
| 164 |
+
tgt = rel.get("target") or rel.get("destination", "")
|
| 165 |
+
items.append(RetrievalItem(
|
| 166 |
+
rank=len(items),
|
| 167 |
+
memory_id=f"rel_{len(items)}",
|
| 168 |
+
text=f"{src} → {rtype} → {tgt}",
|
| 169 |
+
score=0.9,
|
| 170 |
+
raw_backend_id=None,
|
| 171 |
+
))
|
| 172 |
+
|
| 173 |
+
return RetrievalRecord(
|
| 174 |
+
query=query, top_k=top_k, items=items[:top_k],
|
| 175 |
+
raw_trace={"baseline": "Mem0-Graph" if self._use_graph else "Mem0"},
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 179 |
+
return {
|
| 180 |
+
"backend": "Mem0-Graph" if self._use_graph else "Mem0",
|
| 181 |
+
"baseline": "Mem0-Graph" if self._use_graph else "Mem0",
|
| 182 |
+
"available": True,
|
| 183 |
+
"delta_granularity": "snapshot_diff",
|
| 184 |
+
"snapshot_mode": "full",
|
| 185 |
+
}
|
memory_adapters/memgallery_native.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Mem-Gallery native baseline wrappers with conservative schema normalization."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
from typing import Any, Callable
|
| 7 |
+
|
| 8 |
+
from eval_framework.datasets.schemas import (
|
| 9 |
+
MemoryDeltaRecord,
|
| 10 |
+
MemorySnapshotRecord,
|
| 11 |
+
NormalizedTurn,
|
| 12 |
+
RetrievalRecord,
|
| 13 |
+
)
|
| 14 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 15 |
+
from eval_framework.memory_adapters.export_utils import (
|
| 16 |
+
linear_element_to_snapshot,
|
| 17 |
+
memory_element_text,
|
| 18 |
+
normalize_recall_to_retrieval,
|
| 19 |
+
turn_to_observation_dict,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _deep_merge_dict(base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
|
| 24 |
+
out = copy.deepcopy(base)
|
| 25 |
+
for key, val in overrides.items():
|
| 26 |
+
if (
|
| 27 |
+
key in out
|
| 28 |
+
and isinstance(out[key], dict)
|
| 29 |
+
and isinstance(val, dict)
|
| 30 |
+
):
|
| 31 |
+
out[key] = _deep_merge_dict(out[key], val)
|
| 32 |
+
else:
|
| 33 |
+
out[key] = copy.deepcopy(val)
|
| 34 |
+
return out
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _default_config_for_baseline(name: str) -> dict[str, Any]:
|
| 38 |
+
import default_config.DefaultMemoryConfig as dmc # type: ignore[import-not-found]
|
| 39 |
+
|
| 40 |
+
key = {
|
| 41 |
+
"FUMemory": "DEFAULT_FUMEMORY",
|
| 42 |
+
"STMemory": "DEFAULT_STMEMORY",
|
| 43 |
+
"LTMemory": "DEFAULT_LTMEMORY",
|
| 44 |
+
"GAMemory": "DEFAULT_GAMEMORY",
|
| 45 |
+
"MGMemory": "DEFAULT_MGMEMORY",
|
| 46 |
+
"RFMemory": "DEFAULT_RFMEMORY",
|
| 47 |
+
"MMMemory": "DEFAULT_MMMEMORY",
|
| 48 |
+
"MMFUMemory": "DEFAULT_MMFUMEMORY",
|
| 49 |
+
"NGMemory": "DEFAULT_NGMEMORY",
|
| 50 |
+
"AUGUSTUSMemory": "DEFAULT_AUGUSTUSMEMORY",
|
| 51 |
+
"UniversalRAGMemory": "DEFAULT_UNIVERSALRAGMEMORY",
|
| 52 |
+
}[name]
|
| 53 |
+
cfg = getattr(dmc, key)
|
| 54 |
+
return copy.deepcopy(cfg)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _import_memory_class(name: str) -> Callable[..., Any]:
|
| 58 |
+
modmap = {
|
| 59 |
+
"FUMemory": ("memengine.memory.FUMemory", "FUMemory"),
|
| 60 |
+
"STMemory": ("memengine.memory.STMemory", "STMemory"),
|
| 61 |
+
"LTMemory": ("memengine.memory.LTMemory", "LTMemory"),
|
| 62 |
+
"GAMemory": ("memengine.memory.GAMemory", "GAMemory"),
|
| 63 |
+
"MGMemory": ("memengine.memory.MGMemory", "MGMemory"),
|
| 64 |
+
"RFMemory": ("memengine.memory.RFMemory", "RFMemory"),
|
| 65 |
+
"MMMemory": ("memengine.memory.MMMemory", "MMMemory"),
|
| 66 |
+
"MMFUMemory": ("memengine.memory.MMFUMemory", "MMFUMemory"),
|
| 67 |
+
"NGMemory": ("memengine.memory.NGMemory", "NGMemory"),
|
| 68 |
+
"AUGUSTUSMemory": ("memengine.memory.AUGUSTUSMemory", "AUGUSTUSMemory"),
|
| 69 |
+
"UniversalRAGMemory": ("memengine.memory.UniversalRAGMemory", "UniversalRAGMemory"),
|
| 70 |
+
}
|
| 71 |
+
module_path, cls_name = modmap[name]
|
| 72 |
+
import importlib
|
| 73 |
+
|
| 74 |
+
mod = importlib.import_module(module_path)
|
| 75 |
+
return getattr(mod, cls_name)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def instantiate_memgallery_memory(
|
| 79 |
+
baseline_name: str,
|
| 80 |
+
config: dict[str, Any] | None = None,
|
| 81 |
+
) -> Any:
|
| 82 |
+
"""Construct a Mem-Gallery memory object with optional config overrides."""
|
| 83 |
+
base_cfg = _default_config_for_baseline(baseline_name)
|
| 84 |
+
merged = _deep_merge_dict(base_cfg, config or {})
|
| 85 |
+
from memengine.config.Config import MemoryConfig # type: ignore[import-not-found]
|
| 86 |
+
|
| 87 |
+
cls = _import_memory_class(baseline_name)
|
| 88 |
+
return cls(MemoryConfig(merged))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _graph_nodes_to_snapshots(
|
| 92 |
+
storage: Any,
|
| 93 |
+
*,
|
| 94 |
+
session_id: str,
|
| 95 |
+
source: str,
|
| 96 |
+
include_concepts: bool = False,
|
| 97 |
+
) -> list[MemorySnapshotRecord]:
|
| 98 |
+
out: list[MemorySnapshotRecord] = []
|
| 99 |
+
order = getattr(storage, "memory_order_map", []) or []
|
| 100 |
+
node_concepts = getattr(storage, "node_concepts", {})
|
| 101 |
+
for mid_idx, node_id in enumerate(order):
|
| 102 |
+
node = storage.node[node_id]
|
| 103 |
+
cid = node.get("counter_id", mid_idx)
|
| 104 |
+
memory_id = f"n{node_id}"
|
| 105 |
+
text = memory_element_text(node)
|
| 106 |
+
# For AUGUSTUS: append concept tags extracted by the system
|
| 107 |
+
if include_concepts:
|
| 108 |
+
concepts = node_concepts.get(node_id, set())
|
| 109 |
+
if concepts:
|
| 110 |
+
text = f"{text}\n[concepts] {', '.join(sorted(concepts))}"
|
| 111 |
+
out.append(
|
| 112 |
+
MemorySnapshotRecord(
|
| 113 |
+
memory_id=memory_id,
|
| 114 |
+
text=text,
|
| 115 |
+
session_id=session_id,
|
| 116 |
+
status="active",
|
| 117 |
+
source=source,
|
| 118 |
+
raw_backend_id=str(cid),
|
| 119 |
+
raw_backend_type="graph_node",
|
| 120 |
+
metadata={"node_id": node_id},
|
| 121 |
+
)
|
| 122 |
+
)
|
| 123 |
+
return out
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _linear_storage_snapshots(
|
| 127 |
+
storage: Any,
|
| 128 |
+
*,
|
| 129 |
+
session_id: str,
|
| 130 |
+
source: str,
|
| 131 |
+
) -> list[MemorySnapshotRecord]:
|
| 132 |
+
rows: list[MemorySnapshotRecord] = []
|
| 133 |
+
for i, m in enumerate(storage.memory_list):
|
| 134 |
+
cid = m.get("counter_id", i)
|
| 135 |
+
rows.append(
|
| 136 |
+
linear_element_to_snapshot(
|
| 137 |
+
m,
|
| 138 |
+
memory_id=str(cid),
|
| 139 |
+
session_id=session_id,
|
| 140 |
+
source=source,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
return rows
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def collect_memgallery_snapshots(
|
| 147 |
+
memory: Any,
|
| 148 |
+
baseline_name: str,
|
| 149 |
+
session_id: str,
|
| 150 |
+
) -> list[MemorySnapshotRecord]:
|
| 151 |
+
"""Best-effort snapshot of backend-visible memories."""
|
| 152 |
+
source = baseline_name
|
| 153 |
+
if baseline_name == "MGMemory":
|
| 154 |
+
out: list[MemorySnapshotRecord] = []
|
| 155 |
+
# store_op/recall_op have their own main_context references;
|
| 156 |
+
# prefer store_op's view as it holds the actual stored data.
|
| 157 |
+
mc = getattr(memory.store_op, "main_context", None) or memory.main_context
|
| 158 |
+
recall_storage = getattr(memory.recall_op, "recall_storage",
|
| 159 |
+
getattr(memory, "recall_storage", None))
|
| 160 |
+
archival_storage = getattr(memory.recall_op, "archival_storage",
|
| 161 |
+
getattr(memory, "archival_storage", None))
|
| 162 |
+
storages = [("wm", mc["working_context"]), ("fifo", mc["FIFO_queue"])]
|
| 163 |
+
if recall_storage is not None:
|
| 164 |
+
storages.append(("recall", recall_storage))
|
| 165 |
+
if archival_storage is not None:
|
| 166 |
+
storages.append(("archival", archival_storage))
|
| 167 |
+
for prefix, st in storages:
|
| 168 |
+
for i, m in enumerate(st.memory_list):
|
| 169 |
+
cid = m.get("counter_id", i)
|
| 170 |
+
mid = f"{prefix}-{cid}"
|
| 171 |
+
rows = linear_element_to_snapshot(
|
| 172 |
+
m,
|
| 173 |
+
memory_id=mid,
|
| 174 |
+
session_id=session_id,
|
| 175 |
+
source=source,
|
| 176 |
+
)
|
| 177 |
+
out.append(rows)
|
| 178 |
+
gsum = mc.get("recursive_summary", {}).get("global")
|
| 179 |
+
if gsum and str(gsum) != "None":
|
| 180 |
+
out.append(
|
| 181 |
+
MemorySnapshotRecord(
|
| 182 |
+
memory_id="recursive_summary",
|
| 183 |
+
text=str(gsum),
|
| 184 |
+
session_id=session_id,
|
| 185 |
+
status="active",
|
| 186 |
+
source=source,
|
| 187 |
+
raw_backend_id=None,
|
| 188 |
+
raw_backend_type="mg_summary",
|
| 189 |
+
metadata={},
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
return out
|
| 193 |
+
|
| 194 |
+
if baseline_name == "RFMemory":
|
| 195 |
+
rows = _linear_storage_snapshots(
|
| 196 |
+
memory.storage, session_id=session_id, source=source
|
| 197 |
+
)
|
| 198 |
+
insight = getattr(memory, "insight", {}).get("global_insight", "")
|
| 199 |
+
if insight:
|
| 200 |
+
rows.append(
|
| 201 |
+
MemorySnapshotRecord(
|
| 202 |
+
memory_id="rf_insight",
|
| 203 |
+
text=str(insight),
|
| 204 |
+
session_id=session_id,
|
| 205 |
+
status="active",
|
| 206 |
+
source=source,
|
| 207 |
+
raw_backend_id=None,
|
| 208 |
+
raw_backend_type="rf_insight",
|
| 209 |
+
metadata={},
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
return rows
|
| 213 |
+
|
| 214 |
+
if baseline_name == "NGMemory":
|
| 215 |
+
return _graph_nodes_to_snapshots(
|
| 216 |
+
memory.storage, session_id=session_id, source=source
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if baseline_name == "AUGUSTUSMemory":
|
| 220 |
+
return _graph_nodes_to_snapshots(
|
| 221 |
+
memory.contextual_memory, session_id=session_id, source=source,
|
| 222 |
+
include_concepts=True,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if baseline_name == "UniversalRAGMemory":
|
| 226 |
+
return _linear_storage_snapshots(
|
| 227 |
+
memory.storage, session_id=session_id, source=source
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
if hasattr(memory, "storage") and hasattr(memory.storage, "memory_list"):
|
| 231 |
+
return _linear_storage_snapshots(
|
| 232 |
+
memory.storage, session_id=session_id, source=source
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
return []
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class MemGalleryNativeAdapter(MemoryAdapter):
|
| 239 |
+
"""Thin wrapper that forwards to Mem-Gallery memories and normalizes I/O."""
|
| 240 |
+
|
| 241 |
+
def __init__(self, memory: Any, *, baseline_name: str) -> None:
|
| 242 |
+
self._memory = memory
|
| 243 |
+
self._baseline_name = baseline_name
|
| 244 |
+
self._session_id = ""
|
| 245 |
+
self._prev_snapshot_ids: set[str] = set()
|
| 246 |
+
self._pending_user_turn: NormalizedTurn | None = None
|
| 247 |
+
self._session_turns: list[str] = [] # collect turn texts for RF optimize
|
| 248 |
+
|
| 249 |
+
@classmethod
|
| 250 |
+
def from_baseline(
|
| 251 |
+
cls,
|
| 252 |
+
baseline_name: str,
|
| 253 |
+
*,
|
| 254 |
+
config: dict[str, Any] | None = None,
|
| 255 |
+
) -> MemGalleryNativeAdapter:
|
| 256 |
+
mem = instantiate_memgallery_memory(baseline_name, config)
|
| 257 |
+
return cls(mem, baseline_name=baseline_name)
|
| 258 |
+
|
| 259 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 260 |
+
"""Buffer user turns; store merged user+assistant pair on assistant turn.
|
| 261 |
+
|
| 262 |
+
This matches the original Mem-Gallery benchmark behavior where each
|
| 263 |
+
dialogue round (user + assistant) is merged into a single observation
|
| 264 |
+
before calling store().
|
| 265 |
+
"""
|
| 266 |
+
self._session_id = turn.session_id
|
| 267 |
+
if turn.role == "user":
|
| 268 |
+
# Flush any prior unpaired user turn, then buffer this one
|
| 269 |
+
if self._pending_user_turn is not None:
|
| 270 |
+
self._store_observation(self._pending_user_turn, assistant_turn=None)
|
| 271 |
+
self._pending_user_turn = turn
|
| 272 |
+
else:
|
| 273 |
+
# Assistant turn: merge with buffered user turn and store
|
| 274 |
+
self._store_observation(self._pending_user_turn, assistant_turn=turn)
|
| 275 |
+
self._pending_user_turn = None
|
| 276 |
+
|
| 277 |
+
def _store_observation(
|
| 278 |
+
self,
|
| 279 |
+
user_turn: NormalizedTurn | None,
|
| 280 |
+
assistant_turn: NormalizedTurn | None,
|
| 281 |
+
) -> None:
|
| 282 |
+
"""Build a merged observation dict (matching original benchmark format) and store."""
|
| 283 |
+
parts: list[str] = []
|
| 284 |
+
timestamp = None
|
| 285 |
+
dialogue_id = ""
|
| 286 |
+
if user_turn is not None:
|
| 287 |
+
parts.append(f"user: {user_turn.text}")
|
| 288 |
+
for att in user_turn.attachments:
|
| 289 |
+
parts.append(f"[{att.type}] {att.caption}")
|
| 290 |
+
timestamp = user_turn.timestamp
|
| 291 |
+
dialogue_id = f"{user_turn.session_id}:{user_turn.turn_index}"
|
| 292 |
+
if assistant_turn is not None:
|
| 293 |
+
parts.append(f"assistant: {assistant_turn.text}")
|
| 294 |
+
for att in assistant_turn.attachments:
|
| 295 |
+
parts.append(f"[{att.type}] {att.caption}")
|
| 296 |
+
if timestamp is None:
|
| 297 |
+
timestamp = assistant_turn.timestamp
|
| 298 |
+
if not dialogue_id:
|
| 299 |
+
dialogue_id = f"{assistant_turn.session_id}:{assistant_turn.turn_index}"
|
| 300 |
+
|
| 301 |
+
obs: dict[str, Any] = {"text": "\n".join(parts)}
|
| 302 |
+
if timestamp:
|
| 303 |
+
obs["timestamp"] = timestamp
|
| 304 |
+
obs["dialogue_id"] = dialogue_id
|
| 305 |
+
self._memory.store(obs)
|
| 306 |
+
self._session_turns.append(obs["text"])
|
| 307 |
+
|
| 308 |
+
def end_session(self, session_id: str) -> None:
|
| 309 |
+
# Flush any remaining unpaired user turn
|
| 310 |
+
if self._pending_user_turn is not None:
|
| 311 |
+
self._store_observation(self._pending_user_turn, assistant_turn=None)
|
| 312 |
+
self._pending_user_turn = None
|
| 313 |
+
|
| 314 |
+
# --- Trigger backend-specific post-session processing ---
|
| 315 |
+
# GAMemory: self-reflection generates insights and stores them
|
| 316 |
+
if self._baseline_name == "GAMemory":
|
| 317 |
+
try:
|
| 318 |
+
self._memory.manage("reflect")
|
| 319 |
+
except Exception:
|
| 320 |
+
pass # reflection may fail if accumulated importance < threshold
|
| 321 |
+
|
| 322 |
+
# RFMemory: optimize generates a global insight from the session trial
|
| 323 |
+
if self._baseline_name == "RFMemory" and self._session_turns:
|
| 324 |
+
try:
|
| 325 |
+
trial = "\n".join(self._session_turns)
|
| 326 |
+
self._memory.optimize(new_trial=trial)
|
| 327 |
+
except Exception:
|
| 328 |
+
pass
|
| 329 |
+
|
| 330 |
+
self._session_turns = []
|
| 331 |
+
|
| 332 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 333 |
+
sid = self._session_id or ""
|
| 334 |
+
return collect_memgallery_snapshots(
|
| 335 |
+
self._memory, self._baseline_name, sid
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 339 |
+
"""Export delta by diffing current backend snapshot against previous snapshot.
|
| 340 |
+
|
| 341 |
+
This reflects what the backend ACTUALLY stores, not what was fed in.
|
| 342 |
+
For FU/ST/LT/GA/RF (LinearStorage), this is the raw observations added.
|
| 343 |
+
For MGMemory, this includes FIFO items, summaries, and archival entries.
|
| 344 |
+
"""
|
| 345 |
+
current_snapshot = self.snapshot_memories()
|
| 346 |
+
prev_ids = self._prev_snapshot_ids
|
| 347 |
+
deltas: list[MemoryDeltaRecord] = []
|
| 348 |
+
current_ids: set[str] = set()
|
| 349 |
+
|
| 350 |
+
for snap in current_snapshot:
|
| 351 |
+
current_ids.add(snap.memory_id)
|
| 352 |
+
if snap.memory_id not in prev_ids:
|
| 353 |
+
deltas.append(
|
| 354 |
+
MemoryDeltaRecord(
|
| 355 |
+
session_id=session_id,
|
| 356 |
+
op="add",
|
| 357 |
+
text=snap.text,
|
| 358 |
+
linked_previous=(),
|
| 359 |
+
raw_backend_id=snap.raw_backend_id,
|
| 360 |
+
metadata={
|
| 361 |
+
"baseline": self._baseline_name,
|
| 362 |
+
"source": snap.source,
|
| 363 |
+
"backend_type": snap.raw_backend_type,
|
| 364 |
+
},
|
| 365 |
+
)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
self._prev_snapshot_ids = current_ids
|
| 369 |
+
return deltas
|
| 370 |
+
|
| 371 |
+
def reset(self) -> None:
|
| 372 |
+
self._memory.reset()
|
| 373 |
+
self._prev_snapshot_ids = set()
|
| 374 |
+
self._pending_user_turn = None
|
| 375 |
+
self._session_turns = []
|
| 376 |
+
|
| 377 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 378 |
+
raw = self._memory.recall(query)
|
| 379 |
+
trace: dict[str, Any] = {"baseline": self._baseline_name}
|
| 380 |
+
ro = getattr(self._memory, "recall_op", None)
|
| 381 |
+
if ro is not None and hasattr(ro, "last_retrieved_ids"):
|
| 382 |
+
trace["last_retrieved_ids"] = list(ro.last_retrieved_ids)
|
| 383 |
+
return normalize_recall_to_retrieval(query, top_k, raw, raw_trace=trace)
|
| 384 |
+
|
| 385 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 386 |
+
return {
|
| 387 |
+
"backend": "MemGallery",
|
| 388 |
+
"baseline": self._baseline_name,
|
| 389 |
+
"delta_granularity": "ingest_turn_only",
|
| 390 |
+
"snapshot_mode": "conservative",
|
| 391 |
+
"notes": (
|
| 392 |
+
"Deltas record adapter ingest only; backend-internal rewrite, reflection, "
|
| 393 |
+
"or graph reshaping is not diffed. Snapshots read observable storage where supported."
|
| 394 |
+
),
|
| 395 |
+
}
|
memory_adapters/memoryos.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adapter for the external MemoryOS baseline."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import importlib
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
import sys
|
| 9 |
+
import tempfile
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Callable
|
| 12 |
+
|
| 13 |
+
from eval_framework.datasets.schemas import (
|
| 14 |
+
MemoryDeltaRecord,
|
| 15 |
+
MemorySnapshotRecord,
|
| 16 |
+
NormalizedTurn,
|
| 17 |
+
RetrievalItem,
|
| 18 |
+
RetrievalRecord,
|
| 19 |
+
)
|
| 20 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 21 |
+
|
| 22 |
+
_BACKEND_ID = "MemoryOS"
|
| 23 |
+
|
| 24 |
+
INTEGRATION_ERROR = (
|
| 25 |
+
f"{_BACKEND_ID} backend unavailable."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MemoryOSAdapter(MemoryAdapter):
|
| 30 |
+
"""Thin wrapper around MemoryOS's local Python API."""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
*,
|
| 35 |
+
backend: Any | None = None,
|
| 36 |
+
backend_factory: Callable[[], Any] | None = None,
|
| 37 |
+
source_root: str | os.PathLike[str] | None = None,
|
| 38 |
+
storage_root: str | os.PathLike[str] | None = None,
|
| 39 |
+
user_id: str = "eval_user",
|
| 40 |
+
assistant_id: str = "eval_assistant",
|
| 41 |
+
llm_model: str | None = None,
|
| 42 |
+
embedding_model_name: str = "all-MiniLM-L6-v2",
|
| 43 |
+
openai_api_key: str | None = None,
|
| 44 |
+
openai_base_url: str | None = None,
|
| 45 |
+
) -> None:
|
| 46 |
+
self._source_root = Path(source_root).resolve() if source_root else self._default_source_root()
|
| 47 |
+
self._storage_root = Path(storage_root).resolve() if storage_root else Path(
|
| 48 |
+
tempfile.mkdtemp(prefix="memoryos_eval_")
|
| 49 |
+
)
|
| 50 |
+
self._user_id = user_id
|
| 51 |
+
self._assistant_id = assistant_id
|
| 52 |
+
self._llm_model = llm_model or os.getenv("OPENAI_MODEL") or "gpt-5.1"
|
| 53 |
+
self._embedding_model_name = embedding_model_name
|
| 54 |
+
self._openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
|
| 55 |
+
self._openai_base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
|
| 56 |
+
self._backend_factory = backend_factory
|
| 57 |
+
self._backend: Any | None = None
|
| 58 |
+
self._integration_error: str | None = None
|
| 59 |
+
self._session_id = ""
|
| 60 |
+
self._prev_snapshot_ids: set[str] = set()
|
| 61 |
+
self._pending_user_turns: list[NormalizedTurn] = []
|
| 62 |
+
|
| 63 |
+
if backend is not None:
|
| 64 |
+
self._backend = backend
|
| 65 |
+
else:
|
| 66 |
+
try:
|
| 67 |
+
if self._backend_factory is None:
|
| 68 |
+
self._backend_factory = self._build_backend_factory()
|
| 69 |
+
self._backend = self._backend_factory()
|
| 70 |
+
except Exception as exc:
|
| 71 |
+
self._integration_error = str(exc)
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def _default_source_root() -> Path:
|
| 75 |
+
here = Path(__file__).resolve()
|
| 76 |
+
# memory_adapters/ -> eval_framework/ -> nips26/ -> baselines/MemoryOS/memoryos-pypi
|
| 77 |
+
return (here.parents[2] / "baselines" / "MemoryOS" / "memoryos-pypi").resolve()
|
| 78 |
+
|
| 79 |
+
def _build_backend_factory(self) -> Callable[[], Any]:
|
| 80 |
+
if not self._source_root.is_dir():
|
| 81 |
+
raise RuntimeError(
|
| 82 |
+
f"{_BACKEND_ID}: source root not found at {self._source_root}"
|
| 83 |
+
)
|
| 84 |
+
src = str(self._source_root)
|
| 85 |
+
if src not in sys.path:
|
| 86 |
+
sys.path.insert(0, src)
|
| 87 |
+
mod = importlib.import_module("memoryos")
|
| 88 |
+
backend_cls = getattr(mod, "Memoryos")
|
| 89 |
+
|
| 90 |
+
def _factory() -> Any:
|
| 91 |
+
run_root = self._storage_root / "runtime"
|
| 92 |
+
shutil.rmtree(run_root, ignore_errors=True)
|
| 93 |
+
run_root.mkdir(parents=True, exist_ok=True)
|
| 94 |
+
return backend_cls(
|
| 95 |
+
user_id=self._user_id,
|
| 96 |
+
openai_api_key=self._openai_api_key or "",
|
| 97 |
+
openai_base_url=self._openai_base_url,
|
| 98 |
+
data_storage_path=str(run_root),
|
| 99 |
+
llm_model=self._llm_model,
|
| 100 |
+
assistant_id=self._assistant_id,
|
| 101 |
+
embedding_model_name=self._embedding_model_name,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return _factory
|
| 105 |
+
|
| 106 |
+
def _runtime_error(self) -> RuntimeError:
|
| 107 |
+
detail = self._integration_error or INTEGRATION_ERROR
|
| 108 |
+
return RuntimeError(
|
| 109 |
+
f"{_BACKEND_ID}: backend unavailable — {detail}"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def reset(self) -> None:
|
| 113 |
+
if self._backend_factory is None and self._backend is None:
|
| 114 |
+
raise self._runtime_error()
|
| 115 |
+
if self._backend_factory is not None:
|
| 116 |
+
self._backend = self._backend_factory()
|
| 117 |
+
self._prev_snapshot_ids = set()
|
| 118 |
+
self._pending_user_turns = []
|
| 119 |
+
self._session_id = ""
|
| 120 |
+
|
| 121 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 122 |
+
self._require_backend()
|
| 123 |
+
self._session_id = turn.session_id
|
| 124 |
+
if turn.role == "assistant":
|
| 125 |
+
self._store_pair(turn)
|
| 126 |
+
else:
|
| 127 |
+
self._pending_user_turns.append(turn)
|
| 128 |
+
|
| 129 |
+
def end_session(self, session_id: str) -> None:
|
| 130 |
+
self._require_backend()
|
| 131 |
+
self._session_id = session_id
|
| 132 |
+
if self._pending_user_turns:
|
| 133 |
+
synthetic = self._pending_user_turns[-1]
|
| 134 |
+
self._store_memory(
|
| 135 |
+
session_id=session_id,
|
| 136 |
+
user_input=self._joined_user_text(),
|
| 137 |
+
agent_response="",
|
| 138 |
+
timestamp=synthetic.timestamp,
|
| 139 |
+
)
|
| 140 |
+
self._pending_user_turns = []
|
| 141 |
+
|
| 142 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 143 |
+
backend = self._require_backend()
|
| 144 |
+
rows: list[MemorySnapshotRecord] = []
|
| 145 |
+
sid = self._session_id
|
| 146 |
+
|
| 147 |
+
for idx, qa in enumerate(backend.short_term_memory.get_all()):
|
| 148 |
+
rows.append(
|
| 149 |
+
MemorySnapshotRecord(
|
| 150 |
+
memory_id=f"st:{idx}",
|
| 151 |
+
text=self._format_qa_text(qa),
|
| 152 |
+
session_id=sid,
|
| 153 |
+
status="active",
|
| 154 |
+
source=_BACKEND_ID,
|
| 155 |
+
raw_backend_id=f"st:{idx}",
|
| 156 |
+
raw_backend_type="short_term",
|
| 157 |
+
metadata={"timestamp": qa.get("timestamp")},
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
for internal_session_id, session in getattr(backend.mid_term_memory, "sessions", {}).items():
|
| 162 |
+
for page_idx, page in enumerate(session.get("details", [])):
|
| 163 |
+
rows.append(
|
| 164 |
+
MemorySnapshotRecord(
|
| 165 |
+
memory_id=f"mt:{internal_session_id}:{page_idx}",
|
| 166 |
+
text=self._format_qa_text(page),
|
| 167 |
+
session_id=sid,
|
| 168 |
+
status="active",
|
| 169 |
+
source=_BACKEND_ID,
|
| 170 |
+
raw_backend_id=str(page.get("page_id", f"{internal_session_id}:{page_idx}")),
|
| 171 |
+
raw_backend_type="mid_term_page",
|
| 172 |
+
metadata={"memoryos_session_id": internal_session_id},
|
| 173 |
+
)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
user_profile = backend.user_long_term_memory.get_raw_user_profile(backend.user_id)
|
| 177 |
+
if user_profile and str(user_profile).lower() != "none":
|
| 178 |
+
rows.append(
|
| 179 |
+
MemorySnapshotRecord(
|
| 180 |
+
memory_id="lt:user_profile",
|
| 181 |
+
text=str(user_profile),
|
| 182 |
+
session_id=sid,
|
| 183 |
+
status="active",
|
| 184 |
+
source=_BACKEND_ID,
|
| 185 |
+
raw_backend_id="user_profile",
|
| 186 |
+
raw_backend_type="user_profile",
|
| 187 |
+
metadata={},
|
| 188 |
+
)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
for idx, item in enumerate(backend.user_long_term_memory.get_user_knowledge()):
|
| 192 |
+
rows.append(
|
| 193 |
+
MemorySnapshotRecord(
|
| 194 |
+
memory_id=f"lt:user:{idx}",
|
| 195 |
+
text=str(item.get("knowledge", "")),
|
| 196 |
+
session_id=sid,
|
| 197 |
+
status="active",
|
| 198 |
+
source=_BACKEND_ID,
|
| 199 |
+
raw_backend_id=f"user:{idx}",
|
| 200 |
+
raw_backend_type="user_knowledge",
|
| 201 |
+
metadata={"timestamp": item.get("timestamp")},
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
assistant_ltm = getattr(backend, "assistant_long_term_memory", None)
|
| 206 |
+
if assistant_ltm is not None and hasattr(assistant_ltm, "get_assistant_knowledge"):
|
| 207 |
+
for idx, item in enumerate(assistant_ltm.get_assistant_knowledge()):
|
| 208 |
+
rows.append(
|
| 209 |
+
MemorySnapshotRecord(
|
| 210 |
+
memory_id=f"lt:assistant:{idx}",
|
| 211 |
+
text=str(item.get("knowledge", "")),
|
| 212 |
+
session_id=sid,
|
| 213 |
+
status="active",
|
| 214 |
+
source=_BACKEND_ID,
|
| 215 |
+
raw_backend_id=f"assistant:{idx}",
|
| 216 |
+
raw_backend_type="assistant_knowledge",
|
| 217 |
+
metadata={"timestamp": item.get("timestamp")},
|
| 218 |
+
)
|
| 219 |
+
)
|
| 220 |
+
return rows
|
| 221 |
+
|
| 222 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 223 |
+
"""Export delta by diffing current snapshot against previous snapshot."""
|
| 224 |
+
self._require_backend()
|
| 225 |
+
current_snapshot = self.snapshot_memories()
|
| 226 |
+
deltas: list[MemoryDeltaRecord] = []
|
| 227 |
+
current_ids: set[str] = set()
|
| 228 |
+
|
| 229 |
+
for snap in current_snapshot:
|
| 230 |
+
current_ids.add(snap.memory_id)
|
| 231 |
+
if snap.memory_id not in self._prev_snapshot_ids:
|
| 232 |
+
deltas.append(
|
| 233 |
+
MemoryDeltaRecord(
|
| 234 |
+
session_id=session_id,
|
| 235 |
+
op="add",
|
| 236 |
+
text=snap.text,
|
| 237 |
+
linked_previous=(),
|
| 238 |
+
raw_backend_id=snap.raw_backend_id,
|
| 239 |
+
metadata={
|
| 240 |
+
"baseline": _BACKEND_ID,
|
| 241 |
+
"backend_type": snap.raw_backend_type,
|
| 242 |
+
},
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
self._prev_snapshot_ids = current_ids
|
| 247 |
+
return deltas
|
| 248 |
+
|
| 249 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 250 |
+
backend = self._require_backend()
|
| 251 |
+
raw = backend.retriever.retrieve_context(query, user_id=backend.user_id)
|
| 252 |
+
items: list[RetrievalItem] = []
|
| 253 |
+
|
| 254 |
+
for page in raw.get("retrieved_pages", []):
|
| 255 |
+
items.append(
|
| 256 |
+
RetrievalItem(
|
| 257 |
+
rank=len(items),
|
| 258 |
+
memory_id=f"page:{len(items)}",
|
| 259 |
+
text=self._format_qa_text(page),
|
| 260 |
+
score=1.0 / float(len(items) + 1),
|
| 261 |
+
raw_backend_id=page.get("page_id"),
|
| 262 |
+
)
|
| 263 |
+
)
|
| 264 |
+
for item in raw.get("retrieved_user_knowledge", []):
|
| 265 |
+
items.append(
|
| 266 |
+
RetrievalItem(
|
| 267 |
+
rank=len(items),
|
| 268 |
+
memory_id=f"user:{len(items)}",
|
| 269 |
+
text=str(item.get("knowledge", "")),
|
| 270 |
+
score=1.0 / float(len(items) + 1),
|
| 271 |
+
raw_backend_id=None,
|
| 272 |
+
)
|
| 273 |
+
)
|
| 274 |
+
for item in raw.get("retrieved_assistant_knowledge", []):
|
| 275 |
+
items.append(
|
| 276 |
+
RetrievalItem(
|
| 277 |
+
rank=len(items),
|
| 278 |
+
memory_id=f"assistant:{len(items)}",
|
| 279 |
+
text=str(item.get("knowledge", "")),
|
| 280 |
+
score=1.0 / float(len(items) + 1),
|
| 281 |
+
raw_backend_id=None,
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
+
return RetrievalRecord(
|
| 285 |
+
query=query,
|
| 286 |
+
top_k=top_k,
|
| 287 |
+
items=items[:top_k],
|
| 288 |
+
raw_trace={"baseline": _BACKEND_ID, "retrieved_at": raw.get("retrieved_at")},
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 292 |
+
available = self._backend is not None or self._backend_factory is not None
|
| 293 |
+
return {
|
| 294 |
+
"backend": _BACKEND_ID,
|
| 295 |
+
"baseline": _BACKEND_ID,
|
| 296 |
+
"available": available and self._integration_error is None,
|
| 297 |
+
"integration_status": "integrated" if available and self._integration_error is None else "unavailable",
|
| 298 |
+
"integration_error": self._integration_error or INTEGRATION_ERROR,
|
| 299 |
+
"delta_granularity": "ingest_pair_only",
|
| 300 |
+
"snapshot_mode": "short_mid_long_term",
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
def _require_backend(self) -> Any:
|
| 304 |
+
if self._backend is None:
|
| 305 |
+
raise self._runtime_error()
|
| 306 |
+
return self._backend
|
| 307 |
+
|
| 308 |
+
def _store_pair(self, assistant_turn: NormalizedTurn) -> None:
|
| 309 |
+
user_input = self._joined_user_text()
|
| 310 |
+
self._store_memory(
|
| 311 |
+
session_id=assistant_turn.session_id,
|
| 312 |
+
user_input=user_input,
|
| 313 |
+
agent_response=self._turn_text(assistant_turn),
|
| 314 |
+
timestamp=assistant_turn.timestamp,
|
| 315 |
+
)
|
| 316 |
+
self._pending_user_turns = []
|
| 317 |
+
|
| 318 |
+
def _store_memory(
|
| 319 |
+
self,
|
| 320 |
+
*,
|
| 321 |
+
session_id: str,
|
| 322 |
+
user_input: str,
|
| 323 |
+
agent_response: str,
|
| 324 |
+
timestamp: str | None,
|
| 325 |
+
) -> None:
|
| 326 |
+
backend = self._require_backend()
|
| 327 |
+
backend.add_memory(
|
| 328 |
+
user_input=user_input,
|
| 329 |
+
agent_response=agent_response,
|
| 330 |
+
timestamp=timestamp,
|
| 331 |
+
meta_data={"session_id": session_id},
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
def _joined_user_text(self) -> str:
|
| 335 |
+
if not self._pending_user_turns:
|
| 336 |
+
return ""
|
| 337 |
+
return "\n".join(self._turn_text(turn) for turn in self._pending_user_turns)
|
| 338 |
+
|
| 339 |
+
@staticmethod
|
| 340 |
+
def _turn_text(turn: NormalizedTurn) -> str:
|
| 341 |
+
parts = [turn.text]
|
| 342 |
+
for att in turn.attachments:
|
| 343 |
+
parts.append(f"[{att.type}] {att.caption}")
|
| 344 |
+
return "\n".join(parts)
|
| 345 |
+
|
| 346 |
+
@staticmethod
|
| 347 |
+
def _format_qa_text(item: dict[str, Any]) -> str:
|
| 348 |
+
parts = []
|
| 349 |
+
user_text = item.get("user_input", "")
|
| 350 |
+
if user_text:
|
| 351 |
+
parts.append(f"user: {user_text}")
|
| 352 |
+
assistant_text = item.get("agent_response", "")
|
| 353 |
+
if assistant_text:
|
| 354 |
+
parts.append(f"assistant: {assistant_text}")
|
| 355 |
+
if not parts:
|
| 356 |
+
parts.append(str(item))
|
| 357 |
+
return "\n".join(parts)
|
memory_adapters/memverse_adapter.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adapter for MemVerse — uses build_memory for storage + cosine retrieval."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import shutil
|
| 9 |
+
import tempfile
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
load_dotenv(Path(__file__).resolve().parents[2] / ".env")
|
| 17 |
+
|
| 18 |
+
from eval_framework.datasets.schemas import (
|
| 19 |
+
MemoryDeltaRecord,
|
| 20 |
+
MemorySnapshotRecord,
|
| 21 |
+
NormalizedTurn,
|
| 22 |
+
RetrievalItem,
|
| 23 |
+
RetrievalRecord,
|
| 24 |
+
)
|
| 25 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 26 |
+
|
| 27 |
+
_DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/MemVerse")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MemVerseAdapter(MemoryAdapter):
|
| 31 |
+
"""Adapter for MemVerse using build_memory + cosine retrieval.
|
| 32 |
+
|
| 33 |
+
Bypasses the async orchestrator/LightRAG and uses the core
|
| 34 |
+
memory building + embedding-based retrieval directly.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
*,
|
| 40 |
+
source_root: str | os.PathLike[str] | None = None,
|
| 41 |
+
**kwargs: Any,
|
| 42 |
+
) -> None:
|
| 43 |
+
root = Path(source_root or _DEFAULT_SOURCE).resolve()
|
| 44 |
+
if str(root) not in sys.path:
|
| 45 |
+
sys.path.insert(0, str(root))
|
| 46 |
+
|
| 47 |
+
from openai import OpenAI
|
| 48 |
+
|
| 49 |
+
self._client = OpenAI(
|
| 50 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
| 51 |
+
base_url=os.getenv("OPENAI_BASE_URL"),
|
| 52 |
+
)
|
| 53 |
+
self._model = os.getenv("OPENAI_MODEL") or "gpt-4o"
|
| 54 |
+
|
| 55 |
+
# Working directory for memory files
|
| 56 |
+
self._work_dir = Path(tempfile.mkdtemp(prefix="memverse_eval_"))
|
| 57 |
+
self._root = root
|
| 58 |
+
self._session_id = ""
|
| 59 |
+
self._prev_snapshot_ids: set[str] = set()
|
| 60 |
+
self._memories: list[dict[str, Any]] = [] # {id, text, embedding, output}
|
| 61 |
+
self._conversation: list[dict[str, Any]] = []
|
| 62 |
+
self._turn_counter = 0
|
| 63 |
+
|
| 64 |
+
# Load system prompts for memory agents
|
| 65 |
+
self._prompts: dict[str, str] = {}
|
| 66 |
+
for name in ["core_memory_agent", "episodic_memory_agent", "semantic_memory_agent"]:
|
| 67 |
+
prompt_path = root / "MemoryKB" / "Long_Term_Memory" / "system" / f"{name}.txt"
|
| 68 |
+
if prompt_path.exists():
|
| 69 |
+
self._prompts[name] = prompt_path.read_text(encoding="utf-8").strip()
|
| 70 |
+
|
| 71 |
+
def _get_embedding(self, text: str) -> np.ndarray:
|
| 72 |
+
resp = self._client.embeddings.create(
|
| 73 |
+
model="text-embedding-3-small",
|
| 74 |
+
input=text,
|
| 75 |
+
)
|
| 76 |
+
return np.array(resp.data[0].embedding)
|
| 77 |
+
|
| 78 |
+
def _cosine_sim(self, a: np.ndarray, b: np.ndarray) -> float:
|
| 79 |
+
norm = np.linalg.norm(a) * np.linalg.norm(b)
|
| 80 |
+
if norm == 0:
|
| 81 |
+
return 0.0
|
| 82 |
+
return float(np.dot(a, b) / norm)
|
| 83 |
+
|
| 84 |
+
def reset(self) -> None:
|
| 85 |
+
self._memories = []
|
| 86 |
+
self._conversation = []
|
| 87 |
+
self._prev_snapshot_ids = set()
|
| 88 |
+
self._turn_counter = 0
|
| 89 |
+
if self._work_dir.exists():
|
| 90 |
+
shutil.rmtree(self._work_dir, ignore_errors=True)
|
| 91 |
+
self._work_dir = Path(tempfile.mkdtemp(prefix="memverse_eval_"))
|
| 92 |
+
|
| 93 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 94 |
+
self._session_id = turn.session_id
|
| 95 |
+
text = f"{turn.role}: {turn.text}"
|
| 96 |
+
for att in turn.attachments:
|
| 97 |
+
text += f"\n[{att.type}] {att.caption}"
|
| 98 |
+
|
| 99 |
+
entry_id = f"turn_{self._turn_counter}"
|
| 100 |
+
self._turn_counter += 1
|
| 101 |
+
|
| 102 |
+
entry = {
|
| 103 |
+
"id": entry_id,
|
| 104 |
+
"query": text,
|
| 105 |
+
"videocaption": None,
|
| 106 |
+
"audiocaption": None,
|
| 107 |
+
"imagecaption": None,
|
| 108 |
+
}
|
| 109 |
+
self._conversation.append(entry)
|
| 110 |
+
|
| 111 |
+
# Build memory: get embedding + LLM extraction for each memory type
|
| 112 |
+
embedding = self._get_embedding(text)
|
| 113 |
+
|
| 114 |
+
# Use the first available prompt (core memory agent) for extraction
|
| 115 |
+
prompt = next(iter(self._prompts.values()), "Extract key facts from this text.")
|
| 116 |
+
try:
|
| 117 |
+
resp = self._client.chat.completions.create(
|
| 118 |
+
model=self._model,
|
| 119 |
+
messages=[
|
| 120 |
+
{"role": "system", "content": prompt},
|
| 121 |
+
{"role": "user", "content": text},
|
| 122 |
+
],
|
| 123 |
+
temperature=0,
|
| 124 |
+
max_tokens=512,
|
| 125 |
+
)
|
| 126 |
+
output = resp.choices[0].message.content or ""
|
| 127 |
+
except Exception:
|
| 128 |
+
output = text
|
| 129 |
+
|
| 130 |
+
self._memories.append({
|
| 131 |
+
"id": entry_id,
|
| 132 |
+
"text": text,
|
| 133 |
+
"output": output,
|
| 134 |
+
"embedding": embedding,
|
| 135 |
+
"session_id": turn.session_id,
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
def end_session(self, session_id: str) -> None:
|
| 139 |
+
self._session_id = session_id
|
| 140 |
+
|
| 141 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 142 |
+
return [
|
| 143 |
+
MemorySnapshotRecord(
|
| 144 |
+
memory_id=m["id"],
|
| 145 |
+
text=m["output"],
|
| 146 |
+
session_id=m.get("session_id", self._session_id),
|
| 147 |
+
status="active",
|
| 148 |
+
source="MemVerse",
|
| 149 |
+
raw_backend_id=m["id"],
|
| 150 |
+
raw_backend_type="memverse",
|
| 151 |
+
metadata={},
|
| 152 |
+
)
|
| 153 |
+
for m in self._memories
|
| 154 |
+
]
|
| 155 |
+
|
| 156 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 157 |
+
current = self.snapshot_memories()
|
| 158 |
+
current_ids = {s.memory_id for s in current}
|
| 159 |
+
deltas = [
|
| 160 |
+
MemoryDeltaRecord(
|
| 161 |
+
session_id=session_id, op="add", text=s.text,
|
| 162 |
+
linked_previous=(), raw_backend_id=s.raw_backend_id,
|
| 163 |
+
metadata={"baseline": "MemVerse"},
|
| 164 |
+
)
|
| 165 |
+
for s in current if s.memory_id not in self._prev_snapshot_ids
|
| 166 |
+
]
|
| 167 |
+
self._prev_snapshot_ids = current_ids
|
| 168 |
+
return deltas
|
| 169 |
+
|
| 170 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 171 |
+
if not self._memories:
|
| 172 |
+
return RetrievalRecord(query=query, top_k=top_k, items=[], raw_trace={})
|
| 173 |
+
|
| 174 |
+
query_emb = self._get_embedding(query)
|
| 175 |
+
scored = []
|
| 176 |
+
for m in self._memories:
|
| 177 |
+
sim = self._cosine_sim(query_emb, m["embedding"])
|
| 178 |
+
scored.append((sim, m))
|
| 179 |
+
scored.sort(key=lambda x: x[0], reverse=True)
|
| 180 |
+
|
| 181 |
+
items = [
|
| 182 |
+
RetrievalItem(
|
| 183 |
+
rank=i,
|
| 184 |
+
memory_id=m["id"],
|
| 185 |
+
text=m["output"],
|
| 186 |
+
score=float(sim),
|
| 187 |
+
raw_backend_id=m["id"],
|
| 188 |
+
)
|
| 189 |
+
for i, (sim, m) in enumerate(scored[:top_k])
|
| 190 |
+
]
|
| 191 |
+
return RetrievalRecord(
|
| 192 |
+
query=query, top_k=top_k, items=items,
|
| 193 |
+
raw_trace={"baseline": "MemVerse"},
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 197 |
+
return {
|
| 198 |
+
"backend": "MemVerse",
|
| 199 |
+
"baseline": "MemVerse",
|
| 200 |
+
"available": True,
|
| 201 |
+
"delta_granularity": "per_turn",
|
| 202 |
+
"snapshot_mode": "full",
|
| 203 |
+
}
|
memory_adapters/registry.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Registry and factory for native Mem-Gallery and external placeholder adapters."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
import types
|
| 8 |
+
from contextlib import nullcontext
|
| 9 |
+
from functools import partial
|
| 10 |
+
import importlib
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any, Callable
|
| 13 |
+
|
| 14 |
+
from eval_framework.memory_adapters.amem import AMemAdapter
|
| 15 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 16 |
+
from eval_framework.memory_adapters.dummy import DummyAdapter
|
| 17 |
+
from eval_framework.memory_adapters.memgallery_native import MemGalleryNativeAdapter
|
| 18 |
+
from eval_framework.memory_adapters.memoryos import MemoryOSAdapter
|
| 19 |
+
|
| 20 |
+
MEMGALLERY_NATIVE_BASELINES: frozenset[str] = frozenset(
|
| 21 |
+
{
|
| 22 |
+
"FUMemory",
|
| 23 |
+
"STMemory",
|
| 24 |
+
"LTMemory",
|
| 25 |
+
"GAMemory",
|
| 26 |
+
"MGMemory",
|
| 27 |
+
"RFMemory",
|
| 28 |
+
"MMMemory",
|
| 29 |
+
"MMFUMemory",
|
| 30 |
+
"NGMemory",
|
| 31 |
+
"AUGUSTUSMemory",
|
| 32 |
+
"UniversalRAGMemory",
|
| 33 |
+
}
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _word_mode_truncation(number: int = 50_000) -> dict[str, Any]:
|
| 38 |
+
return {
|
| 39 |
+
"method": "LMTruncation",
|
| 40 |
+
"mode": "word",
|
| 41 |
+
"number": number,
|
| 42 |
+
"path": "",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _text_encoder_override() -> dict[str, Any]:
|
| 47 |
+
return {
|
| 48 |
+
"method": "STEncoder",
|
| 49 |
+
"path": "all-MiniLM-L6-v2",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _openai_llm_override() -> dict[str, Any]:
|
| 54 |
+
return {
|
| 55 |
+
"method": "APILLM",
|
| 56 |
+
"name": os.getenv("OPENAI_MODEL") or "gpt-5.1",
|
| 57 |
+
"api_key": os.getenv("OPENAI_API_KEY") or "",
|
| 58 |
+
"base_url": os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1",
|
| 59 |
+
"temperature": float(os.getenv("OPENAI_TEMPERATURE", "0.0")),
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _default_memgallery_runtime_overrides(baseline_name: str) -> dict[str, Any]:
|
| 64 |
+
overrides: dict[str, Any] = {}
|
| 65 |
+
|
| 66 |
+
# --- text-only baselines ---
|
| 67 |
+
if baseline_name in {"FUMemory", "STMemory", "LTMemory", "RFMemory"}:
|
| 68 |
+
overrides["recall"] = {"truncation": _word_mode_truncation()}
|
| 69 |
+
if baseline_name == "LTMemory":
|
| 70 |
+
overrides.setdefault("recall", {})
|
| 71 |
+
overrides["recall"]["text_retrieval"] = {"encoder": _text_encoder_override()}
|
| 72 |
+
if baseline_name == "GAMemory":
|
| 73 |
+
overrides = {
|
| 74 |
+
"recall": {
|
| 75 |
+
"truncation": _word_mode_truncation(),
|
| 76 |
+
"text_retrieval": {"encoder": _text_encoder_override()},
|
| 77 |
+
"importance_judge": {"LLM_config": _openai_llm_override()},
|
| 78 |
+
},
|
| 79 |
+
"reflect": {
|
| 80 |
+
"reflector": {"LLM_config": _openai_llm_override()},
|
| 81 |
+
},
|
| 82 |
+
}
|
| 83 |
+
if baseline_name == "MGMemory":
|
| 84 |
+
overrides = {
|
| 85 |
+
"recall": {
|
| 86 |
+
"truncation": _word_mode_truncation(),
|
| 87 |
+
"recall_retrieval": {"encoder": _text_encoder_override()},
|
| 88 |
+
"archival_retrieval": {"encoder": _text_encoder_override()},
|
| 89 |
+
"trigger": {"LLM_config": _openai_llm_override()},
|
| 90 |
+
},
|
| 91 |
+
"store": {
|
| 92 |
+
"flush_checker": _word_mode_truncation(),
|
| 93 |
+
"summarizer": {"LLM_config": _openai_llm_override()},
|
| 94 |
+
},
|
| 95 |
+
}
|
| 96 |
+
if baseline_name == "RFMemory":
|
| 97 |
+
overrides.setdefault("optimize", {})
|
| 98 |
+
overrides["optimize"] = {
|
| 99 |
+
"reflector": {"LLM_config": _openai_llm_override()},
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
# --- multimodal baselines ---
|
| 103 |
+
if baseline_name == "MMMemory":
|
| 104 |
+
overrides = {
|
| 105 |
+
"recall": {
|
| 106 |
+
"truncation": _word_mode_truncation(),
|
| 107 |
+
},
|
| 108 |
+
}
|
| 109 |
+
if baseline_name == "MMFUMemory":
|
| 110 |
+
overrides = {
|
| 111 |
+
"recall": {
|
| 112 |
+
"truncation": _word_mode_truncation(),
|
| 113 |
+
},
|
| 114 |
+
}
|
| 115 |
+
if baseline_name == "NGMemory":
|
| 116 |
+
overrides = {
|
| 117 |
+
"recall": {
|
| 118 |
+
"truncation": _word_mode_truncation(),
|
| 119 |
+
},
|
| 120 |
+
}
|
| 121 |
+
if baseline_name == "AUGUSTUSMemory":
|
| 122 |
+
overrides = {
|
| 123 |
+
"recall": {
|
| 124 |
+
"truncation": _word_mode_truncation(),
|
| 125 |
+
},
|
| 126 |
+
"concept_extractor": {
|
| 127 |
+
"llm": _openai_llm_override(),
|
| 128 |
+
},
|
| 129 |
+
}
|
| 130 |
+
if baseline_name == "UniversalRAGMemory":
|
| 131 |
+
overrides = {
|
| 132 |
+
"recall": {
|
| 133 |
+
"truncation": _word_mode_truncation(),
|
| 134 |
+
"text_retrieval": {"encoder": _text_encoder_override()},
|
| 135 |
+
},
|
| 136 |
+
"routing": {
|
| 137 |
+
"llm": _openai_llm_override(),
|
| 138 |
+
},
|
| 139 |
+
}
|
| 140 |
+
return overrides
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def _resolve_baselines_root() -> Path:
|
| 144 |
+
"""Return the ``baselines/`` directory (sibling of eval_framework/).
|
| 145 |
+
|
| 146 |
+
Layout::
|
| 147 |
+
|
| 148 |
+
nips26/
|
| 149 |
+
├── eval_framework/
|
| 150 |
+
└── baselines/
|
| 151 |
+
├── memengine/
|
| 152 |
+
└── default_config/
|
| 153 |
+
"""
|
| 154 |
+
# registry.py -> memory_adapters/ -> eval_framework/ -> nips26/
|
| 155 |
+
return Path(__file__).resolve().parents[2] / "baselines"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _ensure_memgallery_benchmark_on_path() -> Path:
|
| 159 |
+
"""Add ``baselines/`` to sys.path so that ``memengine`` and
|
| 160 |
+
``default_config`` packages are importable."""
|
| 161 |
+
baselines_root = _resolve_baselines_root()
|
| 162 |
+
if not (baselines_root / "memengine").is_dir():
|
| 163 |
+
raise FileNotFoundError(
|
| 164 |
+
f"memengine/ not found under {baselines_root}. "
|
| 165 |
+
f"Clone MemEngine into baselines/memengine."
|
| 166 |
+
)
|
| 167 |
+
s = str(baselines_root)
|
| 168 |
+
if s not in sys.path:
|
| 169 |
+
sys.path.insert(0, s)
|
| 170 |
+
_bootstrap_memengine_namespace(baselines_root)
|
| 171 |
+
return baselines_root
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _bootstrap_memengine_namespace(root: Path) -> None:
|
| 175 |
+
"""
|
| 176 |
+
Pre-seed lightweight namespace packages for the co-located memengine package.
|
| 177 |
+
|
| 178 |
+
memengine's package-level ``__init__.py`` eagerly imports all memories and function
|
| 179 |
+
modules, which pulls in heavyweight optional dependencies like ``torch`` even for
|
| 180 |
+
simple baselines such as ``FUMemory``. By registering package shells in ``sys.modules``
|
| 181 |
+
first, we can import only the specific submodules we need.
|
| 182 |
+
|
| 183 |
+
*root* is the ``our/`` directory that contains ``memengine/``.
|
| 184 |
+
"""
|
| 185 |
+
package_paths = {
|
| 186 |
+
"memengine": root / "memengine",
|
| 187 |
+
"memengine.config": root / "memengine" / "config",
|
| 188 |
+
"memengine.memory": root / "memengine" / "memory",
|
| 189 |
+
"memengine.function": root / "memengine" / "function",
|
| 190 |
+
"memengine.operation": root / "memengine" / "operation",
|
| 191 |
+
"memengine.utils": root / "memengine" / "utils",
|
| 192 |
+
}
|
| 193 |
+
for package_name, package_path in package_paths.items():
|
| 194 |
+
existing = sys.modules.get(package_name)
|
| 195 |
+
if existing is not None:
|
| 196 |
+
continue
|
| 197 |
+
module = types.ModuleType(package_name)
|
| 198 |
+
module.__path__ = [str(package_path)] # type: ignore[attr-defined]
|
| 199 |
+
module.__package__ = package_name
|
| 200 |
+
sys.modules[package_name] = module
|
| 201 |
+
|
| 202 |
+
for package_name in package_paths:
|
| 203 |
+
if "." not in package_name:
|
| 204 |
+
continue
|
| 205 |
+
parent_name, child_name = package_name.rsplit(".", 1)
|
| 206 |
+
parent = sys.modules.get(parent_name)
|
| 207 |
+
child = sys.modules.get(package_name)
|
| 208 |
+
if parent is not None and child is not None and not hasattr(parent, child_name):
|
| 209 |
+
setattr(parent, child_name, child)
|
| 210 |
+
|
| 211 |
+
_bootstrap_optional_dependency_stubs()
|
| 212 |
+
_populate_safe_memengine_function_exports()
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _bootstrap_optional_dependency_stubs() -> None:
|
| 216 |
+
"""Provide narrow stubs for optional imports needed only on unused code paths."""
|
| 217 |
+
if "torch" not in sys.modules:
|
| 218 |
+
try:
|
| 219 |
+
sys.modules["torch"] = importlib.import_module("torch")
|
| 220 |
+
except Exception:
|
| 221 |
+
pass
|
| 222 |
+
if "torch" not in sys.modules:
|
| 223 |
+
torch_module = types.ModuleType("torch")
|
| 224 |
+
|
| 225 |
+
def _torch_unavailable(*args: Any, **kwargs: Any) -> Any:
|
| 226 |
+
del args, kwargs
|
| 227 |
+
raise RuntimeError(
|
| 228 |
+
"PyTorch is required for encoder-backed or tensor-based Mem-Gallery "
|
| 229 |
+
"baselines, but `torch` is not installed in this environment."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
torch_module.cuda = types.SimpleNamespace(is_available=lambda: False) # type: ignore[attr-defined]
|
| 233 |
+
torch_module.device = lambda spec: spec # type: ignore[attr-defined]
|
| 234 |
+
torch_module.no_grad = lambda: nullcontext() # type: ignore[attr-defined]
|
| 235 |
+
torch_module.from_numpy = _torch_unavailable # type: ignore[attr-defined]
|
| 236 |
+
torch_module.stack = _torch_unavailable # type: ignore[attr-defined]
|
| 237 |
+
torch_module.sort = _torch_unavailable # type: ignore[attr-defined]
|
| 238 |
+
torch_module.matmul = _torch_unavailable # type: ignore[attr-defined]
|
| 239 |
+
torch_module.ones = _torch_unavailable # type: ignore[attr-defined]
|
| 240 |
+
torch_module.nn = types.SimpleNamespace( # type: ignore[attr-defined]
|
| 241 |
+
functional=types.SimpleNamespace(normalize=_torch_unavailable)
|
| 242 |
+
)
|
| 243 |
+
sys.modules["torch"] = torch_module
|
| 244 |
+
|
| 245 |
+
if "transformers" not in sys.modules:
|
| 246 |
+
try:
|
| 247 |
+
sys.modules["transformers"] = importlib.import_module("transformers")
|
| 248 |
+
except Exception:
|
| 249 |
+
pass
|
| 250 |
+
if "transformers" not in sys.modules:
|
| 251 |
+
transformers_module = types.ModuleType("transformers")
|
| 252 |
+
|
| 253 |
+
class _UnavailableAutoTokenizer:
|
| 254 |
+
@classmethod
|
| 255 |
+
def from_pretrained(cls, *args: Any, **kwargs: Any) -> Any:
|
| 256 |
+
del args, kwargs
|
| 257 |
+
raise RuntimeError(
|
| 258 |
+
"transformers.AutoTokenizer is required for token-mode truncation "
|
| 259 |
+
"or encoder-backed baselines, but `transformers` is not installed."
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
transformers_module.AutoTokenizer = _UnavailableAutoTokenizer # type: ignore[attr-defined]
|
| 263 |
+
sys.modules["transformers"] = transformers_module
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _populate_safe_memengine_function_exports() -> None:
|
| 267 |
+
"""Expose all function symbols for complete baseline deployment without running package __init__."""
|
| 268 |
+
function_pkg = sys.modules.get("memengine.function")
|
| 269 |
+
if function_pkg is None:
|
| 270 |
+
return
|
| 271 |
+
|
| 272 |
+
# Complete list — covers every module referenced by any of the 11 baselines:
|
| 273 |
+
# FU/ST/LT/GA/MG/RF (text-only) + MM/MMFU/NG/AUGUSTUS/UniversalRAG (multimodal)
|
| 274 |
+
for module_name in (
|
| 275 |
+
# --- text-only baselines ---
|
| 276 |
+
"memengine.function.Encoder",
|
| 277 |
+
"memengine.function.Retrieval",
|
| 278 |
+
"memengine.function.LLM",
|
| 279 |
+
"memengine.function.Judge",
|
| 280 |
+
"memengine.function.Reflector",
|
| 281 |
+
"memengine.function.Summarizer",
|
| 282 |
+
"memengine.function.Truncation",
|
| 283 |
+
"memengine.function.Trigger",
|
| 284 |
+
"memengine.function.Utilization",
|
| 285 |
+
"memengine.function.Forget",
|
| 286 |
+
# --- multimodal / graph / concept baselines ---
|
| 287 |
+
"memengine.function.MultiModalEncoder",
|
| 288 |
+
"memengine.function.MultiModalRetrieval",
|
| 289 |
+
"memengine.function.ConceptExtractor",
|
| 290 |
+
"memengine.function.ConceptBasedRetrieval",
|
| 291 |
+
"memengine.function.EntityExtractor",
|
| 292 |
+
"memengine.function.FactExtractor",
|
| 293 |
+
"memengine.function.UniversalRAGRouting",
|
| 294 |
+
"memengine.function.UniversalRAGRetrieval",
|
| 295 |
+
):
|
| 296 |
+
try:
|
| 297 |
+
module = importlib.import_module(module_name)
|
| 298 |
+
except Exception:
|
| 299 |
+
# Some modules may depend on optional heavy deps (torch, transformers).
|
| 300 |
+
# Skip gracefully — they will fail loudly if the baseline actually needs them.
|
| 301 |
+
continue
|
| 302 |
+
for attr_name, value in vars(module).items():
|
| 303 |
+
if attr_name.startswith("_"):
|
| 304 |
+
continue
|
| 305 |
+
if not hasattr(function_pkg, attr_name):
|
| 306 |
+
setattr(function_pkg, attr_name, value)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def create_memgallery_adapter(
|
| 310 |
+
baseline_name: str,
|
| 311 |
+
*,
|
| 312 |
+
config_overrides: dict[str, Any] | None = None,
|
| 313 |
+
) -> MemGalleryNativeAdapter:
|
| 314 |
+
"""
|
| 315 |
+
Instantiate a native Mem-Gallery adapter for a known baseline name.
|
| 316 |
+
|
| 317 |
+
Loads default_config + memengine from the Mem-Gallery benchmark tree.
|
| 318 |
+
"""
|
| 319 |
+
if baseline_name not in MEMGALLERY_NATIVE_BASELINES:
|
| 320 |
+
raise KeyError(f"unknown Mem-Gallery baseline: {baseline_name!r}")
|
| 321 |
+
_ensure_memgallery_benchmark_on_path()
|
| 322 |
+
runtime_overrides = _default_memgallery_runtime_overrides(baseline_name)
|
| 323 |
+
if config_overrides:
|
| 324 |
+
runtime_overrides = {
|
| 325 |
+
**runtime_overrides,
|
| 326 |
+
**config_overrides,
|
| 327 |
+
}
|
| 328 |
+
return MemGalleryNativeAdapter.from_baseline(
|
| 329 |
+
baseline_name, config=runtime_overrides or None
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
MEMGALLERY_NATIVE_REGISTRY: dict[str, Callable[..., MemGalleryNativeAdapter]] = {
|
| 334 |
+
name: partial(create_memgallery_adapter, name) for name in MEMGALLERY_NATIVE_BASELINES
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
EXTERNAL_ADAPTER_KEYS: frozenset[str] = frozenset({
|
| 338 |
+
"A-Mem", "MemoryOS", "Dummy",
|
| 339 |
+
"Mem0", "Mem0-Graph",
|
| 340 |
+
"SimpleMem", "Omni-SimpleMem",
|
| 341 |
+
"MemVerse",
|
| 342 |
+
"Zep",
|
| 343 |
+
})
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def create_amem_adapter(**kwargs: Any) -> MemoryAdapter:
|
| 347 |
+
from eval_framework.memory_adapters.amem_v2 import AMemV2Adapter
|
| 348 |
+
return AMemV2Adapter(**kwargs)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def create_memoryos_adapter(**kwargs: Any) -> MemoryOSAdapter:
|
| 352 |
+
return MemoryOSAdapter(**kwargs)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def create_dummy_adapter(**kwargs: Any) -> DummyAdapter:
|
| 356 |
+
return DummyAdapter()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def create_mem0_adapter(**kwargs: Any) -> MemoryAdapter:
|
| 360 |
+
from eval_framework.memory_adapters.mem0_adapter import Mem0Adapter
|
| 361 |
+
return Mem0Adapter(use_graph=False, **kwargs)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def create_mem0_graph_adapter(**kwargs: Any) -> MemoryAdapter:
|
| 365 |
+
from eval_framework.memory_adapters.mem0_adapter import Mem0Adapter
|
| 366 |
+
return Mem0Adapter(use_graph=True, **kwargs)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def create_simplemem_adapter(**kwargs: Any) -> MemoryAdapter:
|
| 370 |
+
from eval_framework.memory_adapters.simplemem_adapter import SimpleMemAdapter
|
| 371 |
+
return SimpleMemAdapter(mode="text", **kwargs)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def create_omni_simplemem_adapter(**kwargs: Any) -> MemoryAdapter:
|
| 375 |
+
from eval_framework.memory_adapters.simplemem_adapter import SimpleMemAdapter
|
| 376 |
+
return SimpleMemAdapter(mode="omni", **kwargs)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def create_memverse_adapter(**kwargs: Any) -> MemoryAdapter:
|
| 380 |
+
from eval_framework.memory_adapters.memverse_adapter import MemVerseAdapter
|
| 381 |
+
return MemVerseAdapter(**kwargs)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def create_zep_adapter(**kwargs: Any) -> MemoryAdapter:
|
| 385 |
+
from eval_framework.memory_adapters.zep_adapter import ZepAdapter
|
| 386 |
+
return ZepAdapter(**kwargs)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
EXTERNAL_ADAPTER_REGISTRY: dict[str, Callable[..., MemoryAdapter]] = {
|
| 390 |
+
"A-Mem": create_amem_adapter,
|
| 391 |
+
"MemoryOS": create_memoryos_adapter,
|
| 392 |
+
"Dummy": create_dummy_adapter,
|
| 393 |
+
"Mem0": create_mem0_adapter,
|
| 394 |
+
"Mem0-Graph": create_mem0_graph_adapter,
|
| 395 |
+
"SimpleMem": create_simplemem_adapter,
|
| 396 |
+
"Omni-SimpleMem": create_omni_simplemem_adapter,
|
| 397 |
+
"MemVerse": create_memverse_adapter,
|
| 398 |
+
"Zep": create_zep_adapter,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def create_external_adapter(
|
| 403 |
+
name: str,
|
| 404 |
+
*,
|
| 405 |
+
config_overrides: dict[str, Any] | None = None,
|
| 406 |
+
) -> MemoryAdapter:
|
| 407 |
+
"""Instantiate an external adapter for a known baseline name."""
|
| 408 |
+
if name not in EXTERNAL_ADAPTER_KEYS:
|
| 409 |
+
raise KeyError(f"unknown external adapter: {name!r}")
|
| 410 |
+
return EXTERNAL_ADAPTER_REGISTRY[name](**(config_overrides or {}))
|
memory_adapters/simplemem_adapter.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adapter for SimpleMem and Omni-SimpleMem baselines."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from eval_framework.datasets.schemas import (
|
| 11 |
+
MemoryDeltaRecord,
|
| 12 |
+
MemorySnapshotRecord,
|
| 13 |
+
NormalizedTurn,
|
| 14 |
+
RetrievalItem,
|
| 15 |
+
RetrievalRecord,
|
| 16 |
+
)
|
| 17 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 18 |
+
|
| 19 |
+
_DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/SimpleMem")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SimpleMemAdapter(MemoryAdapter):
|
| 23 |
+
"""Adapter for SimpleMem (text mode) or Omni-SimpleMem (omni mode)."""
|
| 24 |
+
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
*,
|
| 28 |
+
mode: str = "text",
|
| 29 |
+
source_root: str | os.PathLike[str] | None = None,
|
| 30 |
+
**kwargs: Any,
|
| 31 |
+
) -> None:
|
| 32 |
+
self._mode = mode # "text" or "omni"
|
| 33 |
+
root = Path(source_root or _DEFAULT_SOURCE).resolve()
|
| 34 |
+
if str(root) not in sys.path:
|
| 35 |
+
sys.path.insert(0, str(root))
|
| 36 |
+
|
| 37 |
+
import simplemem_router as simplemem
|
| 38 |
+
self._simplemem = simplemem
|
| 39 |
+
self._mem: Any = None
|
| 40 |
+
self._session_id = ""
|
| 41 |
+
self._prev_snapshot_ids: set[str] = set()
|
| 42 |
+
self._stored_texts: list[dict[str, str]] = []
|
| 43 |
+
self._init_mem()
|
| 44 |
+
|
| 45 |
+
def _init_mem(self) -> None:
|
| 46 |
+
self._mem = self._simplemem.create(mode=self._mode, clear_db=True)
|
| 47 |
+
self._stored_texts = []
|
| 48 |
+
|
| 49 |
+
def reset(self) -> None:
|
| 50 |
+
if self._mem is not None:
|
| 51 |
+
try:
|
| 52 |
+
self._mem.close()
|
| 53 |
+
except Exception:
|
| 54 |
+
pass
|
| 55 |
+
self._init_mem()
|
| 56 |
+
self._prev_snapshot_ids = set()
|
| 57 |
+
|
| 58 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 59 |
+
self._session_id = turn.session_id
|
| 60 |
+
text = f"{turn.role}: {turn.text}"
|
| 61 |
+
for att in turn.attachments:
|
| 62 |
+
text += f"\n[{att.type}] {att.caption}"
|
| 63 |
+
|
| 64 |
+
mid = str(len(self._stored_texts))
|
| 65 |
+
if self._mode == "omni":
|
| 66 |
+
self._mem.add_text(text, tags=[f"session:{turn.session_id}"])
|
| 67 |
+
else:
|
| 68 |
+
speaker = "User" if turn.role == "user" else "Assistant"
|
| 69 |
+
ts = turn.timestamp or ""
|
| 70 |
+
self._mem.add_dialogue(speaker, text, ts)
|
| 71 |
+
self._stored_texts.append({"id": mid, "text": text, "session_id": turn.session_id})
|
| 72 |
+
|
| 73 |
+
def end_session(self, session_id: str) -> None:
|
| 74 |
+
self._session_id = session_id
|
| 75 |
+
if self._mode == "text":
|
| 76 |
+
try:
|
| 77 |
+
self._mem.finalize()
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 82 |
+
return [
|
| 83 |
+
MemorySnapshotRecord(
|
| 84 |
+
memory_id=t["id"], text=t["text"],
|
| 85 |
+
session_id=t["session_id"], status="active",
|
| 86 |
+
source=f"SimpleMem-{self._mode}",
|
| 87 |
+
raw_backend_id=t["id"], raw_backend_type="simplemem",
|
| 88 |
+
metadata={},
|
| 89 |
+
)
|
| 90 |
+
for t in self._stored_texts
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 94 |
+
current = self.snapshot_memories()
|
| 95 |
+
current_ids = {s.memory_id for s in current}
|
| 96 |
+
deltas = [
|
| 97 |
+
MemoryDeltaRecord(
|
| 98 |
+
session_id=session_id, op="add", text=s.text,
|
| 99 |
+
linked_previous=(), raw_backend_id=s.raw_backend_id,
|
| 100 |
+
metadata={"baseline": f"SimpleMem-{self._mode}"},
|
| 101 |
+
)
|
| 102 |
+
for s in current if s.memory_id not in self._prev_snapshot_ids
|
| 103 |
+
]
|
| 104 |
+
self._prev_snapshot_ids = current_ids
|
| 105 |
+
return deltas
|
| 106 |
+
|
| 107 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 108 |
+
items: list[RetrievalItem] = []
|
| 109 |
+
try:
|
| 110 |
+
if self._mode == "omni":
|
| 111 |
+
result = self._mem.query(query, top_k=top_k)
|
| 112 |
+
if isinstance(result, list):
|
| 113 |
+
for i, r in enumerate(result[:top_k]):
|
| 114 |
+
text = r.get("text", str(r)) if isinstance(r, dict) else str(r)
|
| 115 |
+
items.append(RetrievalItem(
|
| 116 |
+
rank=i, memory_id=str(i), text=text,
|
| 117 |
+
score=1.0 / (i + 1), raw_backend_id=None,
|
| 118 |
+
))
|
| 119 |
+
else:
|
| 120 |
+
answer = self._mem.ask(query)
|
| 121 |
+
if answer:
|
| 122 |
+
items.append(RetrievalItem(
|
| 123 |
+
rank=0, memory_id="answer", text=str(answer),
|
| 124 |
+
score=1.0, raw_backend_id=None,
|
| 125 |
+
))
|
| 126 |
+
except Exception:
|
| 127 |
+
pass
|
| 128 |
+
|
| 129 |
+
if not items:
|
| 130 |
+
# Fallback: simple text search over stored memories
|
| 131 |
+
query_lower = query.lower()
|
| 132 |
+
scored = []
|
| 133 |
+
for t in self._stored_texts:
|
| 134 |
+
overlap = len(set(query_lower.split()) & set(t["text"].lower().split()))
|
| 135 |
+
scored.append((overlap, t))
|
| 136 |
+
scored.sort(key=lambda x: x[0], reverse=True)
|
| 137 |
+
for i, (sc, t) in enumerate(scored[:top_k]):
|
| 138 |
+
items.append(RetrievalItem(
|
| 139 |
+
rank=i, memory_id=t["id"], text=t["text"],
|
| 140 |
+
score=float(sc) / max(len(query.split()), 1),
|
| 141 |
+
raw_backend_id=t["id"],
|
| 142 |
+
))
|
| 143 |
+
|
| 144 |
+
return RetrievalRecord(
|
| 145 |
+
query=query, top_k=top_k, items=items[:top_k],
|
| 146 |
+
raw_trace={"baseline": f"SimpleMem-{self._mode}"},
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 150 |
+
name = "Omni-SimpleMem" if self._mode == "omni" else "SimpleMem"
|
| 151 |
+
return {
|
| 152 |
+
"backend": name, "baseline": name,
|
| 153 |
+
"available": self._mem is not None,
|
| 154 |
+
"delta_granularity": "per_turn",
|
| 155 |
+
"snapshot_mode": "full",
|
| 156 |
+
}
|
memory_adapters/zep_adapter.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Adapter for Zep memory system (community/self-hosted edition)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import uuid as _uuid
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from eval_framework.datasets.schemas import (
|
| 10 |
+
MemoryDeltaRecord,
|
| 11 |
+
MemorySnapshotRecord,
|
| 12 |
+
NormalizedTurn,
|
| 13 |
+
RetrievalItem,
|
| 14 |
+
RetrievalRecord,
|
| 15 |
+
)
|
| 16 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ZepAdapter(MemoryAdapter):
|
| 20 |
+
"""Adapter for Zep community edition (self-hosted)."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, *, base_url: str | None = None, **kwargs: Any) -> None:
|
| 23 |
+
from zep_python import ZepClient
|
| 24 |
+
|
| 25 |
+
self._base_url = base_url or os.getenv("ZEP_BASE_URL", "http://localhost:8000")
|
| 26 |
+
self._client = ZepClient(base_url=self._base_url)
|
| 27 |
+
self._session_id = ""
|
| 28 |
+
self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}"
|
| 29 |
+
self._prev_snapshot_ids: set[str] = set()
|
| 30 |
+
|
| 31 |
+
def reset(self) -> None:
|
| 32 |
+
try:
|
| 33 |
+
self._client.memory.delete_memory(self._thread_id)
|
| 34 |
+
except Exception:
|
| 35 |
+
pass
|
| 36 |
+
self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}"
|
| 37 |
+
self._prev_snapshot_ids = set()
|
| 38 |
+
|
| 39 |
+
def ingest_turn(self, turn: NormalizedTurn) -> None:
|
| 40 |
+
from zep_python.memory import Memory
|
| 41 |
+
from zep_python.message import Message
|
| 42 |
+
|
| 43 |
+
self._session_id = turn.session_id
|
| 44 |
+
text = f"{turn.role}: {turn.text}"
|
| 45 |
+
for att in turn.attachments:
|
| 46 |
+
text += f"\n[{att.type}] {att.caption}"
|
| 47 |
+
|
| 48 |
+
role_type = "user" if turn.role == "user" else "ai"
|
| 49 |
+
msg = Message(role=turn.role, role_type=role_type, content=text)
|
| 50 |
+
memory = Memory(messages=[msg])
|
| 51 |
+
self._client.memory.add_memory(self._thread_id, memory)
|
| 52 |
+
|
| 53 |
+
def end_session(self, session_id: str) -> None:
|
| 54 |
+
self._session_id = session_id
|
| 55 |
+
|
| 56 |
+
def snapshot_memories(self) -> list[MemorySnapshotRecord]:
|
| 57 |
+
try:
|
| 58 |
+
memory = self._client.memory.get_memory(self._thread_id)
|
| 59 |
+
except Exception:
|
| 60 |
+
return []
|
| 61 |
+
|
| 62 |
+
rows: list[MemorySnapshotRecord] = []
|
| 63 |
+
if memory and memory.messages:
|
| 64 |
+
for i, msg in enumerate(memory.messages):
|
| 65 |
+
mid = str(getattr(msg, "uuid", i))
|
| 66 |
+
rows.append(MemorySnapshotRecord(
|
| 67 |
+
memory_id=mid,
|
| 68 |
+
text=msg.content or "",
|
| 69 |
+
session_id=self._session_id,
|
| 70 |
+
status="active",
|
| 71 |
+
source="Zep",
|
| 72 |
+
raw_backend_id=mid,
|
| 73 |
+
raw_backend_type="zep_message",
|
| 74 |
+
metadata={},
|
| 75 |
+
))
|
| 76 |
+
return rows
|
| 77 |
+
|
| 78 |
+
def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
|
| 79 |
+
current = self.snapshot_memories()
|
| 80 |
+
current_ids = {s.memory_id for s in current}
|
| 81 |
+
deltas = [
|
| 82 |
+
MemoryDeltaRecord(
|
| 83 |
+
session_id=session_id, op="add", text=s.text,
|
| 84 |
+
linked_previous=(), raw_backend_id=s.raw_backend_id,
|
| 85 |
+
metadata={"baseline": "Zep"},
|
| 86 |
+
)
|
| 87 |
+
for s in current if s.memory_id not in self._prev_snapshot_ids
|
| 88 |
+
]
|
| 89 |
+
self._prev_snapshot_ids = current_ids
|
| 90 |
+
return deltas
|
| 91 |
+
|
| 92 |
+
def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
|
| 93 |
+
try:
|
| 94 |
+
results = self._client.memory.search_memory(
|
| 95 |
+
self._thread_id, query, limit=top_k,
|
| 96 |
+
)
|
| 97 |
+
except Exception:
|
| 98 |
+
results = []
|
| 99 |
+
|
| 100 |
+
items = [
|
| 101 |
+
RetrievalItem(
|
| 102 |
+
rank=i,
|
| 103 |
+
memory_id=str(getattr(r.message, "uuid", i)) if r.message else str(i),
|
| 104 |
+
text=r.message.content if r.message else str(r),
|
| 105 |
+
score=float(getattr(r, "score", 1.0 / (i + 1))),
|
| 106 |
+
raw_backend_id=str(getattr(r.message, "uuid", "")) if r.message else None,
|
| 107 |
+
)
|
| 108 |
+
for i, r in enumerate(results[:top_k])
|
| 109 |
+
]
|
| 110 |
+
return RetrievalRecord(
|
| 111 |
+
query=query, top_k=top_k, items=items,
|
| 112 |
+
raw_trace={"baseline": "Zep"},
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def get_capabilities(self) -> dict[str, Any]:
|
| 116 |
+
return {
|
| 117 |
+
"backend": "Zep",
|
| 118 |
+
"baseline": "Zep",
|
| 119 |
+
"available": True,
|
| 120 |
+
"delta_granularity": "snapshot_diff",
|
| 121 |
+
"snapshot_mode": "full",
|
| 122 |
+
}
|
openai_compat.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility helpers for OpenAI chat completions across model families."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def rewrite_chat_completion_kwargs(payload: dict[str, Any]) -> dict[str, Any]:
|
| 9 |
+
"""Translate deprecated chat completion parameters for reasoning models."""
|
| 10 |
+
rewritten = dict(payload)
|
| 11 |
+
model = str(rewritten.get("model") or "")
|
| 12 |
+
if (
|
| 13 |
+
model.startswith("gpt-5")
|
| 14 |
+
and "max_tokens" in rewritten
|
| 15 |
+
and "max_completion_tokens" not in rewritten
|
| 16 |
+
):
|
| 17 |
+
rewritten["max_completion_tokens"] = rewritten.pop("max_tokens")
|
| 18 |
+
return rewritten
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def patch_openai_chat_completions() -> bool:
|
| 22 |
+
"""Monkeypatch the OpenAI SDK so GPT-5 chat calls accept legacy max_tokens."""
|
| 23 |
+
try:
|
| 24 |
+
from openai.resources.chat.completions.completions import Completions
|
| 25 |
+
except Exception:
|
| 26 |
+
return False
|
| 27 |
+
|
| 28 |
+
current = Completions.create
|
| 29 |
+
if getattr(current, "_eval_framework_patched", False):
|
| 30 |
+
return True
|
| 31 |
+
|
| 32 |
+
original_create = current
|
| 33 |
+
|
| 34 |
+
def _patched_create(self: Any, *args: Any, **kwargs: Any) -> Any:
|
| 35 |
+
rewritten = rewrite_chat_completion_kwargs(kwargs)
|
| 36 |
+
try:
|
| 37 |
+
return original_create(self, *args, **rewritten)
|
| 38 |
+
except Exception as exc:
|
| 39 |
+
if (
|
| 40 |
+
"Unsupported parameter: 'max_tokens'" in str(exc)
|
| 41 |
+
and "max_tokens" in kwargs
|
| 42 |
+
):
|
| 43 |
+
retried = rewrite_chat_completion_kwargs(kwargs)
|
| 44 |
+
return original_create(self, *args, **retried)
|
| 45 |
+
raise
|
| 46 |
+
|
| 47 |
+
_patched_create._eval_framework_patched = True # type: ignore[attr-defined]
|
| 48 |
+
Completions.create = _patched_create # type: ignore[assignment]
|
| 49 |
+
return True
|
pipeline/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Session and checkpoint orchestration."""
|
pipeline/gold_state.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Cumulative gold memory state from staged memory-point annotations."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Any, Mapping, Sequence
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass(frozen=True)
|
| 10 |
+
class GoldMemoryPoint:
|
| 11 |
+
memory_id: str
|
| 12 |
+
memory_content: str
|
| 13 |
+
memory_type: str
|
| 14 |
+
memory_source: str
|
| 15 |
+
is_update: bool
|
| 16 |
+
original_memories: tuple[str, ...]
|
| 17 |
+
importance: float
|
| 18 |
+
timestamp: str | None = None
|
| 19 |
+
update_type: str = ""
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass(frozen=True)
|
| 23 |
+
class SessionGoldState:
|
| 24 |
+
session_id: str
|
| 25 |
+
cumulative_gold_memories: tuple[GoldMemoryPoint, ...]
|
| 26 |
+
session_new_memories: tuple[GoldMemoryPoint, ...]
|
| 27 |
+
session_update_memories: tuple[GoldMemoryPoint, ...]
|
| 28 |
+
session_interference_memories: tuple[GoldMemoryPoint, ...]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _as_str_tuple(val: Any) -> tuple[str, ...]:
|
| 32 |
+
if val is None:
|
| 33 |
+
return ()
|
| 34 |
+
if isinstance(val, str):
|
| 35 |
+
return (val,)
|
| 36 |
+
if isinstance(val, Sequence) and not isinstance(val, (str, bytes)):
|
| 37 |
+
out: list[str] = []
|
| 38 |
+
for x in val:
|
| 39 |
+
out.append(x if isinstance(x, str) else str(x))
|
| 40 |
+
return tuple(out)
|
| 41 |
+
return (str(val),)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _parse_is_update(raw: Mapping[str, Any]) -> bool:
|
| 45 |
+
v = raw.get("is_update")
|
| 46 |
+
if v is True:
|
| 47 |
+
return True
|
| 48 |
+
if v is False or v is None:
|
| 49 |
+
return False
|
| 50 |
+
if isinstance(v, str):
|
| 51 |
+
return v.strip().lower() == "true"
|
| 52 |
+
return bool(v)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _parse_importance(raw: Mapping[str, Any]) -> float:
|
| 56 |
+
value = raw.get("importance", 0.0)
|
| 57 |
+
try:
|
| 58 |
+
return float(value)
|
| 59 |
+
except (TypeError, ValueError):
|
| 60 |
+
return 0.0
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def gold_point_from_raw(raw: Mapping[str, Any]) -> GoldMemoryPoint:
|
| 64 |
+
mid = raw.get("memory_id")
|
| 65 |
+
content = raw.get("memory_content")
|
| 66 |
+
return GoldMemoryPoint(
|
| 67 |
+
memory_id=str(mid) if mid is not None else "",
|
| 68 |
+
memory_content=str(content) if content is not None else "",
|
| 69 |
+
memory_type=str(raw.get("memory_type", "")),
|
| 70 |
+
memory_source=str(raw.get("memory_source", "")),
|
| 71 |
+
is_update=_parse_is_update(raw),
|
| 72 |
+
original_memories=_as_str_tuple(raw.get("original_memories")),
|
| 73 |
+
importance=_parse_importance(raw),
|
| 74 |
+
timestamp=(
|
| 75 |
+
str(raw["timestamp"])
|
| 76 |
+
if raw.get("timestamp") is not None
|
| 77 |
+
else None
|
| 78 |
+
),
|
| 79 |
+
update_type=str(raw.get("update_type", "") or ""),
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def build_session_gold_states(
|
| 84 |
+
ordered_session_ids: Sequence[str],
|
| 85 |
+
*,
|
| 86 |
+
s00_memory_points: Sequence[Mapping[str, Any]],
|
| 87 |
+
stage4_by_session_id: Mapping[str, Sequence[Mapping[str, Any]]],
|
| 88 |
+
) -> tuple[SessionGoldState, ...]:
|
| 89 |
+
"""Accumulate non-interference gold memories through sessions in order.
|
| 90 |
+
|
| 91 |
+
S00 is taken from the domain JSON session; later sessions prefer ``stage4``
|
| 92 |
+
rows when present, since those drive staged evaluation labels.
|
| 93 |
+
"""
|
| 94 |
+
cumulative: list[GoldMemoryPoint] = []
|
| 95 |
+
states: list[SessionGoldState] = []
|
| 96 |
+
|
| 97 |
+
for sid in ordered_session_ids:
|
| 98 |
+
if sid == "S00":
|
| 99 |
+
raw_points: Sequence[Mapping[str, Any]] = s00_memory_points
|
| 100 |
+
else:
|
| 101 |
+
raw_points = stage4_by_session_id.get(sid)
|
| 102 |
+
if raw_points is None:
|
| 103 |
+
raw_points = ()
|
| 104 |
+
|
| 105 |
+
news: list[GoldMemoryPoint] = []
|
| 106 |
+
updates: list[GoldMemoryPoint] = []
|
| 107 |
+
interference: list[GoldMemoryPoint] = []
|
| 108 |
+
|
| 109 |
+
for raw in raw_points:
|
| 110 |
+
gp = gold_point_from_raw(raw)
|
| 111 |
+
if gp.memory_source == "interference":
|
| 112 |
+
interference.append(gp)
|
| 113 |
+
continue
|
| 114 |
+
cumulative.append(gp)
|
| 115 |
+
if gp.is_update:
|
| 116 |
+
updates.append(gp)
|
| 117 |
+
else:
|
| 118 |
+
news.append(gp)
|
| 119 |
+
|
| 120 |
+
states.append(
|
| 121 |
+
SessionGoldState(
|
| 122 |
+
session_id=sid,
|
| 123 |
+
cumulative_gold_memories=tuple(cumulative),
|
| 124 |
+
session_new_memories=tuple(news),
|
| 125 |
+
session_update_memories=tuple(updates),
|
| 126 |
+
session_interference_memories=tuple(interference),
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
return tuple(states)
|
pipeline/qa_runner.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared checkpoint QA: retrieval via adapter + answer from an injected callable.
|
| 2 |
+
|
| 3 |
+
``AnswerFn`` may return either a plain ``str`` (legacy) or a
|
| 4 |
+
``(str, list[str])`` tuple of ``(answer, cited_memories)``.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from collections.abc import Callable
|
| 10 |
+
from typing import Union
|
| 11 |
+
|
| 12 |
+
from eval_framework.datasets.domain_a_v2 import NormalizedCheckpoint, NormalizedCheckpointQuestion
|
| 13 |
+
from eval_framework.datasets.schemas import RetrievalRecord
|
| 14 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 15 |
+
from eval_framework.pipeline.records import PipelineCheckpointQARecord
|
| 16 |
+
|
| 17 |
+
# answer_fn may return str (legacy) or (str, list[str])
|
| 18 |
+
AnswerResult = Union[str, tuple[str, list[str]]]
|
| 19 |
+
AnswerFn = Callable[[NormalizedCheckpointQuestion, RetrievalRecord], AnswerResult]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run_checkpoint_qa_records(
|
| 23 |
+
adapter: MemoryAdapter,
|
| 24 |
+
*,
|
| 25 |
+
sample_id: str,
|
| 26 |
+
sample_uuid: str,
|
| 27 |
+
checkpoint: NormalizedCheckpoint,
|
| 28 |
+
top_k: int,
|
| 29 |
+
answer_fn: AnswerFn,
|
| 30 |
+
) -> tuple[PipelineCheckpointQARecord, ...]:
|
| 31 |
+
"""For each question, call ``retrieve`` then ``answer_fn`` (not ``adapter.answer``)."""
|
| 32 |
+
out: list[PipelineCheckpointQARecord] = []
|
| 33 |
+
for q in checkpoint.questions:
|
| 34 |
+
retrieval = adapter.retrieve(q.question, top_k)
|
| 35 |
+
result = answer_fn(q, retrieval)
|
| 36 |
+
|
| 37 |
+
if isinstance(result, tuple):
|
| 38 |
+
generated, cited = result
|
| 39 |
+
else:
|
| 40 |
+
generated, cited = result, []
|
| 41 |
+
|
| 42 |
+
out.append(
|
| 43 |
+
PipelineCheckpointQARecord(
|
| 44 |
+
sample_id=sample_id,
|
| 45 |
+
sample_uuid=sample_uuid,
|
| 46 |
+
checkpoint_id=checkpoint.checkpoint_id,
|
| 47 |
+
question=q.question,
|
| 48 |
+
gold_answer=q.gold_answer,
|
| 49 |
+
gold_evidence_memory_ids=q.gold_evidence_memory_ids,
|
| 50 |
+
gold_evidence_contents=q.gold_evidence_contents,
|
| 51 |
+
question_type=q.question_type,
|
| 52 |
+
question_type_abbrev=q.question_type_abbrev,
|
| 53 |
+
difficulty=q.difficulty,
|
| 54 |
+
retrieval=retrieval,
|
| 55 |
+
generated_answer=generated,
|
| 56 |
+
cited_memories=tuple(cited),
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
return tuple(out)
|
pipeline/records.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pipeline-facing aliases and runtime record types emitted by the eval runner."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
from eval_framework.datasets.schemas import (
|
| 8 |
+
Attachment,
|
| 9 |
+
MemoryDeltaRecord,
|
| 10 |
+
MemorySnapshotRecord,
|
| 11 |
+
NormalizedTurn,
|
| 12 |
+
RetrievalItem,
|
| 13 |
+
RetrievalRecord,
|
| 14 |
+
normalize_turn,
|
| 15 |
+
)
|
| 16 |
+
from eval_framework.pipeline.gold_state import SessionGoldState
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"Attachment",
|
| 20 |
+
"MemoryDeltaRecord",
|
| 21 |
+
"MemorySnapshotRecord",
|
| 22 |
+
"NormalizedTurn",
|
| 23 |
+
"PipelineCheckpointQARecord",
|
| 24 |
+
"PipelineSessionRecord",
|
| 25 |
+
"RetrievalItem",
|
| 26 |
+
"RetrievalRecord",
|
| 27 |
+
"SessionGoldState",
|
| 28 |
+
"normalize_turn",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass(frozen=True)
|
| 33 |
+
class PipelineSessionRecord:
|
| 34 |
+
"""Normalized outputs after one dialogue session (one row per session)."""
|
| 35 |
+
|
| 36 |
+
sample_id: str
|
| 37 |
+
sample_uuid: str
|
| 38 |
+
session_id: str
|
| 39 |
+
memory_snapshot: tuple[MemorySnapshotRecord, ...]
|
| 40 |
+
memory_delta: tuple[MemoryDeltaRecord, ...]
|
| 41 |
+
gold_state: SessionGoldState
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass(frozen=True)
|
| 45 |
+
class PipelineCheckpointQARecord:
|
| 46 |
+
"""One checkpoint question: retrieval trace plus model-generated answer."""
|
| 47 |
+
|
| 48 |
+
sample_id: str
|
| 49 |
+
sample_uuid: str
|
| 50 |
+
checkpoint_id: str
|
| 51 |
+
question: str
|
| 52 |
+
gold_answer: str
|
| 53 |
+
gold_evidence_memory_ids: tuple[str, ...]
|
| 54 |
+
gold_evidence_contents: tuple[str, ...]
|
| 55 |
+
question_type: str
|
| 56 |
+
question_type_abbrev: str
|
| 57 |
+
difficulty: str
|
| 58 |
+
retrieval: RetrievalRecord
|
| 59 |
+
generated_answer: str
|
| 60 |
+
cited_memories: tuple[str, ...] = ()
|
pipeline/runner.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Session-by-session ingest, memory export, and checkpoint QA orchestration."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from collections.abc import Callable
|
| 6 |
+
|
| 7 |
+
from eval_framework.datasets.domain_a_v2 import (
|
| 8 |
+
DomainAV2AcademicSample,
|
| 9 |
+
NormalizedCheckpointQuestion,
|
| 10 |
+
)
|
| 11 |
+
from eval_framework.memory_adapters.base import MemoryAdapter
|
| 12 |
+
from eval_framework.pipeline.qa_runner import run_checkpoint_qa_records
|
| 13 |
+
from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord
|
| 14 |
+
from eval_framework.datasets.schemas import RetrievalRecord
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def ensure_adapter_available(adapter: MemoryAdapter) -> None:
|
| 18 |
+
caps = adapter.get_capabilities()
|
| 19 |
+
if caps.get("available") is False:
|
| 20 |
+
backend = caps.get("backend", type(adapter).__name__)
|
| 21 |
+
detail = caps.get("integration_error") or caps.get(
|
| 22 |
+
"integration_status", "available=False"
|
| 23 |
+
)
|
| 24 |
+
raise RuntimeError(
|
| 25 |
+
f"Memory adapter {backend!r} is not available for pipeline runs: {detail}"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def run_domain_a_v2_sample(
|
| 30 |
+
adapter: MemoryAdapter,
|
| 31 |
+
sample: DomainAV2AcademicSample,
|
| 32 |
+
*,
|
| 33 |
+
top_k: int = 5,
|
| 34 |
+
answer_fn: Callable | None = None,
|
| 35 |
+
) -> tuple[tuple[PipelineSessionRecord, ...], tuple[PipelineCheckpointQARecord, ...]]:
|
| 36 |
+
"""Run all sessions in order, emit one session record per session, then checkpoint QA when due."""
|
| 37 |
+
ensure_adapter_available(adapter)
|
| 38 |
+
if sample.normalized_checkpoints and answer_fn is None:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
"answer_fn is required when the sample defines normalized checkpoints"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
adapter.reset()
|
| 44 |
+
session_out: list[PipelineSessionRecord] = []
|
| 45 |
+
qa_out: list[PipelineCheckpointQARecord] = []
|
| 46 |
+
completed_sessions: set[str] = set()
|
| 47 |
+
session_order = {
|
| 48 |
+
session.session_id: index for index, session in enumerate(sample.sessions)
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
if len(sample.sessions) != len(sample.session_gold_states):
|
| 52 |
+
raise ValueError(
|
| 53 |
+
"sample.sessions and sample.session_gold_states length mismatch"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
for sess, gold in zip(sample.sessions, sample.session_gold_states):
|
| 57 |
+
if sess.session_id != gold.session_id:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"session / gold_state id mismatch: {sess.session_id!r} vs {gold.session_id!r}"
|
| 60 |
+
)
|
| 61 |
+
for turn in sess.turns:
|
| 62 |
+
adapter.ingest_turn(turn)
|
| 63 |
+
adapter.end_session(sess.session_id)
|
| 64 |
+
|
| 65 |
+
snapshot = tuple(adapter.snapshot_memories())
|
| 66 |
+
delta = tuple(adapter.export_memory_delta(sess.session_id))
|
| 67 |
+
session_out.append(
|
| 68 |
+
PipelineSessionRecord(
|
| 69 |
+
sample_id=sample.sample_id,
|
| 70 |
+
sample_uuid=sample.uuid,
|
| 71 |
+
session_id=sess.session_id,
|
| 72 |
+
memory_snapshot=snapshot,
|
| 73 |
+
memory_delta=delta,
|
| 74 |
+
gold_state=gold,
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
completed_sessions.add(sess.session_id)
|
| 78 |
+
|
| 79 |
+
for cp in sample.normalized_checkpoints:
|
| 80 |
+
covered = cp.covered_sessions
|
| 81 |
+
if not covered:
|
| 82 |
+
continue
|
| 83 |
+
missing = [sid for sid in covered if sid not in session_order]
|
| 84 |
+
if missing:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"checkpoint {cp.checkpoint_id!r} references unknown sessions: {missing}"
|
| 87 |
+
)
|
| 88 |
+
if not set(covered).issubset(completed_sessions):
|
| 89 |
+
continue
|
| 90 |
+
trigger_session_id = max(covered, key=session_order.__getitem__)
|
| 91 |
+
if sess.session_id != trigger_session_id:
|
| 92 |
+
continue
|
| 93 |
+
qa_out.extend(
|
| 94 |
+
run_checkpoint_qa_records(
|
| 95 |
+
adapter,
|
| 96 |
+
sample_id=sample.sample_id,
|
| 97 |
+
sample_uuid=sample.uuid,
|
| 98 |
+
checkpoint=cp,
|
| 99 |
+
top_k=top_k,
|
| 100 |
+
answer_fn=answer_fn,
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return tuple(session_out), tuple(qa_out)
|