LCZZZZ commited on
Commit
85b19cf
·
verified ·
1 Parent(s): 4aa5d62

Upload eval_framework source code

Browse files
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ results/
2
+ converted/
3
+ __pycache__/
4
+ *.pyc
5
+ *.jsonl
__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Unified memory evaluation framework (package scaffold)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ __all__: list[str] = []
cli.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI: dataset path, baseline, output dir, dry-run, smoke eval.
2
+
3
+ Evaluation uses batch LLM judge: 2 calls/session + 2 calls/QA.
4
+ Session and QA evaluations run in parallel via ThreadPoolExecutor.
5
+ Pipeline results are checkpointed before eval so --eval-only can resume.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import json
12
+ import os
13
+ from collections.abc import Callable
14
+ from concurrent.futures import ThreadPoolExecutor, as_completed
15
+ from dataclasses import asdict
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ try:
20
+ from openai import OpenAI
21
+ except ImportError:
22
+ OpenAI = None # type: ignore[assignment]
23
+
24
+ from eval_framework.config import EvalConfig
25
+ from eval_framework.datasets.domain_a_v2 import (
26
+ DomainAV2AcademicBundle,
27
+ NormalizedCheckpointQuestion,
28
+ load_domain_a_v2_academic,
29
+ )
30
+ from eval_framework.datasets.schemas import (
31
+ MemoryDeltaRecord,
32
+ MemorySnapshotRecord,
33
+ RetrievalItem,
34
+ RetrievalRecord,
35
+ )
36
+ from eval_framework.evaluators.aggregate import aggregate_metrics
37
+ from eval_framework.evaluators.extraction import evaluate_extraction
38
+ from eval_framework.evaluators.qa import evaluate_checkpoint_qa
39
+ from eval_framework.memory_adapters.base import MemoryAdapter
40
+ from eval_framework.openai_compat import (
41
+ patch_openai_chat_completions,
42
+ rewrite_chat_completion_kwargs,
43
+ )
44
+ from eval_framework.pipeline.gold_state import GoldMemoryPoint, SessionGoldState
45
+ from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord
46
+ from eval_framework.pipeline.runner import run_domain_a_v2_sample
47
+
48
+
49
+ _CHECKPOINT_SESSIONS = "pipeline_sessions.jsonl"
50
+ _CHECKPOINT_QA = "pipeline_qa.jsonl"
51
+
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Checkpoint deserialization: dict -> frozen dataclass
55
+ # ---------------------------------------------------------------------------
56
+
57
+ def _gold_point_from_dict(d: dict[str, Any]) -> GoldMemoryPoint:
58
+ return GoldMemoryPoint(
59
+ memory_id=d["memory_id"],
60
+ memory_content=d["memory_content"],
61
+ memory_type=d["memory_type"],
62
+ memory_source=d["memory_source"],
63
+ is_update=bool(d["is_update"]),
64
+ original_memories=tuple(d.get("original_memories") or ()),
65
+ importance=float(d.get("importance", 0.0)),
66
+ timestamp=d.get("timestamp"),
67
+ update_type=d.get("update_type", ""),
68
+ )
69
+
70
+
71
+ def _gold_state_from_dict(d: dict[str, Any]) -> SessionGoldState:
72
+ return SessionGoldState(
73
+ session_id=d["session_id"],
74
+ cumulative_gold_memories=tuple(_gold_point_from_dict(g) for g in d["cumulative_gold_memories"]),
75
+ session_new_memories=tuple(_gold_point_from_dict(g) for g in d["session_new_memories"]),
76
+ session_update_memories=tuple(_gold_point_from_dict(g) for g in d["session_update_memories"]),
77
+ session_interference_memories=tuple(_gold_point_from_dict(g) for g in d["session_interference_memories"]),
78
+ )
79
+
80
+
81
+ def _snapshot_record_from_dict(d: dict[str, Any]) -> MemorySnapshotRecord:
82
+ return MemorySnapshotRecord(
83
+ memory_id=d["memory_id"],
84
+ text=d["text"],
85
+ session_id=d["session_id"],
86
+ status=d["status"],
87
+ source=d.get("source"),
88
+ raw_backend_id=d.get("raw_backend_id"),
89
+ raw_backend_type=d.get("raw_backend_type"),
90
+ metadata=d.get("metadata") or {},
91
+ )
92
+
93
+
94
+ def _delta_record_from_dict(d: dict[str, Any]) -> MemoryDeltaRecord:
95
+ return MemoryDeltaRecord(
96
+ session_id=d["session_id"],
97
+ op=d["op"],
98
+ text=d["text"],
99
+ linked_previous=tuple(d.get("linked_previous") or ()),
100
+ raw_backend_id=d.get("raw_backend_id"),
101
+ metadata=d.get("metadata") or {},
102
+ )
103
+
104
+
105
+ def _retrieval_item_from_dict(d: dict[str, Any]) -> RetrievalItem:
106
+ return RetrievalItem(
107
+ rank=int(d["rank"]),
108
+ memory_id=d["memory_id"],
109
+ text=d["text"],
110
+ score=float(d["score"]),
111
+ raw_backend_id=d.get("raw_backend_id"),
112
+ )
113
+
114
+
115
+ def _retrieval_record_from_dict(d: dict[str, Any]) -> RetrievalRecord:
116
+ return RetrievalRecord(
117
+ query=d["query"],
118
+ top_k=int(d["top_k"]),
119
+ items=[_retrieval_item_from_dict(i) for i in d["items"]],
120
+ raw_trace=d.get("raw_trace") or {},
121
+ )
122
+
123
+
124
+ def _session_record_from_dict(d: dict[str, Any]) -> PipelineSessionRecord:
125
+ return PipelineSessionRecord(
126
+ sample_id=d["sample_id"],
127
+ sample_uuid=d["sample_uuid"],
128
+ session_id=d["session_id"],
129
+ memory_snapshot=tuple(_snapshot_record_from_dict(s) for s in d["memory_snapshot"]),
130
+ memory_delta=tuple(_delta_record_from_dict(dl) for dl in d["memory_delta"]),
131
+ gold_state=_gold_state_from_dict(d["gold_state"]),
132
+ )
133
+
134
+
135
+ def _qa_record_from_dict(d: dict[str, Any]) -> PipelineCheckpointQARecord:
136
+ return PipelineCheckpointQARecord(
137
+ sample_id=d["sample_id"],
138
+ sample_uuid=d["sample_uuid"],
139
+ checkpoint_id=d["checkpoint_id"],
140
+ question=d["question"],
141
+ gold_answer=d["gold_answer"],
142
+ gold_evidence_memory_ids=tuple(d.get("gold_evidence_memory_ids") or ()),
143
+ gold_evidence_contents=tuple(d.get("gold_evidence_contents") or ()),
144
+ question_type=d["question_type"],
145
+ question_type_abbrev=d["question_type_abbrev"],
146
+ difficulty=d["difficulty"],
147
+ retrieval=_retrieval_record_from_dict(d["retrieval"]),
148
+ generated_answer=d["generated_answer"],
149
+ cited_memories=tuple(d.get("cited_memories") or ()),
150
+ )
151
+
152
+
153
+ def _read_jsonl(path: Path) -> list[dict[str, Any]]:
154
+ rows: list[dict[str, Any]] = []
155
+ with path.open("r", encoding="utf-8") as fh:
156
+ for line in fh:
157
+ line = line.strip()
158
+ if line:
159
+ rows.append(json.loads(line))
160
+ return rows
161
+
162
+
163
+ def _load_pipeline_checkpoint(
164
+ output_dir: Path,
165
+ ) -> tuple[list[PipelineSessionRecord], list[PipelineCheckpointQARecord]]:
166
+ """Restore pipeline records from checkpoint JSONL files."""
167
+ sess_path = output_dir / _CHECKPOINT_SESSIONS
168
+ qa_path = output_dir / _CHECKPOINT_QA
169
+ if not sess_path.exists() or not qa_path.exists():
170
+ raise SystemExit(
171
+ f"Checkpoint files not found in {output_dir}. "
172
+ f"Run without --eval-only first to generate them."
173
+ )
174
+ session_records = [_session_record_from_dict(d) for d in _read_jsonl(sess_path)]
175
+ qa_records = [_qa_record_from_dict(d) for d in _read_jsonl(qa_path)]
176
+ return session_records, qa_records
177
+
178
+
179
+ def _default_create_adapter(baseline_name: str) -> MemoryAdapter:
180
+ from eval_framework.memory_adapters import registry as reg
181
+
182
+ if baseline_name in reg.MEMGALLERY_NATIVE_REGISTRY:
183
+ return reg.MEMGALLERY_NATIVE_REGISTRY[baseline_name]()
184
+ if baseline_name in reg.EXTERNAL_ADAPTER_REGISTRY:
185
+ return reg.EXTERNAL_ADAPTER_REGISTRY[baseline_name]()
186
+ known = sorted(
187
+ reg.MEMGALLERY_NATIVE_BASELINES | reg.EXTERNAL_ADAPTER_KEYS
188
+ )
189
+ raise SystemExit(
190
+ f"Unknown baseline {baseline_name!r}. "
191
+ f"Expected one of: {', '.join(known)}"
192
+ )
193
+
194
+
195
+ def _gold_echo_answer(
196
+ q: NormalizedCheckpointQuestion, _retrieval: RetrievalRecord
197
+ ) -> tuple[str, list[str]]:
198
+ return q.gold_answer, []
199
+
200
+
201
+ def _parse_answer_json(raw: str) -> tuple[str, list[str]]:
202
+ """Extract answer and cited_memories from the model's JSON response."""
203
+ # Try to parse as JSON first
204
+ try:
205
+ data = json.loads(raw)
206
+ answer = str(data.get("answer", ""))
207
+ cited = data.get("cited_memories", [])
208
+ if isinstance(cited, list):
209
+ return answer, [str(c) for c in cited]
210
+ return answer, []
211
+ except (json.JSONDecodeError, TypeError):
212
+ pass
213
+ # Fallback: try to find JSON block in the response
214
+ import re
215
+ m = re.search(r"\{[\s\S]*\}", raw)
216
+ if m:
217
+ try:
218
+ data = json.loads(m.group())
219
+ answer = str(data.get("answer", ""))
220
+ cited = data.get("cited_memories", [])
221
+ if isinstance(cited, list):
222
+ return answer, [str(c) for c in cited]
223
+ except (json.JSONDecodeError, TypeError):
224
+ pass
225
+ # Final fallback: treat entire response as the answer, no citations
226
+ return raw.strip(), []
227
+
228
+
229
+ def build_default_answer_fn() -> Callable[
230
+ [NormalizedCheckpointQuestion, RetrievalRecord], tuple[str, list[str]]
231
+ ]:
232
+ api_key = os.getenv("OPENAI_API_KEY")
233
+ if not api_key or OpenAI is None:
234
+ return _gold_echo_answer
235
+
236
+ client = OpenAI(
237
+ api_key=api_key,
238
+ base_url=os.getenv("OPENAI_BASE_URL"),
239
+ )
240
+ model = os.getenv("OPENAI_MODEL") or "gpt-4o"
241
+ temperature = float(os.getenv("OPENAI_TEMPERATURE", "0.0"))
242
+ max_tokens = int(os.getenv("OPENAI_MAX_TOKENS", "1024"))
243
+
244
+ def _answer(
245
+ q: NormalizedCheckpointQuestion, retrieval: RetrievalRecord
246
+ ) -> tuple[str, list[str]]:
247
+ context_lines = [
248
+ f"[{item.rank}] {item.text}" for item in retrieval.items[: retrieval.top_k]
249
+ ]
250
+ context = "\n".join(context_lines) if context_lines else "No retrieved memories."
251
+ prompt = (
252
+ "Answer the user's question using only the retrieved memories below. "
253
+ "If the memories are insufficient, answer exactly: Not mentioned in memory.\n\n"
254
+ "You MUST also list the specific memory passages you relied on to produce "
255
+ "the answer. Copy the relevant text verbatim from the retrieved memories.\n\n"
256
+ f"Question: {q.question}\n\n"
257
+ f"Retrieved memories:\n{context}\n\n"
258
+ 'Respond in JSON:\n'
259
+ '{\n'
260
+ ' "answer": "your concise answer",\n'
261
+ ' "cited_memories": ["verbatim passage 1", "verbatim passage 2"]\n'
262
+ '}\n'
263
+ )
264
+ request_kwargs = rewrite_chat_completion_kwargs(
265
+ {
266
+ "model": model,
267
+ "messages": [
268
+ {
269
+ "role": "system",
270
+ "content": (
271
+ "You answer benchmark questions using only supplied memory context. "
272
+ "Be concise and do not invent missing facts. "
273
+ "Always respond in the requested JSON format."
274
+ ),
275
+ },
276
+ {"role": "user", "content": prompt},
277
+ ],
278
+ "temperature": temperature,
279
+ "max_tokens": max_tokens,
280
+ }
281
+ )
282
+ response = client.chat.completions.create(**request_kwargs)
283
+ raw = response.choices[0].message.content or ""
284
+ return _parse_answer_json(raw)
285
+
286
+ return _answer
287
+
288
+
289
+ def config_from_namespace(ns: argparse.Namespace) -> EvalConfig:
290
+ return EvalConfig(
291
+ dataset_path=Path(ns.dataset).expanduser().resolve(),
292
+ output_dir=Path(ns.output_dir).expanduser().resolve(),
293
+ baseline=str(ns.baseline),
294
+ smoke=bool(ns.smoke),
295
+ dry_run=bool(ns.dry_run),
296
+ )
297
+
298
+
299
+ def _record_to_json_obj(obj: Any) -> dict[str, Any]:
300
+ return asdict(obj)
301
+
302
+
303
+ def _write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
304
+ path.parent.mkdir(parents=True, exist_ok=True)
305
+ with path.open("w", encoding="utf-8") as fh:
306
+ for row in rows:
307
+ fh.write(json.dumps(row, ensure_ascii=False) + "\n")
308
+
309
+
310
+ def _write_json(path: Path, payload: dict[str, Any]) -> None:
311
+ path.parent.mkdir(parents=True, exist_ok=True)
312
+ path.write_text(
313
+ json.dumps(payload, ensure_ascii=False, indent=2) + "\n",
314
+ encoding="utf-8",
315
+ )
316
+
317
+
318
+ def run_eval(
319
+ config: EvalConfig,
320
+ *,
321
+ load_domain_bundle: Callable[[Path], DomainAV2AcademicBundle] = load_domain_a_v2_academic,
322
+ create_adapter: Callable[[str], MemoryAdapter] | None = None,
323
+ answer_fn: Callable | None = None,
324
+ max_eval_workers: int = 5,
325
+ eval_only: bool = False,
326
+ ) -> None:
327
+ """Load data, run pipeline (serial) + LLM eval (parallel)."""
328
+ patch_openai_chat_completions()
329
+ if config.dry_run:
330
+ return
331
+
332
+ out = config.output_dir
333
+ out.mkdir(parents=True, exist_ok=True)
334
+
335
+ if eval_only:
336
+ # --- Resume from checkpoint ---
337
+ print(f"[Eval-only] Loading pipeline checkpoint from {out}")
338
+ session_records, qa_records = _load_pipeline_checkpoint(out)
339
+ print(f"[Eval-only] Loaded {len(session_records)} sessions + {len(qa_records)} QA records")
340
+ else:
341
+ # --- Stage 1: Pipeline (serial — adapter is stateful) ---
342
+ adapter_factory = create_adapter or _default_create_adapter
343
+ bundle = load_domain_bundle(config.dataset_path)
344
+ samples = bundle.samples[:1] if config.smoke else bundle.samples
345
+ _answer = answer_fn if answer_fn is not None else build_default_answer_fn()
346
+
347
+ session_records: list[PipelineSessionRecord] = []
348
+ qa_records: list[PipelineCheckpointQARecord] = []
349
+
350
+ print(f"[Pipeline] Running {len(samples)} sample(s) with baseline={config.baseline}")
351
+ for i, sample in enumerate(samples):
352
+ print(f" Sample {i + 1}/{len(samples)}: {sample.sample_id}")
353
+ adapter = adapter_factory(config.baseline)
354
+ sess, qa = run_domain_a_v2_sample(
355
+ adapter,
356
+ sample,
357
+ answer_fn=_answer,
358
+ )
359
+ session_records.extend(sess)
360
+ qa_records.extend(qa)
361
+
362
+ # --- Save checkpoint ---
363
+ _write_jsonl(out / _CHECKPOINT_SESSIONS,
364
+ [_record_to_json_obj(r) for r in session_records])
365
+ _write_jsonl(out / _CHECKPOINT_QA,
366
+ [_record_to_json_obj(r) for r in qa_records])
367
+ print(f"[Checkpoint] Saved {len(session_records)} sessions + {len(qa_records)} QA to {out}")
368
+
369
+ # --- Stage 2: Eval (parallel — each record is self-contained) ---
370
+ print(f"[Eval] Evaluating {len(session_records)} sessions + {len(qa_records)} QA with LLM judge (workers={max_eval_workers})...")
371
+
372
+ session_evals: list[dict[str, object] | None] = [None] * len(session_records)
373
+ qa_evals: list[dict[str, object] | None] = [None] * len(qa_records)
374
+
375
+ with ThreadPoolExecutor(max_workers=max_eval_workers) as pool:
376
+ # Submit session evals
377
+ session_futures = {}
378
+ for idx, srec in enumerate(session_records):
379
+ fut = pool.submit(evaluate_extraction, srec)
380
+ session_futures[fut] = idx
381
+
382
+ # Submit QA evals
383
+ qa_futures = {}
384
+ for idx, qrec in enumerate(qa_records):
385
+ fut = pool.submit(evaluate_checkpoint_qa, qrec)
386
+ qa_futures[fut] = idx
387
+
388
+ # Collect session results
389
+ done_sessions = 0
390
+ for fut in as_completed(session_futures):
391
+ idx = session_futures[fut]
392
+ try:
393
+ session_evals[idx] = fut.result()
394
+ except Exception as e:
395
+ session_evals[idx] = {"error": str(e)}
396
+ done_sessions += 1
397
+ if done_sessions % 10 == 0 or done_sessions == len(session_records):
398
+ print(f" Sessions: {done_sessions}/{len(session_records)} done")
399
+
400
+ # Collect QA results
401
+ done_qa = 0
402
+ for fut in as_completed(qa_futures):
403
+ idx = qa_futures[fut]
404
+ try:
405
+ qa_evals[idx] = fut.result()
406
+ except Exception as e:
407
+ qa_evals[idx] = {"error": str(e)}
408
+ done_qa += 1
409
+ if done_qa % 20 == 0 or done_qa == len(qa_records):
410
+ print(f" QA: {done_qa}/{len(qa_records)} done")
411
+
412
+ # --- Stage 3: Aggregate + write ---
413
+ agg = aggregate_metrics(
414
+ config.baseline,
415
+ session_evaluations=[e for e in session_evals if e is not None],
416
+ qa_evaluations=[e for e in qa_evals if e is not None],
417
+ )
418
+
419
+ session_rows = []
420
+ for srec, s_eval in zip(session_records, session_evals):
421
+ row = _record_to_json_obj(srec)
422
+ row["eval"] = s_eval
423
+ session_rows.append(row)
424
+
425
+ qa_rows = []
426
+ for qrec, q_eval in zip(qa_records, qa_evals):
427
+ row = _record_to_json_obj(qrec)
428
+ row["eval"] = q_eval
429
+ qa_rows.append(row)
430
+
431
+ _write_jsonl(out / "session_records.jsonl", session_rows)
432
+ _write_jsonl(out / "qa_records.jsonl", qa_rows)
433
+ _write_json(out / "aggregate_metrics.json", agg)
434
+
435
+ print(f"\n[Done] Results written to {out}")
436
+ print(f" Aggregate: {json.dumps(agg, indent=2)}")
437
+
438
+
439
+ def build_parser() -> argparse.ArgumentParser:
440
+ p = argparse.ArgumentParser(prog="eval_framework")
441
+ p.add_argument("--dataset", required=True)
442
+ p.add_argument("--baseline", required=True)
443
+ p.add_argument("--output-dir", default="eval_framework/results")
444
+ p.add_argument("--smoke", action="store_true")
445
+ p.add_argument("--dry-run", action="store_true")
446
+ p.add_argument("--eval-only", action="store_true",
447
+ help="Skip pipeline, load from checkpoint in output-dir.")
448
+ p.add_argument("--max-eval-workers", type=int, default=5,
449
+ help="Parallel threads for eval stage (default 5).")
450
+ return p
451
+
452
+
453
+ def main(argv: list[str] | None = None) -> None:
454
+ parser = build_parser()
455
+ args = parser.parse_args(argv)
456
+ cfg = config_from_namespace(args)
457
+
458
+ if cfg.dry_run:
459
+ print(json.dumps(cfg.to_display_dict(), indent=2))
460
+ return
461
+
462
+ eval_only = bool(args.eval_only)
463
+
464
+ if not eval_only and not cfg.dataset_path.is_dir():
465
+ raise SystemExit(f"Dataset path is not a directory: {cfg.dataset_path}")
466
+
467
+ run_eval(cfg, max_eval_workers=args.max_eval_workers, eval_only=eval_only)
468
+
469
+
470
+ if __name__ == "__main__":
471
+ main()
config.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration types for eval runs (CLI, dry-run, and smoke execution)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ @dataclass
11
+ class EvalConfig:
12
+ """Resolved paths and flags for one eval invocation."""
13
+
14
+ dataset_path: Path
15
+ output_dir: Path
16
+ baseline: str
17
+ smoke: bool = False
18
+ dry_run: bool = False
19
+
20
+ def to_display_dict(self) -> dict[str, Any]:
21
+ """JSON-friendly snapshot for dry-run and logging."""
22
+ return {
23
+ "dataset_path": str(self.dataset_path),
24
+ "output_dir": str(self.output_dir),
25
+ "baseline": self.baseline,
26
+ "smoke": self.smoke,
27
+ "dry_run": self.dry_run,
28
+ "dataset_profile": "domain_a_v2_academic",
29
+ "judge": "llm (OpenAI API)",
30
+ }
datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Dataset loaders and schema definitions."""
datasets/convert_vistrajqa.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Convert VisTrajQA sessions-*.jsonl → eval_framework domain_a_v2 format.
2
+
3
+ Reads one or more sessions-*.jsonl files and produces the three files
4
+ expected by ``load_domain_a_v2_academic``:
5
+
6
+ <output_dir>/
7
+ ├── domain_a_v2.json
8
+ ├── stage4_memory_points.jsonl
9
+ └── stage4b_qa_checkpoints.jsonl
10
+
11
+ Usage:
12
+ python -m eval_framework.datasets.convert_vistrajqa \
13
+ --input data/generated/sessions-vab.jsonl \
14
+ data/generated/sessions-eb-nav.jsonl \
15
+ data/generated/sessions-arena.jsonl \
16
+ --output eval_framework/converted/all \
17
+ --text-only
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import json
24
+ import uuid as _uuid
25
+ from pathlib import Path
26
+ from typing import Any
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # QA type abbreviation → full name
30
+ # ---------------------------------------------------------------------------
31
+
32
+ QA_TYPE_FULL = {
33
+ "FR": "factual_recall",
34
+ "DU": "dynamic_update",
35
+ "MB": "memory_boundary",
36
+ "TR": "temporal_reasoning",
37
+ "KR": "knowledge_reasoning",
38
+ "VFR": "visual_factual_recall",
39
+ "VS": "visual_search",
40
+ "VU": "visual_update",
41
+ "CMR": "cross_modal_reasoning",
42
+ }
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Turn construction
47
+ # ---------------------------------------------------------------------------
48
+
49
+ def _build_user_text(step: dict[str, Any], *, text_only: bool) -> str:
50
+ """Build user turn text from a CanonicalStep.
51
+
52
+ User turn = what the agent perceives: observation + feedback + (caption if text_only).
53
+ """
54
+ parts: list[str] = []
55
+ obs = step.get("observation") or ""
56
+ if obs:
57
+ parts.append(f"OBSERVATION: {obs}")
58
+ fb = step.get("feedback") or ""
59
+ if fb:
60
+ parts.append(f"FEEDBACK: {fb}")
61
+ if text_only:
62
+ cap = step.get("image_caption") or ""
63
+ if cap:
64
+ parts.append(f"IMAGE: {cap}")
65
+ if not parts:
66
+ parts.append("(no textual observation)")
67
+ return "\n".join(parts)
68
+
69
+
70
+ def _build_assistant_text(step: dict[str, Any]) -> str:
71
+ """Build assistant turn text: thought + action."""
72
+ parts: list[str] = []
73
+ thought = step.get("thought") or ""
74
+ if thought:
75
+ parts.append(f"THOUGHT: {thought}")
76
+ action = step.get("action") or ""
77
+ if action:
78
+ parts.append(f"ACTION: {action}")
79
+ return "\n".join(parts) or "(no action)"
80
+
81
+
82
+ def _build_attachment(step: dict[str, Any], *, text_only: bool) -> list[dict[str, Any]]:
83
+ """Build attachment list for a step (caption-only for text_only mode)."""
84
+ cap = step.get("image_caption") or ""
85
+ if not cap:
86
+ return []
87
+ if text_only:
88
+ # Caption already inlined in user text, no separate attachment needed
89
+ return []
90
+ image_id = step.get("image_id") or step.get("image_path") or ""
91
+ return [{"caption": cap, "type": "image_caption", "image_id": image_id}]
92
+
93
+
94
+ # ---------------------------------------------------------------------------
95
+ # Session segmentation
96
+ # ---------------------------------------------------------------------------
97
+
98
+ def _segment_steps_by_probes(
99
+ steps: list[dict[str, Any]],
100
+ probes: list[dict[str, Any]],
101
+ total_steps: int,
102
+ ) -> list[tuple[str, list[dict[str, Any]]]]:
103
+ """Split steps into sessions at probe boundaries.
104
+
105
+ Returns list of (session_id, steps_in_session).
106
+ Session after probe i covers steps (prev_boundary+1 .. probe_i.after_step_num].
107
+ The remainder after the last probe is the final session.
108
+ """
109
+ probe_bounds = sorted(set(p["after_step_num"] for p in probes))
110
+ boundaries = [0] + probe_bounds + [total_steps]
111
+
112
+ sessions: list[tuple[str, list[dict[str, Any]]]] = []
113
+ for i in range(len(boundaries) - 1):
114
+ lo = boundaries[i] # exclusive lower bound (step_num > lo)
115
+ hi = boundaries[i + 1] # inclusive upper bound (step_num <= hi)
116
+ sid = f"S{i:02d}"
117
+ seg = [s for s in steps if lo < s["step_num"] <= hi]
118
+ if seg:
119
+ sessions.append((sid, seg))
120
+
121
+ return sessions
122
+
123
+
124
+ def _assign_mps_to_sessions(
125
+ memory_points: list[dict[str, Any]],
126
+ sessions: list[tuple[str, list[dict[str, Any]]]],
127
+ ) -> dict[str, list[dict[str, Any]]]:
128
+ """Map memory points to sessions by step_num range."""
129
+ # Build session_id → step_num range
130
+ ranges: list[tuple[str, int, int]] = []
131
+ for sid, seg in sessions:
132
+ lo = min(s["step_num"] for s in seg)
133
+ hi = max(s["step_num"] for s in seg)
134
+ ranges.append((sid, lo, hi))
135
+
136
+ result: dict[str, list[dict[str, Any]]] = {sid: [] for sid, _ in sessions}
137
+
138
+ for mp in memory_points:
139
+ sn = mp.get("step_num") or mp.get("probe_step_num") or 0
140
+ assigned = False
141
+ for sid, lo, hi in ranges:
142
+ if lo <= sn <= hi:
143
+ result[sid].append(mp)
144
+ assigned = True
145
+ break
146
+ if not assigned:
147
+ # Fallback: assign to last session
148
+ result[ranges[-1][0]].append(mp)
149
+
150
+ return result
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Memory point conversion
155
+ # ---------------------------------------------------------------------------
156
+
157
+ def _convert_mp(mp: dict[str, Any]) -> dict[str, Any]:
158
+ """VisTrajQA memory point → eval_framework gold memory point dict."""
159
+ return {
160
+ "memory_id": mp.get("mp_id", ""),
161
+ "memory_content": mp.get("content", ""),
162
+ "memory_type": mp.get("type", ""),
163
+ "memory_source": mp.get("source", "primary"),
164
+ "is_update": bool(mp.get("is_update", False)),
165
+ "original_memories": mp.get("original_memories") or [],
166
+ "importance": float(mp.get("importance", 0.0)),
167
+ "timestamp": None,
168
+ "update_type": mp.get("update_type") or "",
169
+ }
170
+
171
+
172
+ # ---------------------------------------------------------------------------
173
+ # Question conversion
174
+ # ---------------------------------------------------------------------------
175
+
176
+ def _convert_question(
177
+ q: dict[str, Any],
178
+ mp_content_map: dict[str, str],
179
+ ) -> dict[str, Any]:
180
+ """VisTrajQA question → eval_framework checkpoint question dict."""
181
+ qa_type = q.get("qa_type", "FR")
182
+ evidence_ids = q.get("evidence") or []
183
+ return {
184
+ "question": q.get("question", ""),
185
+ "answer": q.get("answer", ""),
186
+ "question_type": QA_TYPE_FULL.get(qa_type, qa_type),
187
+ "question_type_abbrev": qa_type,
188
+ "difficulty": q.get("difficulty", "medium"),
189
+ "evidence": [{"memory_id": mid} for mid in evidence_ids],
190
+ }
191
+
192
+
193
+ # ---------------------------------------------------------------------------
194
+ # Main conversion
195
+ # ---------------------------------------------------------------------------
196
+
197
+ def convert_one_session(
198
+ rec: dict[str, Any],
199
+ *,
200
+ text_only: bool = True,
201
+ ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
202
+ """Convert one VisTrajQA session record → (sample_json, stage4_row, qa_row).
203
+
204
+ Returns dicts ready for serialization into the three target files.
205
+ """
206
+ sample_id = rec["session_id"]
207
+ sample_uuid = str(_uuid.uuid5(_uuid.NAMESPACE_DNS, sample_id))
208
+ steps = rec["step_plan"]
209
+ total_steps = rec.get("total_steps") or len(steps)
210
+ probes = rec.get("probes") or []
211
+ post_qa = rec.get("post_trajectory_qa") or []
212
+ memory_points = rec.get("memory_points") or []
213
+
214
+ # mp_id → content map for evidence resolution
215
+ mp_content_map: dict[str, str] = {
216
+ mp["mp_id"]: mp.get("content", "")
217
+ for mp in memory_points
218
+ if mp.get("mp_id")
219
+ }
220
+
221
+ # --- Session segmentation ---
222
+ sessions = _segment_steps_by_probes(steps, probes, total_steps)
223
+ mp_by_session = _assign_mps_to_sessions(memory_points, sessions)
224
+
225
+ # --- Build domain_a_v2.json sample ---
226
+ session_objects: list[dict[str, Any]] = []
227
+ for sid, seg_steps in sessions:
228
+ dialogue: list[dict[str, Any]] = []
229
+ for step in seg_steps:
230
+ # User turn
231
+ user_text = _build_user_text(step, text_only=text_only)
232
+ dialogue.append({
233
+ "role": "user",
234
+ "content": user_text,
235
+ "timestamp": f"step_{step['step_num']:04d}",
236
+ "attachments": _build_attachment(step, text_only=text_only),
237
+ })
238
+ # Assistant turn
239
+ assistant_text = _build_assistant_text(step)
240
+ dialogue.append({
241
+ "role": "assistant",
242
+ "content": assistant_text,
243
+ "timestamp": f"step_{step['step_num']:04d}",
244
+ "attachments": [],
245
+ })
246
+
247
+ sess_obj: dict[str, Any] = {
248
+ "_v2_session_id": sid,
249
+ "dialogue": dialogue,
250
+ }
251
+ # S00 carries its own memory_points in the session object
252
+ if sid == "S00":
253
+ sess_obj["memory_points"] = [_convert_mp(mp) for mp in mp_by_session.get(sid, [])]
254
+ session_objects.append(sess_obj)
255
+
256
+ sample_json = {
257
+ "uuid": sample_uuid,
258
+ "sample_id": sample_id,
259
+ "sessions": session_objects,
260
+ # Metadata (not consumed by loader, but useful for debugging)
261
+ "_source": rec.get("source", ""),
262
+ "_env": rec.get("env", ""),
263
+ "_traj_id": rec.get("traj_id", ""),
264
+ "_task": rec.get("task", ""),
265
+ "_total_steps": total_steps,
266
+ }
267
+
268
+ # --- Build stage4_memory_points.jsonl row ---
269
+ stage4_sessions: list[dict[str, Any]] = []
270
+ for sid, _ in sessions:
271
+ if sid == "S00":
272
+ continue # S00 is embedded in domain_a_v2.json
273
+ mps = mp_by_session.get(sid, [])
274
+ if mps:
275
+ stage4_sessions.append({
276
+ "session_id": sid,
277
+ "memory_points": [_convert_mp(mp) for mp in mps],
278
+ })
279
+
280
+ stage4_row = {
281
+ "uuid": sample_uuid,
282
+ "sample_id": sample_id,
283
+ "memory_sessions": stage4_sessions,
284
+ }
285
+
286
+ # --- Build stage4b_qa_checkpoints.jsonl row ---
287
+ session_ids = [sid for sid, _ in sessions]
288
+ checkpoints: list[dict[str, Any]] = []
289
+
290
+ # Probe checkpoints
291
+ probe_by_after_step = {p["after_step_num"]: p for p in probes}
292
+ cumulative_sessions: list[str] = []
293
+ for sid, seg_steps in sessions:
294
+ cumulative_sessions.append(sid)
295
+ max_step_in_session = max(s["step_num"] for s in seg_steps)
296
+ probe = probe_by_after_step.get(max_step_in_session)
297
+ if probe is None:
298
+ continue
299
+ questions = [
300
+ _convert_question(q, mp_content_map)
301
+ for q in probe.get("questions", [])
302
+ ]
303
+ if questions:
304
+ checkpoints.append({
305
+ "checkpoint_id": f"probe_{probe['probe_id']}",
306
+ "covered_sessions": list(cumulative_sessions),
307
+ "questions": questions,
308
+ })
309
+
310
+ # Post-trajectory checkpoint (covers all sessions)
311
+ if post_qa:
312
+ post_questions = [
313
+ _convert_question(q, mp_content_map)
314
+ for q in post_qa
315
+ ]
316
+ if post_questions:
317
+ checkpoints.append({
318
+ "checkpoint_id": "post_trajectory",
319
+ "covered_sessions": session_ids,
320
+ "questions": post_questions,
321
+ })
322
+
323
+ qa_row = {
324
+ "uuid": sample_uuid,
325
+ "sample_id": sample_id,
326
+ "checkpoints": checkpoints,
327
+ }
328
+
329
+ return sample_json, stage4_row, qa_row
330
+
331
+
332
+ def convert_files(
333
+ input_paths: list[Path],
334
+ output_dir: Path,
335
+ *,
336
+ text_only: bool = True,
337
+ ) -> None:
338
+ """Read VisTrajQA session files and write the three domain_a_v2 files."""
339
+ output_dir.mkdir(parents=True, exist_ok=True)
340
+
341
+ all_samples: list[dict[str, Any]] = []
342
+ all_stage4: list[dict[str, Any]] = []
343
+ all_qa: list[dict[str, Any]] = []
344
+
345
+ for path in input_paths:
346
+ print(f"Reading {path} ...")
347
+ with path.open(encoding="utf-8") as fh:
348
+ for line_num, line in enumerate(fh, 1):
349
+ line = line.strip()
350
+ if not line:
351
+ continue
352
+ rec = json.loads(line)
353
+ sample_json, stage4_row, qa_row = convert_one_session(
354
+ rec, text_only=text_only,
355
+ )
356
+ all_samples.append(sample_json)
357
+ all_stage4.append(stage4_row)
358
+ all_qa.append(qa_row)
359
+
360
+ # Write domain_a_v2.json
361
+ domain_path = output_dir / "domain_a_v2.json"
362
+ domain_path.write_text(
363
+ json.dumps(all_samples, ensure_ascii=False, indent=2) + "\n",
364
+ encoding="utf-8",
365
+ )
366
+ print(f" → {domain_path} ({len(all_samples)} samples)")
367
+
368
+ # Write stage4_memory_points.jsonl
369
+ stage4_path = output_dir / "stage4_memory_points.jsonl"
370
+ with stage4_path.open("w", encoding="utf-8") as fh:
371
+ for row in all_stage4:
372
+ fh.write(json.dumps(row, ensure_ascii=False) + "\n")
373
+ print(f" → {stage4_path} ({len(all_stage4)} rows)")
374
+
375
+ # Write stage4b_qa_checkpoints.jsonl
376
+ qa_path = output_dir / "stage4b_qa_checkpoints.jsonl"
377
+ with qa_path.open("w", encoding="utf-8") as fh:
378
+ for row in all_qa:
379
+ fh.write(json.dumps(row, ensure_ascii=False) + "\n")
380
+ print(f" → {qa_path} ({len(all_qa)} rows)")
381
+
382
+ # --- Validation ---
383
+ print("\nValidating ...")
384
+ _validate(output_dir)
385
+ print("Done.")
386
+
387
+
388
+ def _validate(output_dir: Path) -> None:
389
+ """Quick validation: load through the eval_framework loader."""
390
+ try:
391
+ from eval_framework.datasets.domain_a_v2 import load_domain_a_v2_academic
392
+ bundle = load_domain_a_v2_academic(output_dir)
393
+ print(f" Loaded {len(bundle.samples)} samples successfully")
394
+ for sample in bundle.samples:
395
+ n_sessions = len(sample.sessions)
396
+ n_turns = sum(len(s.turns) for s in sample.sessions)
397
+ n_checkpoints = len(sample.normalized_checkpoints)
398
+ n_questions = sum(len(cp.questions) for cp in sample.normalized_checkpoints)
399
+ n_gold = len(sample.session_gold_states)
400
+ n_gold_points = sum(
401
+ len(g.cumulative_gold_memories)
402
+ for g in sample.session_gold_states[-1:]
403
+ )
404
+ print(
405
+ f" {sample.sample_id}: "
406
+ f"{n_sessions} sessions, {n_turns} turns, "
407
+ f"{n_checkpoints} checkpoints, {n_questions} questions, "
408
+ f"{n_gold_points} gold points"
409
+ )
410
+ except Exception as e:
411
+ print(f" Validation failed: {e}")
412
+ raise
413
+
414
+
415
+ def main() -> None:
416
+ parser = argparse.ArgumentParser(
417
+ description="Convert VisTrajQA sessions → eval_framework domain_a_v2 format",
418
+ )
419
+ parser.add_argument(
420
+ "--input", "-i",
421
+ nargs="+",
422
+ required=True,
423
+ help="Path(s) to sessions-*.jsonl files",
424
+ )
425
+ parser.add_argument(
426
+ "--output", "-o",
427
+ required=True,
428
+ help="Output directory for the three converted files",
429
+ )
430
+ parser.add_argument(
431
+ "--text-only",
432
+ action="store_true",
433
+ default=True,
434
+ help="Inline image captions into user turn text (default: true)",
435
+ )
436
+ parser.add_argument(
437
+ "--multimodal",
438
+ action="store_true",
439
+ help="Keep image captions as attachments instead of inlining",
440
+ )
441
+ args = parser.parse_args()
442
+
443
+ text_only = not args.multimodal
444
+
445
+ input_paths = [Path(p).expanduser().resolve() for p in args.input]
446
+ output_dir = Path(args.output).expanduser().resolve()
447
+
448
+ convert_files(input_paths, output_dir, text_only=text_only)
449
+
450
+
451
+ if __name__ == "__main__":
452
+ main()
datasets/domain_a_v2.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Domain A v2 academic bundle: dialogue normalization + staged QA / gold state."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any, Iterator, Mapping
9
+
10
+ from eval_framework.datasets.schemas import NormalizedTurn, normalize_turn
11
+ from eval_framework.pipeline.gold_state import (
12
+ SessionGoldState,
13
+ build_session_gold_states,
14
+ )
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class Stage4Record:
19
+ uuid: str
20
+ sample_id: str
21
+ memory_sessions: tuple[tuple[str, tuple[Mapping[str, Any], ...]], ...]
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class QARecord:
26
+ uuid: str
27
+ sample_id: str
28
+ raw_checkpoints: tuple[Mapping[str, Any], ...]
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class NormalizedCheckpointQuestion:
33
+ question: str
34
+ gold_answer: str
35
+ gold_evidence_memory_ids: tuple[str, ...]
36
+ gold_evidence_contents: tuple[str, ...]
37
+ question_type: str
38
+ question_type_abbrev: str
39
+ difficulty: str
40
+
41
+
42
+ @dataclass(frozen=True)
43
+ class NormalizedCheckpoint:
44
+ checkpoint_id: str
45
+ covered_sessions: tuple[str, ...]
46
+ questions: tuple[NormalizedCheckpointQuestion, ...]
47
+
48
+
49
+ @dataclass(frozen=True)
50
+ class DomainAV2Session:
51
+ session_id: str
52
+ turns: tuple[NormalizedTurn, ...]
53
+
54
+
55
+ @dataclass(frozen=True)
56
+ class DomainAV2AcademicSample:
57
+ uuid: str
58
+ sample_id: str
59
+ sessions: tuple[DomainAV2Session, ...]
60
+ stage4: Stage4Record
61
+ qa_record: QARecord
62
+ normalized_checkpoints: tuple[NormalizedCheckpoint, ...]
63
+ session_gold_states: tuple[SessionGoldState, ...]
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class DomainAV2AcademicBundle:
68
+ samples: tuple[DomainAV2AcademicSample, ...]
69
+
70
+
71
+ def _read_jsonl(path: Path) -> Iterator[dict[str, Any]]:
72
+ with path.open(encoding="utf-8") as fh:
73
+ for line in fh:
74
+ line = line.strip()
75
+ if not line:
76
+ continue
77
+ yield json.loads(line)
78
+
79
+
80
+ def _stage4_from_obj(obj: Mapping[str, Any]) -> Stage4Record:
81
+ blocks: list[tuple[str, tuple[Mapping[str, Any], ...]]] = []
82
+ for ms in obj.get("memory_sessions") or []:
83
+ sid = str(ms.get("session_id", ""))
84
+ pts = ms.get("memory_points") or []
85
+ if not isinstance(pts, list):
86
+ pts = []
87
+ blocks.append((sid, tuple(pts)))
88
+ return Stage4Record(
89
+ uuid=str(obj["uuid"]),
90
+ sample_id=str(obj["sample_id"]),
91
+ memory_sessions=tuple(blocks),
92
+ )
93
+
94
+
95
+ def _qa_from_obj(obj: Mapping[str, Any]) -> QARecord:
96
+ cps = obj.get("checkpoints") or []
97
+ if not isinstance(cps, list):
98
+ cps = []
99
+ return QARecord(
100
+ uuid=str(obj["uuid"]),
101
+ sample_id=str(obj["sample_id"]),
102
+ raw_checkpoints=tuple(cps),
103
+ )
104
+
105
+
106
+ def _normalize_checkpoint_question(
107
+ raw: Mapping[str, Any],
108
+ memory_content_map: Mapping[str, str],
109
+ ) -> NormalizedCheckpointQuestion:
110
+ evidence = raw.get("evidence") or []
111
+ mem_ids: list[str] = []
112
+ mem_contents: list[str] = []
113
+ if isinstance(evidence, list):
114
+ for item in evidence:
115
+ if isinstance(item, dict) and "memory_id" in item:
116
+ mid = str(item["memory_id"])
117
+ mem_ids.append(mid)
118
+ content = memory_content_map.get(mid, "")
119
+ if content:
120
+ mem_contents.append(content)
121
+ return NormalizedCheckpointQuestion(
122
+ question=str(raw.get("question", "")),
123
+ gold_answer=str(raw.get("answer", "")),
124
+ gold_evidence_memory_ids=tuple(mem_ids),
125
+ gold_evidence_contents=tuple(mem_contents),
126
+ question_type=str(raw.get("question_type", "")),
127
+ question_type_abbrev=str(raw.get("question_type_abbrev", "")),
128
+ difficulty=str(raw.get("difficulty", "")),
129
+ )
130
+
131
+
132
+ def _normalize_checkpoints(
133
+ raw_checkpoints: tuple[Mapping[str, Any], ...],
134
+ memory_content_map: Mapping[str, str],
135
+ ) -> tuple[NormalizedCheckpoint, ...]:
136
+ out: list[NormalizedCheckpoint] = []
137
+ for cp in raw_checkpoints:
138
+ qs = cp.get("questions") or []
139
+ if not isinstance(qs, list):
140
+ qs = []
141
+ covered = cp.get("covered_sessions") or []
142
+ if not isinstance(covered, list):
143
+ covered = []
144
+ out.append(
145
+ NormalizedCheckpoint(
146
+ checkpoint_id=str(cp.get("checkpoint_id", "")),
147
+ covered_sessions=tuple(str(x) for x in covered),
148
+ questions=tuple(
149
+ _normalize_checkpoint_question(q, memory_content_map)
150
+ for q in qs
151
+ if isinstance(q, Mapping)
152
+ ),
153
+ )
154
+ )
155
+ return tuple(out)
156
+
157
+
158
+ def _dialogue_turns(sample_id: str, session_id: str, dialogue: list[Any]) -> tuple[NormalizedTurn, ...]:
159
+ turns: list[NormalizedTurn] = []
160
+ for turn_index, entry in enumerate(dialogue):
161
+ if not isinstance(entry, dict):
162
+ continue
163
+ text = str(entry.get("content", ""))
164
+ attachments_raw = entry.get("attachments") or []
165
+ captions: list[str] = []
166
+ if isinstance(attachments_raw, list):
167
+ for att in attachments_raw:
168
+ if isinstance(att, dict):
169
+ cap = att.get("caption", "")
170
+ captions.append(cap if isinstance(cap, str) else str(cap))
171
+ if captions:
172
+ text = text + "\n\n" + "\n".join(captions)
173
+ ts = entry.get("timestamp")
174
+ timestamp = ts if isinstance(ts, str) else (str(ts) if ts is not None else None)
175
+ raw_turn = {
176
+ "sample_id": sample_id,
177
+ "session_id": session_id,
178
+ "turn_index": turn_index,
179
+ "role": str(entry.get("role", "user")),
180
+ "text": text,
181
+ "attachments": [],
182
+ "timestamp": timestamp,
183
+ }
184
+ turns.append(normalize_turn(raw_turn))
185
+ return tuple(turns)
186
+
187
+
188
+ def load_domain_a_v2_academic(data_dir: Path) -> DomainAV2AcademicBundle:
189
+ data_dir = data_dir.resolve()
190
+ main_path = data_dir / "domain_a_v2.json"
191
+ stage4_path = data_dir / "stage4_memory_points.jsonl"
192
+ qa_path = data_dir / "stage4b_qa_checkpoints.jsonl"
193
+
194
+ raw_samples = json.loads(main_path.read_text(encoding="utf-8"))
195
+ if not isinstance(raw_samples, list):
196
+ raise ValueError("domain_a_v2.json must be a list")
197
+
198
+ stage4_by_id: dict[str, Stage4Record] = {}
199
+ for obj in _read_jsonl(stage4_path):
200
+ rec = _stage4_from_obj(obj)
201
+ stage4_by_id[rec.sample_id] = rec
202
+
203
+ qa_by_id: dict[str, QARecord] = {}
204
+ for obj in _read_jsonl(qa_path):
205
+ rec = _qa_from_obj(obj)
206
+ qa_by_id[rec.sample_id] = rec
207
+
208
+ built: list[DomainAV2AcademicSample] = []
209
+ for item in raw_samples:
210
+ if not isinstance(item, dict):
211
+ continue
212
+ sample_id = str(item["sample_id"])
213
+ uuid = str(item["uuid"])
214
+ stage4 = stage4_by_id.get(sample_id)
215
+ qa = qa_by_id.get(sample_id)
216
+ if stage4 is None or qa is None:
217
+ raise KeyError(f"missing stage4 or QA row for sample_id={sample_id}")
218
+
219
+ stage4_map = {sid: pts for sid, pts in stage4.memory_sessions}
220
+
221
+ sessions_raw = item.get("sessions") or []
222
+ if not isinstance(sessions_raw, list):
223
+ sessions_raw = []
224
+
225
+ session_blocks: list[DomainAV2Session] = []
226
+ ordered_ids: list[str] = []
227
+ s00_points: tuple[Mapping[str, Any], ...] = ()
228
+
229
+ for sess in sessions_raw:
230
+ if not isinstance(sess, dict):
231
+ continue
232
+ sid = str(sess.get("_v2_session_id", ""))
233
+ if not sid:
234
+ continue
235
+ ordered_ids.append(sid)
236
+ dialogue = sess.get("dialogue") or []
237
+ if not isinstance(dialogue, list):
238
+ dialogue = []
239
+ session_blocks.append(
240
+ DomainAV2Session(
241
+ session_id=sid,
242
+ turns=_dialogue_turns(sample_id, sid, dialogue),
243
+ )
244
+ )
245
+ if sid == "S00":
246
+ mps = sess.get("memory_points") or []
247
+ if isinstance(mps, list):
248
+ s00_points = tuple(mps)
249
+
250
+ gold_states = build_session_gold_states(
251
+ ordered_ids,
252
+ s00_memory_points=s00_points,
253
+ stage4_by_session_id=stage4_map,
254
+ )
255
+
256
+ # Build memory_id -> memory_content map from all sources
257
+ memory_content_map: dict[str, str] = {}
258
+ for mp_raw in s00_points:
259
+ if isinstance(mp_raw, Mapping):
260
+ mid = mp_raw.get("memory_id")
261
+ mc = mp_raw.get("memory_content")
262
+ if mid is not None and mc is not None:
263
+ memory_content_map[str(mid)] = str(mc)
264
+ for _sid, pts in stage4.memory_sessions:
265
+ for mp_raw in pts:
266
+ if isinstance(mp_raw, Mapping):
267
+ mid = mp_raw.get("memory_id")
268
+ mc = mp_raw.get("memory_content")
269
+ if mid is not None and mc is not None:
270
+ memory_content_map[str(mid)] = str(mc)
271
+
272
+ built.append(
273
+ DomainAV2AcademicSample(
274
+ uuid=uuid,
275
+ sample_id=sample_id,
276
+ sessions=tuple(session_blocks),
277
+ stage4=stage4,
278
+ qa_record=qa,
279
+ normalized_checkpoints=_normalize_checkpoints(
280
+ qa.raw_checkpoints, memory_content_map
281
+ ),
282
+ session_gold_states=gold_states,
283
+ )
284
+ )
285
+
286
+ return DomainAV2AcademicBundle(samples=tuple(built))
datasets/schemas.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalized runtime schemas shared across adapters, pipeline, and evaluators."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Mapping
7
+
8
+ MemoryDeltaOp = str
9
+
10
+ _VALID_DELTA_OPS: frozenset[str] = frozenset(
11
+ {"add", "update", "keep", "suppress", "archive"}
12
+ )
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class Attachment:
17
+ """Caption-first attachment; image_id is optional for caption-only items."""
18
+
19
+ caption: str
20
+ type: str = "image_caption"
21
+ image_id: str | None = None
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class NormalizedTurn:
26
+ sample_id: str
27
+ session_id: str
28
+ turn_index: int
29
+ role: str
30
+ text: str
31
+ attachments: tuple[Attachment, ...] = ()
32
+ timestamp: str | None = None
33
+
34
+
35
+ def normalize_turn(raw: Mapping[str, Any]) -> NormalizedTurn:
36
+ """Build a turn record, keeping attachments that only carry captions."""
37
+ attachments: list[Attachment] = []
38
+ for item in raw.get("attachments") or []:
39
+ if not isinstance(item, dict):
40
+ continue
41
+ cap = item.get("caption", "")
42
+ caption = cap if isinstance(cap, str) else str(cap)
43
+ iid = item.get("image_id")
44
+ if iid is None or iid == "":
45
+ image_id: str | None = None
46
+ else:
47
+ image_id = str(iid)
48
+ typ = item.get("type", "image_caption")
49
+ type_str = typ if isinstance(typ, str) else str(typ)
50
+ attachments.append(
51
+ Attachment(caption=caption, type=type_str, image_id=image_id)
52
+ )
53
+ ts = raw.get("timestamp")
54
+ timestamp = ts if isinstance(ts, str) else (str(ts) if ts is not None else None)
55
+ return NormalizedTurn(
56
+ sample_id=str(raw["sample_id"]),
57
+ session_id=str(raw["session_id"]),
58
+ turn_index=int(raw["turn_index"]),
59
+ role=str(raw["role"]),
60
+ text=str(raw["text"]),
61
+ attachments=tuple(attachments),
62
+ timestamp=timestamp,
63
+ )
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class MemorySnapshotRecord:
68
+ memory_id: str
69
+ text: str
70
+ session_id: str
71
+ status: str
72
+ source: str | None = None
73
+ raw_backend_id: str | None = None
74
+ raw_backend_type: str | None = None
75
+ metadata: dict[str, Any] = field(default_factory=dict)
76
+
77
+
78
+ @dataclass(frozen=True)
79
+ class MemoryDeltaRecord:
80
+ session_id: str
81
+ op: MemoryDeltaOp
82
+ text: str
83
+ linked_previous: tuple[str, ...] = ()
84
+ raw_backend_id: str | None = None
85
+ metadata: dict[str, Any] = field(default_factory=dict)
86
+
87
+ def __post_init__(self) -> None:
88
+ if self.op not in _VALID_DELTA_OPS:
89
+ raise ValueError(f"invalid memory delta op: {self.op!r}")
90
+
91
+
92
+ @dataclass(frozen=True)
93
+ class RetrievalItem:
94
+ rank: int
95
+ memory_id: str
96
+ text: str
97
+ score: float
98
+ raw_backend_id: str | None = None
99
+
100
+
101
+ @dataclass(frozen=True)
102
+ class RetrievalRecord:
103
+ query: str
104
+ top_k: int
105
+ items: list[RetrievalItem]
106
+ raw_trace: dict[str, Any] = field(default_factory=dict)
docs/DATA_CONVERSION.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VisTrajQA → Eval Framework 数据适配指南
2
+
3
+ ## 概述
4
+
5
+ `convert_vistrajqa.py` 将 VisTrajQA 的 `sessions-*.jsonl` 转换为 eval_framework 所需的 domain_a_v2 三文件格式,从而可以用 Mem-Gallery / A-Mem / MemoryOS 等 baseline 进行统一评测。
6
+
7
+ ## 快速使用
8
+
9
+ ```bash
10
+ # 转换所有数据源(text-only 模式,默认)
11
+ python -m eval_framework.datasets.convert_vistrajqa \
12
+ --input data/generated/sessions-vab.jsonl \
13
+ data/generated/sessions-eb-nav.jsonl \
14
+ data/generated/sessions-arena.jsonl \
15
+ data/generated/sessions-eb-alfred.jsonl \
16
+ data/generated/sessions-infini-thor.jsonl \
17
+ --output eval_framework/converted/all
18
+
19
+ # 只转换某个数据源
20
+ python -m eval_framework.datasets.convert_vistrajqa \
21
+ --input data/generated/sessions-vab.jsonl \
22
+ --output eval_framework/converted/vab
23
+
24
+ # multimodal 模式(image caption 作为 attachment 而非内联文本)
25
+ python -m eval_framework.datasets.convert_vistrajqa \
26
+ --input data/generated/sessions-vab.jsonl \
27
+ --output eval_framework/converted/vab-mm \
28
+ --multimodal
29
+
30
+ # 转换后直接跑 eval
31
+ python -m eval_framework.cli \
32
+ --dataset eval_framework/converted/all \
33
+ --baseline FUMemory \
34
+ --output-dir eval_framework/results/FUMemory
35
+ ```
36
+
37
+ ## 转换映射
38
+
39
+ ### 数据结构映射
40
+
41
+ ```
42
+ VisTrajQA session → eval_framework sample
43
+ ├── session_id → sample_id
44
+ ├── step_plan[] → sessions[].dialogue[] (user + assistant turns)
45
+ ├── probes[] → checkpoints[] (probe checkpoints)
46
+ ├── post_trajectory_qa[] → checkpoints[-1] (post-trajectory checkpoint)
47
+ └── memory_points[] → gold memory points (S00 embedded + stage4)
48
+ ```
49
+
50
+ ### Session 切分
51
+
52
+ 一条 VisTrajQA 轨迹(如 30 步,4 个 probe 在 step 6/12/18/24)按 probe 边界切分为 5 个 session:
53
+
54
+ ```
55
+ 步骤 1-6 → S00 (probe 1 在此 session 结束后触发)
56
+ 步骤 7-12 → S01 (probe 2)
57
+ 步骤 13-18 → S02 (probe 3)
58
+ 步骤 19-24 → S03 (probe 4)
59
+ 步骤 25-30 → S04 (post-trajectory QA 在全部 session 结束后触发)
60
+ ```
61
+
62
+ 这样保证 eval_framework 的 runner 在每个 session 完成后恰好触发对应的 checkpoint。
63
+
64
+ ### Turn 构建
65
+
66
+ 每个 step 生成 2 个 dialogue turn:
67
+
68
+ | Turn | Role | 内容 |
69
+ |------|------|------|
70
+ | User turn | `user` | OBSERVATION + FEEDBACK + IMAGE caption(text-only 模式) |
71
+ | Assistant turn | `assistant` | THOUGHT + ACTION |
72
+
73
+ **text-only 模式**(默认):image caption 直接写入 user turn 文本,格式为 `IMAGE: <caption>`。适用于所有 text-only baseline。
74
+
75
+ **multimodal 模式**(`--multimodal`):image caption 作为 `attachment` 附加,不写入正文。适用于 MMMemory 等多模态 baseline。
76
+
77
+ ### Memory Point 映射
78
+
79
+ | VisTrajQA 字段 | eval_framework 字段 | 说明 |
80
+ |----------------|---------------------|------|
81
+ | `mp_id` | `memory_id` | 如 `mp_S04_1` |
82
+ | `content` | `memory_content` | 一句话事实描述 |
83
+ | `type` | `memory_type` | `event_memory` / `state_memory` / `spatial_memory` |
84
+ | `source` | `memory_source` | `primary` (文本) / `secondary` (推断) |
85
+ | `is_update` | `is_update` | 是否为更新型记忆 |
86
+ | `original_memories` | `original_memories` | 被替换的旧内容列表 |
87
+ | `importance` | `importance` | 0.4 / 0.6 / 0.8 / 1.0 |
88
+ | `update_type` | `update_type` | `status_update` / `location_change` / ... |
89
+
90
+ Memory point 按 `step_num` 分配到对应 session:
91
+ - S00 的 memory points 嵌入在 `domain_a_v2.json` 的 session 对象中
92
+ - 其他 session 的 memory points 写入 `stage4_memory_points.jsonl`
93
+
94
+ ### QA / Checkpoint 映射
95
+
96
+ **Probe checkpoint**:每个 probe 生成一个 checkpoint,`covered_sessions` 为该 probe 及之前所有 session。
97
+
98
+ **Post-trajectory checkpoint**:覆盖全部 session,包含 9 类 QA。
99
+
100
+ | VisTrajQA QA type | eval_framework question_type | 缩写 |
101
+ |----|----|-----|
102
+ | FR | factual_recall | FR |
103
+ | DU | dynamic_update | DU |
104
+ | MB | memory_boundary | MB |
105
+ | TR | temporal_reasoning | TR |
106
+ | KR | knowledge_reasoning | KR |
107
+ | VFR | visual_factual_recall | VFR |
108
+ | VS | visual_search | VS |
109
+ | VU | visual_update | VU |
110
+ | CMR | cross_modal_reasoning | CMR |
111
+
112
+ Evidence 字段从 `["mp_S04_1"]`(字符串列表)转换为 `[{"memory_id": "mp_S04_1"}]`(字典列表)以匹配 eval_framework 格式。
113
+
114
+ ## 输出文件
115
+
116
+ ```
117
+ eval_framework/converted/all/
118
+ ├── domain_a_v2.json # 主对话数据 (JSON array)
119
+ ├── stage4_memory_points.jsonl # 每 session 的 gold memory points
120
+ └── stage4b_qa_checkpoints.jsonl # checkpoint QA 题目
121
+ ```
122
+
123
+ ## 评测维度与 VisTrajQA 的对应
124
+
125
+ | eval_framework 维度 | 测量内容 | 对应 VisTrajQA 特性 |
126
+ |-----|-----|-----|
127
+ | Memory Recall | 记忆系统存储了多少 gold points | 直接对应,所有 MP 类型 |
128
+ | Memory Correctness | 存储的记忆是否正确 | 检测 hallucination |
129
+ | Update Handling | 更新型记忆是否正确替换 | 对应 `is_update=true` 的 MP |
130
+ | Interference Rejection | 干扰信息是否被过滤 | VisTrajQA 无 interference 标注,此维度为空 |
131
+ | QA Accuracy | 问答正确率 | 对应 9 类 QA (FR/DU/MB/TR/KR/VFR/VS/VU/CMR) |
132
+ | Evidence Coverage | 回答引用了多少 gold evidence | 对应 evidence memory_point_ids |
133
+
134
+ > **注意**:VisTrajQA 没有 interference(干扰信息)标注,因此 eval_framework 的 Interference Rejection 维度在评测结果中会为空值。MB(Memory Boundary)类型的题目在 QA 层面测试了类似能力。
135
+
136
+ ## 注意事项
137
+
138
+ 1. **text-only baseline(FU/ST/LT/GA/MG/RF)**:使用默认 `--text-only`,image caption 内联到用户消息文本中
139
+ 2. **multimodal baseline(MM/MMFU/NG/AUGUSTUS)**:使用 `--multimodal`,caption 作为 attachment
140
+ 3. **caption 质量**:text-only baseline 对图像的理解完全依赖 caption 质量。如果 `image_caption` 为空,用户 turn 中不会有任何视觉信息
141
+ 4. **Arena 数据**:observation 恒为空字符串,视觉信息完全来自 image_caption
142
+ 5. **转换器会自动验证**:运行后会调用 `load_domain_a_v2_academic` 检验输出是否合法
docs/EXPERIMENTS.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 实验记录
2
+
3
+ ## 实验环境
4
+
5
+ - **数据**:VisTrajQA VAB smoke(1 sample, vab_minecraft, 30 步, 5 sessions, 45 QA)
6
+ - **转换模式**:text-only(image caption 内联到 user turn 文本)
7
+ - **Judge / Answer 模型**:从 `.env` 读取 `OPENAI_MODEL`
8
+ - **日期**:2026-04-15
9
+
10
+ ---
11
+
12
+ ## 总表
13
+
14
+ ### Memory 维度
15
+
16
+ | Baseline | Recall ↑ | Update Recall ↑ | Correctness ↑ | Hallucination ↓ | Irrelevant ↓ | Update Score ↑ |
17
+ |----------|----------|-----------------|---------------|-----------------|--------------|----------------|
18
+ | Dummy | 95.0% | 90.0% | 50.0% | 5.0% | 45.0% | 71.4% |
19
+ | FUMemory | 95.0% | 90.0% | 80.0% | 10.0% | 10.0% | 69.0% |
20
+ | STMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
21
+ | LTMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 69.0% |
22
+ | GAMemory | 92.1% | 90.0% | 62.5% | 22.5% | 15.0% | 71.4% |
23
+ | MGMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
24
+ | RFMemory | 95.0% | 90.0% | 90.0% | 6.7% | 3.3% | 71.4% |
25
+ | MMMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
26
+ | MMFUMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
27
+ | NGMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
28
+ | AUGUSTUSMemory | 95.0% | 90.0% | 83.3% | 6.7% | 10.0% | 71.4% |
29
+ | UniversalRAGMemory | 95.0% | 90.0% | 80.0% | 10.0% | 10.0% | 71.4% |
30
+ | Mem0 | 47.9% | 56.7% | 57.3% | 18.0% | 24.7% | 81.0% |
31
+ | Mem0-Graph | 48.2% | 56.7% | 37.9% | 14.6% | 38.1% | 71.4% |
32
+ | SimpleMem | 95.0% | 90.0% | 48.3% | 5.0% | 46.7% | 71.4% |
33
+ | Omni-SimpleMem | — | — | — | — | — | — |
34
+ | MemVerse | 95.0% | 90.0% | 65.0% | 23.3% | 11.7% | 71.4% |
35
+ | Zep | — | — | — | — | — | — |
36
+ | A-Mem | 95.0% | 90.0% | 56.7% | 5.0% | 38.3% | 71.4% |
37
+ | MemoryOS | 47.1% | 41.0% | 33.8% | 3.5% | 62.8% | 50.0% |
38
+
39
+ ### QA 维度
40
+
41
+ | Baseline | QA Correct ↑ | QA Hallucination ↓ | QA Omission ↓ | Evidence Coverage ↑ |
42
+ |----------|--------------|---------------------|---------------|---------------------|
43
+ | Dummy | 44.4% | 24.4% | 31.1% | 47.6% |
44
+ | FUMemory | 57.8% | 40.0% | 2.2% | 65.0% |
45
+ | STMemory | 31.1% | 33.3% | 35.6% | 32.0% |
46
+ | LTMemory | 40.0% | 26.7% | 33.3% | 51.5% |
47
+ | GAMemory | 37.8% | 20.0% | 42.2% | 51.5% |
48
+ | MGMemory | 60.0% | 35.6% | 4.4% | 68.9% |
49
+ | RFMemory | 60.0% | 37.8% | 2.2% | 63.1% |
50
+ | MMMemory | 57.8% | 40.0% | 2.2% | 66.0% |
51
+ | MMFUMemory | 60.0% | 37.8% | 2.2% | 67.0% |
52
+ | NGMemory | 62.2% | 24.4% | 13.3% | 71.8% |
53
+ | AUGUSTUSMemory | 64.4% | 22.2% | 13.3% | 68.9% |
54
+ | UniversalRAGMemory | 40.0% | 28.9% | 31.1% | 49.5% |
55
+ | Mem0 | 26.7% | 28.9% | 44.4% | 31.1% |
56
+ | Mem0-Graph | 31.1% | 20.0% | 48.9% | 32.0% |
57
+ | SimpleMem | 57.8% | 22.2% | 20.0% | 35.9% |
58
+ | Omni-SimpleMem | — | — | — | — |
59
+ | MemVerse | 40.0% | 28.9% | 31.1% | 40.8% |
60
+ | Zep | — | — | — | — |
61
+ | A-Mem | 46.7% | 20.0% | 33.3% | 50.5% |
62
+ | MemoryOS | 28.9% | 24.4% | 46.7% | 37.9% |
63
+
64
+ ### QA 分类型正确率
65
+
66
+ | Baseline | FR | DU | MB | TR | KR | VFR | VS | VU | CMR |
67
+ |----------|----|----|----|----|----|----|----|----|-----|
68
+ | Dummy | 5/5 | 0/5 | 4/5 | 2/5 | 2/5 | 4/5 | 0/5 | 1/5 | 2/5 |
69
+ | FUMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 2/5 | 0/5 |
70
+ | STMemory | 0/5 | 0/5 | 5/5 | 1/5 | 3/5 | 2/5 | 0/5 | 0/5 | 3/5 |
71
+ | LTMemory | 2/5 | 1/5 | 5/5 | 1/5 | 3/5 | 4/5 | 0/5 | 0/5 | 2/5 |
72
+ | GAMemory | 4/5 | 0/5 | 5/5 | 0/5 | 2/5 | 4/5 | 0/5 | 0/5 | 2/5 |
73
+ | MGMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 1/5 | 2/5 |
74
+ | RFMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 1/5 | 2/5 |
75
+ | MMMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 2/5 | 0/5 |
76
+ | MMFUMemory | 4/5 | 3/5 | 5/5 | 4/5 | 4/5 | 4/5 | 0/5 | 2/5 | 1/5 |
77
+ | NGMemory | 5/5 | 3/5 | 5/5 | 3/5 | 3/5 | 4/5 | 0/5 | 2/5 | 3/5 |
78
+ | AUGUSTUSMemory | 5/5 | 2/5 | 5/5 | 4/5 | 3/5 | 4/5 | 1/5 | 2/5 | 3/5 |
79
+ | UniversalRAGMemory | 2/5 | 1/5 | 5/5 | 1/5 | 3/5 | 4/5 | 0/5 | 0/5 | 2/5 |
80
+ | Mem0 | 2/5 | 1/5 | 5/5 | 0/5 | 0/5 | 2/5 | 0/5 | 0/5 | 2/5 |
81
+ | Mem0-Graph | 2/5 | 0/5 | 5/5 | 0/5 | 1/5 | 3/5 | 0/5 | 0/5 | 3/5 |
82
+ | SimpleMem | 5/5 | 3/5 | 4/5 | 3/5 | 2/5 | 2/5 | 0/5 | 3/5 | 4/5 |
83
+ | Omni-SimpleMem | — | — | — | — | — | — | — | — | — |
84
+ | MemVerse | 5/5 | 0/5 | 5/5 | 0/5 | 2/5 | 3/5 | 0/5 | 0/5 | 3/5 |
85
+ | Zep | — | — | — | — | — | — | — | — | — |
86
+ | A-Mem | 3/5 | 3/5 | 5/5 | 0/5 | 2/5 | 4/5 | 0/5 | 0/5 | 4/5 |
87
+ | MemoryOS | 4/5 | 1/5 | 5/5 | 0/5 | 0/5 | 1/5 | 0/5 | 0/5 | 2/5 |
docs/GUIDE.md ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eval Framework 使用指南
2
+
3
+ ## 1. 整体架构
4
+
5
+ ```
6
+ eval_framework/
7
+ ├── cli.py # 入口:CLI 解析 + 三阶段编排 (Pipeline → Eval → Aggregate)
8
+ ├── config.py # EvalConfig 数据类
9
+ ├── openai_compat.py # GPT-5 系列 max_tokens→max_completion_tokens 兼容补丁
10
+ ├── datasets/
11
+ │ ├── schemas.py # 运行时共享数据结构 (NormalizedTurn, MemorySnapshotRecord, RetrievalRecord 等)
12
+ │ └── domain_a_v2.py # domain_a_v2 数据集加载器
13
+ ├── memory_adapters/
14
+ │ ├── base.py # MemoryAdapter 抽象基类 (7 个接口方法)
15
+ │ ├── registry.py # Baseline 注册表 + Mem-Gallery 默认配置覆盖
16
+ │ ├── memgallery_native.py # Mem-Gallery 11 种内置 baseline 的统一适配器
17
+ │ ├── amem.py # A-Mem 外部 baseline 适配器
18
+ │ ├── memoryos.py # MemoryOS 外部 baseline 适配器
19
+ │ └── export_utils.py # 快照/检索结果归一化工具
20
+ ├── pipeline/
21
+ │ ├── runner.py # 按 session 顺序喂入对话 → 生成 snapshot/delta → 触发 QA
22
+ │ ├── qa_runner.py # 对每个 checkpoint question 做 retrieve + answer
23
+ │ ├── gold_state.py # Gold memory points 累积构建
24
+ │ └── records.py # PipelineSessionRecord / PipelineCheckpointQARecord
25
+ ├── evaluators/
26
+ │ ├── extraction.py # Session 级评估:Recall + Correctness + Update + Interference
27
+ │ ├── qa.py # Checkpoint QA 评估:Answer 正确性 + Evidence 覆盖率
28
+ │ └── aggregate.py # 聚合所有 session/QA 评估到 baseline 级汇总指标
29
+ └── judges/
30
+ ├── llm_client.py # OpenAI 兼容 LLM 调用 + JSON 解析 + 重试 + 并发控制
31
+ └── prompts.py # 6 套 LLM judge prompt 模板
32
+ ```
33
+
34
+ ## 2. 运行流程
35
+
36
+ 整个 eval 分三个阶段(`cli.py: run_eval()`):
37
+
38
+ ### Stage 1 — Pipeline(串行,适配器有状态)
39
+
40
+ ```
41
+ for each sample:
42
+ adapter = create_adapter(baseline_name)
43
+ adapter.reset()
44
+ for each session in sample.sessions:
45
+ for each turn in session.turns:
46
+ adapter.ingest_turn(turn) # 喂入一条对话
47
+ adapter.end_session(session_id) # 触发 session 后处理(如 GA 反思、RF 优化)
48
+ snapshot = adapter.snapshot_memories() # 拍快照
49
+ delta = adapter.export_memory_delta() # 导出本 session 增量
50
+ → PipelineSessionRecord
51
+
52
+ # 当某个 checkpoint 的 covered_sessions 全部完成时触发 QA
53
+ for each question in checkpoint:
54
+ retrieval = adapter.retrieve(question, top_k=5)
55
+ answer = answer_fn(question, retrieval) # 可注入外部 LLM 回答
56
+ → PipelineCheckpointQARecord
57
+ ```
58
+
59
+ Pipeline 结束后写入 checkpoint 文件 `pipeline_sessions.jsonl` + `pipeline_qa.jsonl`,支持 `--eval-only` 跳过此阶段直接从 checkpoint 恢复。
60
+
61
+ ### Stage 2 — Eval(并行,ThreadPoolExecutor)
62
+
63
+ - **Session 评估**(`evaluators/extraction.py`)— 每个 session 4+ 次 LLM 调用:
64
+ 1. **Recall**:本 session 的 gold points 中有多少被 delta 覆盖?
65
+ 2. **Correctness**:每条 delta 记忆是 correct / hallucination / irrelevant?
66
+ 3. **Update handling**:每个 update gold point → updated / both / outdated
67
+ 4. **Interference rejection**:每个 interference gold point → rejected / memorized
68
+
69
+ - **QA 评估**(`evaluators/qa.py`)— 每个 question 2 次 LLM 调用:
70
+ 1. **Answer 正确性**:Correct / Hallucination / Omission
71
+ 2. **Evidence 覆盖率**:cited memories 覆盖了多少 gold evidence points
72
+
73
+ ### Stage 3 — Aggregate
74
+
75
+ 将所有 session 和 QA 级别的评估结果聚合为 6 个维度的 baseline 级指标:
76
+
77
+ | 维度 | 聚合方式 | 关键指标 |
78
+ |------|---------|---------|
79
+ | Memory Recall | 按 session 平均 | `avg_recall`, `avg_update_recall` |
80
+ | Memory Correctness | 按 session 平均 | `avg_correctness`, `avg_hallucination` |
81
+ | Update Handling | 跨 session 池化 | `score` (updated=1.0, both=0.5, outdated=0.0) |
82
+ | Interference Rejection | 跨 session 池化 | `score` (rejected/total) |
83
+ | Question Answering | 跨 question 池化 | `correct_ratio`, `hallucination_ratio`, `omission_ratio` |
84
+ | Evidence Coverage | 跨 question 池化 | `hit_rate` |
85
+
86
+ 输出文件:
87
+ - `session_records.jsonl` — 每条含 pipeline 数据 + eval 结果
88
+ - `qa_records.jsonl` — 同上
89
+ - `aggregate_metrics.json` — baseline 级汇总
90
+
91
+ ## 3. 支持的 Baselines
92
+
93
+ ### 3.1 Mem-Gallery 内置(11 种)
94
+
95
+ 通过 `MemGalleryNativeAdapter` 统一包装,需要在 `eval_framework/` 同级目录放置 `memengine/` 和 `default_config/`(从 Mem-Gallery 的 `benchmark/` 目录复制)。
96
+
97
+ | Baseline | 类型 | 特性 | 额外依赖 |
98
+ |----------|------|------|---------|
99
+ | `FUMemory` | text-only | 全量存储(FIFO 截断) | — |
100
+ | `STMemory` | text-only | 短期记忆 | — |
101
+ | `LTMemory` | text-only | 长期记忆,embedding 检索 | sentence-transformers |
102
+ | `GAMemory` | text-only | 带 importance judge + 自反思 | LLM API |
103
+ | `MGMemory` | text-only | 多层存储(working/FIFO/recall/archival) | LLM API, sentence-transformers |
104
+ | `RFMemory` | text-only | 带 reflection optimizer | LLM API |
105
+ | `MMMemory` | multimodal | 多模态记忆 | torch |
106
+ | `MMFUMemory` | multimodal | 多模态全量存储 | torch |
107
+ | `NGMemory` | multimodal | 知识图谱节点存储 | torch |
108
+ | `AUGUSTUSMemory` | multimodal | 概念抽取 + 图谱 | LLM API, torch |
109
+ | `UniversalRAGMemory` | multimodal | RAG routing + 存储 | LLM API |
110
+
111
+ ### 3.2 外部适配器
112
+
113
+ | Baseline | 来源 | 安装方式 | 需要外部服务 |
114
+ |----------|------|---------|-------------|
115
+ | `Mem0` | [mem0ai/mem0](https://github.com/mem0ai/mem0) | `pip install mem0ai` | 否(内置 Qdrant + SQLite) |
116
+ | `Mem0-Graph` | 同上(graph 模式) | `pip install "mem0ai[graph]"` | 需要 Neo4j |
117
+ | `SimpleMem` | [aiming-lab/SimpleMem](https://github.com/aiming-lab/SimpleMem) | clone + requirements | 否 |
118
+ | `Omni-SimpleMem` | 同上(omni 模式) | 同上 | 否 |
119
+ | `Zep` | [getzep/zep](https://github.com/getzep/zep) | `pip install zep-python` | 需要 Zep server |
120
+ | `A-Mem` | [A-Mem](https://arxiv.org/abs/2504.19413) | clone 源码 | 否 |
121
+ | `MemoryOS` | [MemoryOS](https://github.com/memodb-io/memobase) | clone 源码 | 否 |
122
+
123
+ **论文来源:**
124
+
125
+ | Baseline | 论文 | GitHub |
126
+ |----------|------|--------|
127
+ | Mem0 / Mem0-Graph | [arXiv:2504.19413](https://arxiv.org/abs/2504.19413) | https://github.com/mem0ai/mem0 |
128
+ | SimpleMem | [arXiv:2601.02553](https://arxiv.org/abs/2601.02553) | https://github.com/aiming-lab/SimpleMem |
129
+ | Omni-SimpleMem | [arXiv:2604.01007](https://arxiv.org/abs/2604.01007) | https://github.com/aiming-lab/SimpleMem |
130
+ | MemVerse | [arXiv:2512.03627](https://arxiv.org/abs/2512.03627) | https://github.com/KnowledgeXLab/MemVerse |
131
+ | Memobase | — | https://github.com/memodb-io/memobase |
132
+ | Supermemory | — | https://github.com/supermemoryai/supermemory |
133
+ | Zep | [arXiv:2501.13956](https://arxiv.org/abs/2501.13956) | https://github.com/getzep/zep |
134
+
135
+ ### 3.3 添加新 Baseline
136
+
137
+ 实现 `MemoryAdapter` 的 7 个抽象方法:
138
+
139
+ ```python
140
+ class MyAdapter(MemoryAdapter):
141
+ def reset(self) -> None: ...
142
+ def ingest_turn(self, turn: NormalizedTurn) -> None: ...
143
+ def end_session(self, session_id: str) -> None: ...
144
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]: ...
145
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: ...
146
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord: ...
147
+ def get_capabilities(self) -> dict[str, Any]: ...
148
+ ```
149
+
150
+ 然后在 `registry.py` 的 `EXTERNAL_ADAPTER_REGISTRY` 中注册。
151
+
152
+ ## 4. 数据适配
153
+
154
+ ### 4.1 数据集格式(domain_a_v2)
155
+
156
+ 加载器 `load_domain_a_v2_academic(data_dir)` 要求 `data_dir` 下有三个文件:
157
+
158
+ ```
159
+ data_dir/
160
+ ├── domain_a_v2.json # 主对话数据(JSON array)
161
+ ├── stage4_memory_points.jsonl # 每 session 的 gold memory points
162
+ └── stage4b_qa_checkpoints.jsonl # checkpoint QA 题目
163
+ ```
164
+
165
+ **`domain_a_v2.json`** 中每个 sample 结构:
166
+
167
+ ```json
168
+ {
169
+ "uuid": "unique-id",
170
+ "sample_id": "sample_001",
171
+ "sessions": [
172
+ {
173
+ "_v2_session_id": "S00",
174
+ "dialogue": [
175
+ {
176
+ "role": "user",
177
+ "content": "Hello...",
178
+ "timestamp": "2025-01-01T10:00:00",
179
+ "attachments": [{"caption": "photo of...", "type": "image_caption"}]
180
+ },
181
+ {"role": "assistant", "content": "Hi..."}
182
+ ],
183
+ "memory_points": [...] // 仅 S00 需要
184
+ },
185
+ {"_v2_session_id": "S01", "dialogue": [...]}
186
+ ]
187
+ }
188
+ ```
189
+
190
+ **`stage4_memory_points.jsonl`** 每行一个 sample:
191
+
192
+ ```json
193
+ {
194
+ "uuid": "...", "sample_id": "sample_001",
195
+ "memory_sessions": [
196
+ {
197
+ "session_id": "S01",
198
+ "memory_points": [
199
+ {
200
+ "memory_id": "m001",
201
+ "memory_content": "User prefers dark mode",
202
+ "memory_type": "preference",
203
+ "memory_source": "normal",
204
+ "is_update": false,
205
+ "original_memories": [],
206
+ "importance": 0.8
207
+ }
208
+ ]
209
+ }
210
+ ]
211
+ }
212
+ ```
213
+
214
+ **`stage4b_qa_checkpoints.jsonl`** 每行一个 sample:
215
+
216
+ ```json
217
+ {
218
+ "uuid": "...", "sample_id": "sample_001",
219
+ "checkpoints": [
220
+ {
221
+ "checkpoint_id": "cp01",
222
+ "covered_sessions": ["S00", "S01"],
223
+ "questions": [
224
+ {
225
+ "question": "What theme does the user prefer?",
226
+ "answer": "Dark mode",
227
+ "question_type": "preference_recall",
228
+ "question_type_abbrev": "pref",
229
+ "difficulty": "easy",
230
+ "evidence": [{"memory_id": "m001"}]
231
+ }
232
+ ]
233
+ }
234
+ ]
235
+ }
236
+ ```
237
+
238
+ ### 4.2 适配自有数据
239
+
240
+ 若要接入新数据源,有两条路径:
241
+
242
+ **路径 A:��换为 domain_a_v2 格式**(推荐)
243
+ - 将原始对话整理为上述三文件格式
244
+ - 直接使用现有 CLI 运行
245
+
246
+ **路径 B:编写新的 dataset loader**
247
+ - 在 `datasets/` 下新建加载器,返回 `DomainAV2AcademicBundle`(或等价结构)
248
+ - 在 `cli.py` 的 `run_eval()` 中通过 `load_domain_bundle` 参数注入
249
+
250
+ ### 4.3 关键数据结构
251
+
252
+ 每条对话 turn 会被归一化为 `NormalizedTurn`:
253
+
254
+ ```python
255
+ NormalizedTurn(
256
+ sample_id="sample_001",
257
+ session_id="S01",
258
+ turn_index=0,
259
+ role="user", # "user" | "assistant"
260
+ text="Hello...",
261
+ attachments=(Attachment(caption="...", type="image_caption"),),
262
+ timestamp="2025-01-01T10:00:00",
263
+ )
264
+ ```
265
+
266
+ Memory 的 gold 标注支持三种来源标记:
267
+ - `normal` — 正常记忆点
268
+ - `interference` — 干扰信息(不应被记忆)
269
+ - `is_update=True` — 更新型记忆(应替换旧记忆)
270
+
271
+ ## 5. 环境配置(uv)
272
+
273
+ ### 5.1 安装 uv
274
+
275
+ ```bash
276
+ curl -LsSf https://astral.sh/uv/install.sh | sh
277
+ ```
278
+
279
+ ### 5.2 初始化项目环境
280
+
281
+ ```bash
282
+ cd /data1/toby/nips26
283
+
284
+ # 创建虚拟环境
285
+ uv venv .venv --python 3.11
286
+ source .venv/bin/activate
287
+ ```
288
+
289
+ ### 5.3 安装核心依赖
290
+
291
+ ```bash
292
+ # 最小依赖(可跑 FUMemory/STMemory 等纯文本 baseline)
293
+ uv pip install openai tenacity
294
+
295
+ # embedding 检索类 baseline(LTMemory, GAMemory, MGMemory 等)
296
+ uv pip install sentence-transformers
297
+
298
+ # 多模态 baseline(MMMemory, NGMemory, AUGUSTUSMemory 等)
299
+ uv pip install torch torchvision transformers
300
+
301
+ # 外部 baseline(A-Mem, MemoryOS)— 按各自文档安装额外依赖
302
+ # A-Mem 需要其源码目录下的 requirements
303
+ # MemoryOS 需要 memoryos 包
304
+ ```
305
+
306
+ ### 5.4 环境变量(.env 文件)
307
+
308
+ 在项目根目录 (`nips26/`) 创建 `.env` 文件,框架会自动加载:
309
+
310
+ ```bash
311
+ # .env
312
+ # 必需 — LLM API(pipeline 答题 + judge 评估统一使用)
313
+ OPENAI_API_KEY=sk-...
314
+ OPENAI_BASE_URL=https://api.openai.com/v1 # 或兼容端点
315
+ OPENAI_MODEL=gpt-4o
316
+
317
+ # 可选
318
+ OPENAI_TEMPERATURE=0.0
319
+ OPENAI_MAX_TOKENS=1024
320
+ OPENAI_TIMEOUT=120
321
+ JUDGE_TEMPERATURE=0.0 # judge 专用温度
322
+ LLM_MAX_CONCURRENT=5 # LLM 并发上限
323
+ ```
324
+
325
+ ### 5.5 Mem-Gallery 本地依赖
326
+
327
+ Mem-Gallery 内置 baseline 需要将其源码放到 `eval_framework/` 的同级目录:
328
+
329
+ ```bash
330
+ # 假设 Mem-Gallery repo 在 /path/to/Mem-Gallery
331
+ cp -r /path/to/Mem-Gallery/benchmark/memengine /data1/toby/nips26/
332
+ cp -r /path/to/Mem-Gallery/benchmark/default_config /data1/toby/nips26/
333
+ ```
334
+
335
+ 最终目录结构应为:
336
+
337
+ ```
338
+ nips26/
339
+ ├── eval_framework/
340
+ ├── memengine/ # Mem-Gallery 记忆引擎
341
+ └── default_config/ # Mem-Gallery 默认配置
342
+ ```
343
+
344
+ ## 6. 运行示例
345
+
346
+ ### 基本运行
347
+
348
+ ```bash
349
+ # 运行单个 baseline
350
+ python -m eval_framework.cli \
351
+ --dataset /path/to/domain_a_v2_data/ \
352
+ --baseline FUMemory \
353
+ --output-dir eval_framework/results/FUMemory
354
+
355
+ # smoke 模式(只跑第 1 个 sample,快速验证)
356
+ python -m eval_framework.cli \
357
+ --dataset /path/to/domain_a_v2_data/ \
358
+ --baseline FUMemory \
359
+ --output-dir eval_framework/results/FUMemory_smoke \
360
+ --smoke
361
+
362
+ # dry-run(不实际运行,打印配置)
363
+ python -m eval_framework.cli \
364
+ --dataset /path/to/domain_a_v2_data/ \
365
+ --baseline FUMemory \
366
+ --dry-run
367
+
368
+ # 仅重跑 eval 阶段(从 checkpoint 恢复,pipeline 不重跑)
369
+ python -m eval_framework.cli \
370
+ --dataset /path/to/domain_a_v2_data/ \
371
+ --baseline FUMemory \
372
+ --output-dir eval_framework/results/FUMemory \
373
+ --eval-only
374
+
375
+ # 调整 eval 并发数
376
+ python -m eval_framework.cli \
377
+ --dataset /path/to/domain_a_v2_data/ \
378
+ --baseline MGMemory \
379
+ --output-dir eval_framework/results/MGMemory \
380
+ --max-eval-workers 10
381
+ ```
382
+
383
+ ### 批量跑所有 baseline
384
+
385
+ ```bash
386
+ DATASET="/path/to/domain_a_v2_data"
387
+ for baseline in FUMemory STMemory LTMemory GAMemory MGMemory RFMemory A-Mem MemoryOS; do
388
+ echo "=== Running $baseline ==="
389
+ python -m eval_framework.cli \
390
+ --dataset "$DATASET" \
391
+ --baseline "$baseline" \
392
+ --output-dir "eval_framework/results/$baseline"
393
+ done
394
+ ```
395
+
396
+ ### 输出文件说明
397
+
398
+ 运行完成后 `output-dir` 下包含:
399
+
400
+ ```
401
+ results/FUMemory/
402
+ ├── pipeline_sessions.jsonl # Stage 1 checkpoint — session 级 pipeline 结果
403
+ ├── pipeline_qa.jsonl # Stage 1 checkpoint — QA 级 pipeline 结果
404
+ ├── session_records.jsonl # 最终 session 结果(含 eval)
405
+ ├── qa_records.jsonl # 最终 QA 结果(含 eval)
406
+ └── aggregate_metrics.json # baseline 级汇总指标
407
+ ```
408
+
409
+ ## 7. LLM API 开销估算
410
+
411
+ 每个 sample 的 LLM 调用量:
412
+
413
+ | 来源 | 调用次数 |
414
+ |------|---------|
415
+ | Pipeline answer(每个 QA question) | N_questions |
416
+ | Session Recall judge | N_sessions |
417
+ | Session Correctness judge | N_sessions |
418
+ | Update judge | N_update_points(逐条) |
419
+ | Interference judge | N_interference_points(逐条) |
420
+ | QA Answer judge | N_questions |
421
+ | QA Evidence judge | N_questions |
422
+
423
+ 典型场景下一个 sample 约 20-50 次 LLM 调用。通过 `LLM_MAX_CONCURRENT` 控制并发避免 rate limit。
docs/OUTPUT_FORMAT.md ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eval Framework 输出格式
2
+
3
+ ## 输出目录结构
4
+
5
+ 运行完成后 `--output-dir` 下包含 5 个文件:
6
+
7
+ ```
8
+ output-dir/
9
+ ├── pipeline_sessions.jsonl # Stage 1 checkpoint — pipeline 中间结果(session 级)
10
+ ├── pipeline_qa.jsonl # Stage 1 checkpoint — pipeline 中间结果(QA 级)
11
+ ├── session_records.jsonl # 最终结果:session pipeline 数据 + eval 评判
12
+ ├── qa_records.jsonl # 最终结果:QA pipeline 数据 + eval 评判
13
+ └── aggregate_metrics.json # 最终结果:baseline 级别汇总指标
14
+ ```
15
+
16
+ ## 文件详解
17
+
18
+ ### 1. `session_records.jsonl`
19
+
20
+ 每行一个 session,包含 pipeline 原始数据和 `eval` 评判结果:
21
+
22
+ ```json
23
+ {
24
+ "sample_id": "vab_minecraft_...",
25
+ "sample_uuid": "uuid-...",
26
+ "session_id": "S01",
27
+ "memory_snapshot": [
28
+ {
29
+ "memory_id": "3",
30
+ "text": "user: OBSERVATION: ...\nassistant: THOUGHT: ...",
31
+ "session_id": "S01",
32
+ "status": "active",
33
+ "source": "FUMemory",
34
+ "raw_backend_id": "3",
35
+ "raw_backend_type": "linear",
36
+ "metadata": {}
37
+ }
38
+ ],
39
+ "memory_delta": [
40
+ {
41
+ "session_id": "S01",
42
+ "op": "add",
43
+ "text": "user: OBSERVATION: ...",
44
+ "linked_previous": [],
45
+ "raw_backend_id": "3",
46
+ "metadata": {"baseline": "FUMemory"}
47
+ }
48
+ ],
49
+ "gold_state": {
50
+ "session_id": "S01",
51
+ "cumulative_gold_memories": [...],
52
+ "session_new_memories": [...],
53
+ "session_update_memories": [...],
54
+ "session_interference_memories": []
55
+ },
56
+ "eval": {
57
+ "session_id": "S01",
58
+ "recall": 0.8,
59
+ "covered_count": 4,
60
+ "num_gold": 5,
61
+ "update_recall": 1.0,
62
+ "update_covered_count": 2,
63
+ "update_total": 2,
64
+ "recall_reasoning": "4 of 5 gold points are covered...",
65
+ "correctness_rate": 0.75,
66
+ "num_memories": 8,
67
+ "num_correct": 6,
68
+ "num_hallucination": 1,
69
+ "num_irrelevant": 1,
70
+ "correctness_reasoning": "...",
71
+ "correctness_records": [
72
+ {"id": 1, "label": "correct"},
73
+ {"id": 2, "label": "hallucination"}
74
+ ],
75
+ "update_score": 1.0,
76
+ "update_num_updated": 2,
77
+ "update_num_both": 0,
78
+ "update_num_outdated": 0,
79
+ "update_total_items": 2,
80
+ "update_records": [
81
+ {"memory_id": "mp_S08_3", "label": "updated", "reasoning": "..."}
82
+ ],
83
+ "interference_score": null,
84
+ "interference_num_rejected": 0,
85
+ "interference_num_memorized": 0,
86
+ "interference_total_items": 0,
87
+ "interference_records": []
88
+ }
89
+ }
90
+ ```
91
+
92
+ **eval 字段说明:**
93
+
94
+ | 字段 | 含义 |
95
+ |------|------|
96
+ | `recall` | 本 session gold points 被 delta 覆盖的比例 (0-1) |
97
+ | `update_recall` | update 类型 gold points 的覆盖比例 |
98
+ | `correctness_rate` | delta 中正确记忆的比例 |
99
+ | `num_hallucination` | delta 中幻觉记忆数量 |
100
+ | `num_irrelevant` | delta 中无关记忆数量 |
101
+ | `update_score` | 更新处理得分 (updated=1.0, both=0.5, outdated=0.0) |
102
+ | `interference_score` | 干扰拒绝得分 (rejected=1.0, memorized=0.0) |
103
+
104
+ ### 2. `qa_records.jsonl`
105
+
106
+ 每行一个 QA question,包含检索结果、模型回答和评判:
107
+
108
+ ```json
109
+ {
110
+ "sample_id": "vab_minecraft_...",
111
+ "sample_uuid": "uuid-...",
112
+ "checkpoint_id": "probe_e980c238",
113
+ "question": "What was in the agent's inventory at step 1?",
114
+ "gold_answer": "At step 1, the agent's inventory was empty.",
115
+ "gold_evidence_memory_ids": ["mp_S04_1"],
116
+ "gold_evidence_contents": ["The agent started with empty inventory"],
117
+ "question_type": "factual_recall",
118
+ "question_type_abbrev": "FR",
119
+ "difficulty": "easy",
120
+ "retrieval": {
121
+ "query": "What was in the agent's inventory at step 1?",
122
+ "top_k": 5,
123
+ "items": [
124
+ {
125
+ "rank": 0,
126
+ "memory_id": "memgallery:string_bundle",
127
+ "text": "user: OBSERVATION: Your Inventory: ...",
128
+ "score": 1.0,
129
+ "raw_backend_id": null
130
+ }
131
+ ],
132
+ "raw_trace": {"baseline": "FUMemory"}
133
+ },
134
+ "generated_answer": "The agent's inventory was empty at step 1.",
135
+ "cited_memories": ["user: OBSERVATION: Inventory: nothing"],
136
+ "eval": {
137
+ "answer_label": "Correct",
138
+ "answer_reasoning": "The response matches the reference answer...",
139
+ "answer_is_valid": true,
140
+ "evidence_hit_rate": 1.0,
141
+ "evidence_covered_count": 1,
142
+ "num_evidence": 1,
143
+ "evidence_reasoning": "The cited memory covers the gold evidence...",
144
+ "num_cited_memories": 1
145
+ }
146
+ }
147
+ ```
148
+
149
+ **eval 字段说明:**
150
+
151
+ | 字段 | 含义 |
152
+ |------|------|
153
+ | `answer_label` | `Correct` / `Hallucination` / `Omission` |
154
+ | `answer_is_valid` | 评判是否成功(非 LLM 错误) |
155
+ | `evidence_hit_rate` | cited memories 覆盖了多少 gold evidence (0-1) |
156
+ | `evidence_covered_count` | 被覆盖的 gold evidence 数量 |
157
+ | `num_cited_memories` | 模型回答时引用的记忆条数 |
158
+
159
+ ### 3. `aggregate_metrics.json`
160
+
161
+ baseline 级别的 6 维汇总指标:
162
+
163
+ ```json
164
+ {
165
+ "baseline_id": "FUMemory",
166
+ "memory_recall": {
167
+ "avg_recall": 0.72,
168
+ "avg_update_recall": 0.65,
169
+ "num_sessions_with_recall": 110,
170
+ "num_sessions_with_update": 85,
171
+ "total_covered": 320,
172
+ "total_gold": 445
173
+ },
174
+ "memory_correctness": {
175
+ "avg_correctness": 0.81,
176
+ "avg_hallucination": 0.08,
177
+ "avg_irrelevant": 0.11,
178
+ "num_sessions": 110,
179
+ "total_memories": 1200,
180
+ "total_correct": 972,
181
+ "total_hallucination": 96,
182
+ "total_irrelevant": 132
183
+ },
184
+ "update_handling": {
185
+ "score": 0.65,
186
+ "num_updated": 52,
187
+ "num_both": 18,
188
+ "num_outdated": 15,
189
+ "num_total": 85
190
+ },
191
+ "interference_rejection": {
192
+ "score": 0.0,
193
+ "num_rejected": 0,
194
+ "num_memorized": 0,
195
+ "num_total": 0
196
+ },
197
+ "question_answering": {
198
+ "correct_ratio": 0.58,
199
+ "hallucination_ratio": 0.22,
200
+ "omission_ratio": 0.20,
201
+ "num_total": 990,
202
+ "num_valid": 990
203
+ },
204
+ "evidence_coverage": {
205
+ "hit_rate": 0.43,
206
+ "num_covered": 425,
207
+ "num_total": 990
208
+ }
209
+ }
210
+ ```
211
+
212
+ **6 个维度:**
213
+
214
+ | 维度 | 聚合方式 | 核心指标 | 方向 |
215
+ |------|---------|---------|------|
216
+ | Memory Recall | 按 session 平均 | `avg_recall` | ↑ |
217
+ | Memory Correctness | 按 session 平均 | `avg_correctness`, `avg_hallucination` | ↑, ↓ |
218
+ | Update Handling | 跨 session 池化 | `score` | ↑ |
219
+ | Interference Rejection | 跨 session 池化 | `score` | ↑ |
220
+ | Question Answering | 跨 question 池化 | `correct_ratio`, `hallucination_ratio` | ↑, ↓ |
221
+ | Evidence Coverage | 跨 question 池化 | `hit_rate` | ↑ |
222
+
223
+ ### 4. `pipeline_sessions.jsonl` / `pipeline_qa.jsonl`
224
+
225
+ Stage 1 的 checkpoint 文件,结构与 `session_records.jsonl` / `qa_records.jsonl` 相同但**不含 `eval` 字段**。
226
+
227
+ 用途:`--eval-only` 模式跳过 pipeline 直接从 checkpoint 恢复,只重跑 eval 阶段。典型场景:
228
+
229
+ ```bash
230
+ # 首次完整运行
231
+ python -m eval_framework.cli --dataset ... --baseline FUMemory --output-dir results/FU
232
+
233
+ # 换 judge 模型重评(不重跑 pipeline)
234
+ OPENAI_MODEL=gpt-4o-mini python -m eval_framework.cli \
235
+ --dataset ... --baseline FUMemory --output-dir results/FU --eval-only
236
+ ```
237
+
238
+ ## 结果分析示例
239
+
240
+ ```python
241
+ import json
242
+
243
+ # 读取汇总
244
+ with open("results/FUMemory/aggregate_metrics.json") as f:
245
+ agg = json.load(f)
246
+ print(f"Recall: {agg['memory_recall']['avg_recall']:.2%}")
247
+ print(f"QA Correct: {agg['question_answering']['correct_ratio']:.2%}")
248
+
249
+ # 按 QA type 分析正确率
250
+ qa_by_type = {}
251
+ with open("results/FUMemory/qa_records.jsonl") as f:
252
+ for line in f:
253
+ rec = json.loads(line)
254
+ qt = rec["question_type_abbrev"]
255
+ label = rec["eval"]["answer_label"]
256
+ qa_by_type.setdefault(qt, []).append(label)
257
+
258
+ for qt, labels in sorted(qa_by_type.items()):
259
+ correct = sum(1 for l in labels if l == "Correct")
260
+ print(f" {qt}: {correct}/{len(labels)} = {correct/len(labels):.0%}")
261
+ ```
evaluators/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session- and checkpoint-level evaluators using batch LLM judge."""
2
+
3
+ from eval_framework.evaluators.aggregate import aggregate_metrics
4
+ from eval_framework.evaluators.extraction import evaluate_extraction
5
+ from eval_framework.evaluators.qa import evaluate_checkpoint_qa
6
+
7
+ __all__ = [
8
+ "aggregate_metrics",
9
+ "evaluate_checkpoint_qa",
10
+ "evaluate_extraction",
11
+ ]
evaluators/aggregate.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Roll up per-session and per-QA evaluations into baseline-level summaries.
2
+
3
+ Recall & correctness: per-session average (not pooled cumulative).
4
+ Interference: pooled across sessions.
5
+ QA & evidence: pooled across questions.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import Mapping, Sequence
11
+
12
+
13
+ def _safe_div(a: float, b: float) -> float:
14
+ return a / b if b else 0.0
15
+
16
+
17
+ def aggregate_metrics(
18
+ baseline_id: str,
19
+ *,
20
+ session_evaluations: Sequence[Mapping[str, object]] = (),
21
+ qa_evaluations: Sequence[Mapping[str, object]] = (),
22
+ ) -> dict[str, object]:
23
+ """Aggregate all per-session and per-QA evaluations."""
24
+
25
+ # --- Per-session recall (average) ---
26
+ recall_scores: list[float] = []
27
+ update_recall_scores: list[float] = []
28
+
29
+ # --- Per-session correctness (average) ---
30
+ correctness_scores: list[float] = []
31
+ hallucination_scores: list[float] = []
32
+ irrelevant_scores: list[float] = []
33
+
34
+ # --- Update handling (pooled) ---
35
+ upd_num_updated = 0
36
+ upd_num_both = 0
37
+ upd_num_outdated = 0
38
+ upd_total_items = 0
39
+
40
+ # --- Interference rejection (pooled) ---
41
+ interf_num_rejected = 0
42
+ interf_num_memorized = 0
43
+ interf_total_items = 0
44
+
45
+ # --- Per-session detail counters (for reference) ---
46
+ total_gold_points = 0
47
+ total_covered = 0
48
+ total_memories = 0
49
+ total_correct = 0
50
+ total_hallucination = 0
51
+ total_irrelevant = 0
52
+
53
+ for s in session_evaluations:
54
+ # Recall: per-session score
55
+ r = s.get("recall")
56
+ if r is not None:
57
+ recall_scores.append(float(r))
58
+
59
+ ur = s.get("update_recall")
60
+ if ur is not None:
61
+ update_recall_scores.append(float(ur))
62
+
63
+ # Correctness: per-session score
64
+ cr = s.get("correctness_rate")
65
+ if cr is not None:
66
+ correctness_scores.append(float(cr))
67
+
68
+ nm = int(s.get("num_memories", 0))
69
+ if nm > 0:
70
+ hallucination_scores.append(
71
+ float(s.get("num_hallucination", 0)) / nm
72
+ )
73
+ irrelevant_scores.append(
74
+ float(s.get("num_irrelevant", 0)) / nm
75
+ )
76
+
77
+ # Detail counters
78
+ c = s.get("covered_count")
79
+ if c is not None:
80
+ total_covered += int(c)
81
+ total_gold_points += int(s.get("num_gold", 0))
82
+ total_memories += nm
83
+ total_correct += int(s.get("num_correct", 0))
84
+ total_hallucination += int(s.get("num_hallucination", 0))
85
+ total_irrelevant += int(s.get("num_irrelevant", 0))
86
+
87
+ # Update handling (pooled)
88
+ upd_num_updated += int(s.get("update_num_updated", 0))
89
+ upd_num_both += int(s.get("update_num_both", 0))
90
+ upd_num_outdated += int(s.get("update_num_outdated", 0))
91
+ upd_total_items += int(s.get("update_total_items", 0))
92
+
93
+ # Interference rejection (pooled)
94
+ interf_num_rejected += int(s.get("interference_num_rejected", 0))
95
+ interf_num_memorized += int(s.get("interference_num_memorized", 0))
96
+ interf_total_items += int(s.get("interference_total_items", 0))
97
+
98
+ # --- QA (pooled) ---
99
+ qa_total = 0
100
+ qa_valid = 0
101
+ qa_correct = 0
102
+ qa_hallucination = 0
103
+ qa_omission = 0
104
+ evidence_covered = 0
105
+ evidence_total = 0
106
+
107
+ for q in qa_evaluations:
108
+ qa_total += 1
109
+ label = q.get("answer_label")
110
+ if label in ("Correct", "Hallucination", "Omission"):
111
+ qa_valid += 1
112
+ if label == "Correct":
113
+ qa_correct += 1
114
+ elif label == "Hallucination":
115
+ qa_hallucination += 1
116
+ elif label == "Omission":
117
+ qa_omission += 1
118
+
119
+ ec = q.get("evidence_covered_count")
120
+ if ec is not None:
121
+ evidence_covered += int(ec)
122
+ evidence_total += int(q.get("num_evidence", 0))
123
+
124
+ n_recall = len(recall_scores)
125
+ n_update = len(update_recall_scores)
126
+ n_correct = len(correctness_scores)
127
+ n_hallu = len(hallucination_scores)
128
+ n_irrel = len(irrelevant_scores)
129
+
130
+ return {
131
+ "baseline_id": baseline_id,
132
+ "memory_recall": {
133
+ "avg_recall": _safe_div(sum(recall_scores), n_recall),
134
+ "avg_update_recall": _safe_div(sum(update_recall_scores), n_update),
135
+ "num_sessions_with_recall": n_recall,
136
+ "num_sessions_with_update": n_update,
137
+ "total_covered": total_covered,
138
+ "total_gold": total_gold_points,
139
+ },
140
+ "memory_correctness": {
141
+ "avg_correctness": _safe_div(sum(correctness_scores), n_correct),
142
+ "avg_hallucination": _safe_div(sum(hallucination_scores), n_hallu),
143
+ "avg_irrelevant": _safe_div(sum(irrelevant_scores), n_irrel),
144
+ "num_sessions": n_correct,
145
+ "total_memories": total_memories,
146
+ "total_correct": total_correct,
147
+ "total_hallucination": total_hallucination,
148
+ "total_irrelevant": total_irrelevant,
149
+ },
150
+ "update_handling": {
151
+ "score": _safe_div(upd_num_updated * 1.0 + upd_num_both * 0.5, upd_total_items),
152
+ "num_updated": upd_num_updated,
153
+ "num_both": upd_num_both,
154
+ "num_outdated": upd_num_outdated,
155
+ "num_total": upd_total_items,
156
+ },
157
+ "interference_rejection": {
158
+ "score": _safe_div(interf_num_rejected, interf_total_items),
159
+ "num_rejected": interf_num_rejected,
160
+ "num_memorized": interf_num_memorized,
161
+ "num_total": interf_total_items,
162
+ },
163
+ "question_answering": {
164
+ "correct_ratio": _safe_div(qa_correct, qa_valid),
165
+ "hallucination_ratio": _safe_div(qa_hallucination, qa_valid),
166
+ "omission_ratio": _safe_div(qa_omission, qa_valid),
167
+ "num_total": qa_total,
168
+ "num_valid": qa_valid,
169
+ },
170
+ "evidence_coverage": {
171
+ "hit_rate": _safe_div(evidence_covered, evidence_total),
172
+ "num_covered": evidence_covered,
173
+ "num_total": evidence_total,
174
+ },
175
+ }
evaluators/extraction.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified session evaluation: recall + correctness (includes update & interference).
2
+
3
+ Per session, 2 LLM calls — both scoped to THIS SESSION's memory delta only:
4
+ Call 1 — Recall: how many of this session's gold points are covered by the
5
+ session's memory delta (add/update ops)?
6
+ Call 2 — Correctness: is each delta memory correct, hallucinated, or irrelevant?
7
+ (reference = this session's gold points + interference)
8
+
9
+ Aggregate: per-session recall/correctness averaged across sessions.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from eval_framework.judges import (
15
+ evaluate_correctness_batch,
16
+ evaluate_interference_single,
17
+ evaluate_recall_batch,
18
+ evaluate_update_single,
19
+ )
20
+ from eval_framework.pipeline.records import PipelineSessionRecord
21
+
22
+
23
+ def _delta_to_text(session: PipelineSessionRecord) -> str:
24
+ """Only the memories added or updated in THIS session (not the full snapshot)."""
25
+ lines: list[str] = []
26
+ idx = 0
27
+ for d in session.memory_delta:
28
+ if d.op in ("add", "update"):
29
+ idx += 1
30
+ lines.append(f"[{idx}] {d.text}")
31
+ return "\n".join(lines)
32
+
33
+
34
+ def _delta_texts(session: PipelineSessionRecord) -> list[str]:
35
+ """Text list of memories added or updated in THIS session."""
36
+ return [d.text for d in session.memory_delta if d.op in ("add", "update")]
37
+
38
+
39
+ def _build_recall_gold_points(session: PipelineSessionRecord) -> list[str]:
40
+ """Current session's new + update gold points only (NOT cumulative)."""
41
+ out: list[str] = []
42
+ for g in session.gold_state.session_new_memories:
43
+ out.append(f"[normal] {g.memory_content}")
44
+ for g in session.gold_state.session_update_memories:
45
+ out.append(f"[update] {g.memory_content}")
46
+ return out
47
+
48
+
49
+ def _build_correctness_gold_points(session: PipelineSessionRecord) -> list[str]:
50
+ """Current session's new + update + interference gold points as reference."""
51
+ out: list[str] = []
52
+ for g in session.gold_state.session_new_memories:
53
+ out.append(f"[normal] {g.memory_content}")
54
+ for g in session.gold_state.session_update_memories:
55
+ out.append(f"[update] {g.memory_content}")
56
+ for g in session.gold_state.session_interference_memories:
57
+ out.append(f"[interference] {g.memory_content}")
58
+ return out
59
+
60
+
61
+ def evaluate_extraction(
62
+ session: PipelineSessionRecord,
63
+ **_kwargs: object,
64
+ ) -> dict[str, object]:
65
+ """Unified session evaluation: recall + correctness in 2 LLM calls.
66
+
67
+ Uses only THIS session's new gold points for recall and correctness,
68
+ not the cumulative history. Aggregate averages per-session scores.
69
+ """
70
+
71
+ delta_str = _delta_to_text(session)
72
+ delta_texts = _delta_texts(session)
73
+ interference_total = len(session.gold_state.session_interference_memories)
74
+
75
+ # --- Call 1: Recall (this session's gold points vs this session's delta) ---
76
+ recall_gold = _build_recall_gold_points(session)
77
+
78
+ if not recall_gold:
79
+ recall = None
80
+ update_recall = None
81
+ recall_result: dict[str, object] = {
82
+ "covered_count": 0, "update_covered_count": 0,
83
+ "total": 0, "update_total": 0,
84
+ "reasoning": "No new gold points in this session.",
85
+ }
86
+ elif not delta_str.strip():
87
+ recall = 0.0
88
+ update_recall = 0.0
89
+ update_total = sum(1 for p in recall_gold if p.startswith("[update]"))
90
+ recall_result = {
91
+ "covered_count": 0, "update_covered_count": 0,
92
+ "total": len(recall_gold), "update_total": update_total,
93
+ "reasoning": "No add/update memories in this session's delta.",
94
+ }
95
+ else:
96
+ recall_result = evaluate_recall_batch(delta_str, recall_gold)
97
+
98
+ covered = recall_result.get("covered_count")
99
+ upd_covered = recall_result.get("update_covered_count")
100
+ total_gold = recall_result.get("total", len(recall_gold))
101
+ upd_total = recall_result.get("update_total", 0)
102
+
103
+ if recall_gold:
104
+ recall = float(covered) / float(total_gold) if covered is not None and total_gold else None
105
+ update_recall = float(upd_covered) / float(upd_total) if upd_covered is not None and upd_total else None
106
+
107
+ # --- Call 2: Correctness (this session's delta memories, reference = this session's golds) ---
108
+ correctness_gold = _build_correctness_gold_points(session)
109
+ correctness_result = evaluate_correctness_batch(delta_texts, correctness_gold, interference_total)
110
+ correctness_records = correctness_result.get("results", [])
111
+
112
+ num_correct = sum(1 for r in correctness_records if r.get("label") == "correct")
113
+ num_hallucination = sum(1 for r in correctness_records if r.get("label") == "hallucination")
114
+ num_irrelevant = sum(1 for r in correctness_records if r.get("label") == "irrelevant")
115
+ num_memories = len(delta_texts)
116
+ correctness_rate = float(num_correct) / float(num_memories) if num_memories else 0.0
117
+
118
+ # --- Call 3+: Update handling (one LLM call per update gold point) ---
119
+ update_records: list[dict[str, object]] = []
120
+ for g in session.gold_state.session_update_memories:
121
+ res = evaluate_update_single(
122
+ delta_str,
123
+ new_content=g.memory_content,
124
+ old_contents=list(g.original_memories),
125
+ )
126
+ update_records.append({
127
+ "memory_id": g.memory_id,
128
+ "label": res["label"],
129
+ "reasoning": res["reasoning"],
130
+ })
131
+
132
+ num_updated = sum(1 for r in update_records if r["label"] == "updated")
133
+ num_both = sum(1 for r in update_records if r["label"] == "both")
134
+ num_outdated = sum(1 for r in update_records if r["label"] == "outdated")
135
+ update_total_items = len(update_records)
136
+ # Score: updated=1.0, both=0.5, outdated=0.0
137
+ update_score = (
138
+ (num_updated * 1.0 + num_both * 0.5) / update_total_items
139
+ if update_total_items else None
140
+ )
141
+
142
+ # --- Call 4+: Interference rejection (one LLM call per interference gold point) ---
143
+ interference_records: list[dict[str, object]] = []
144
+ for g in session.gold_state.session_interference_memories:
145
+ res = evaluate_interference_single(
146
+ delta_str,
147
+ interference_content=g.memory_content,
148
+ )
149
+ interference_records.append({
150
+ "memory_id": g.memory_id,
151
+ "label": res["label"],
152
+ "reasoning": res["reasoning"],
153
+ })
154
+
155
+ num_rejected = sum(1 for r in interference_records if r["label"] == "rejected")
156
+ num_memorized = sum(1 for r in interference_records if r["label"] == "memorized")
157
+ interference_total_items = len(interference_records)
158
+ # Score: rejected=1.0, memorized=0.0
159
+ interference_score = (
160
+ float(num_rejected) / interference_total_items
161
+ if interference_total_items else None
162
+ )
163
+
164
+ return {
165
+ "session_id": session.session_id,
166
+ "recall": recall,
167
+ "covered_count": covered,
168
+ "num_gold": total_gold,
169
+ "update_recall": update_recall,
170
+ "update_covered_count": upd_covered,
171
+ "update_total": upd_total,
172
+ "recall_reasoning": recall_result.get("reasoning", ""),
173
+ "correctness_rate": correctness_rate,
174
+ "num_memories": num_memories,
175
+ "num_correct": num_correct,
176
+ "num_hallucination": num_hallucination,
177
+ "num_irrelevant": num_irrelevant,
178
+ "correctness_reasoning": correctness_result.get("reasoning", ""),
179
+ "correctness_records": correctness_records,
180
+ # Update handling
181
+ "update_score": update_score,
182
+ "update_num_updated": num_updated,
183
+ "update_num_both": num_both,
184
+ "update_num_outdated": num_outdated,
185
+ "update_total_items": update_total_items,
186
+ "update_records": update_records,
187
+ # Interference rejection
188
+ "interference_score": interference_score,
189
+ "interference_num_rejected": num_rejected,
190
+ "interference_num_memorized": num_memorized,
191
+ "interference_total_items": interference_total_items,
192
+ "interference_records": interference_records,
193
+ }
evaluators/qa.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Checkpoint QA evaluation: answer quality + batch evidence coverage.
2
+
3
+ Two dimensions:
4
+ 1. Answer evaluation: Correct / Hallucination / Omission (1 LLM call)
5
+ 2. Evidence coverage: how many gold evidence points are covered by the
6
+ memories the model actually *cited* when answering? (1 LLM call)
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from eval_framework.judges import evaluate_evidence_batch, evaluate_qa_llm
12
+ from eval_framework.pipeline.records import PipelineCheckpointQARecord
13
+
14
+
15
+ def evaluate_checkpoint_qa(
16
+ record: PipelineCheckpointQARecord,
17
+ **_kwargs: object,
18
+ ) -> dict[str, object]:
19
+ """LLM-judged QA evaluation: answer correctness + evidence coverage."""
20
+
21
+ # --- Build cited-memories text (what the model actually used) ---
22
+ if record.cited_memories:
23
+ cited_lines = [f"[{i + 1}] {m}" for i, m in enumerate(record.cited_memories)]
24
+ cited_str = "\n".join(cited_lines)
25
+ else:
26
+ # Fallback: use full retrieval (legacy records without cited_memories)
27
+ cited_lines = [f"[{item.rank}] {item.text}" for item in record.retrieval.items]
28
+ cited_str = "\n".join(cited_lines) if cited_lines else ""
29
+
30
+ # --- Answer evaluation (1 LLM call, unchanged) ---
31
+ gold_evidence_str = (
32
+ "\n".join(record.gold_evidence_contents)
33
+ if record.gold_evidence_contents
34
+ else "No evidence available."
35
+ )
36
+ answer_result = evaluate_qa_llm(
37
+ question=record.question,
38
+ reference_answer=record.gold_answer,
39
+ key_memory_points=gold_evidence_str,
40
+ system_response=record.generated_answer,
41
+ )
42
+ answer_label = answer_result.get("evaluation_result")
43
+
44
+ # --- Evidence coverage (1 LLM call, batch) ---
45
+ # Only check against cited memories, not the full retrieval
46
+ gold_contents = list(record.gold_evidence_contents)
47
+ evidence_result: dict[str, object] = {
48
+ "covered_count": 0, "total": len(gold_contents), "reasoning": ""
49
+ }
50
+
51
+ if gold_contents and cited_str.strip():
52
+ evidence_result = evaluate_evidence_batch(cited_str, gold_contents)
53
+
54
+ covered = evidence_result.get("covered_count")
55
+ total_ev = evidence_result.get("total", len(gold_contents))
56
+ if covered is not None and total_ev:
57
+ evidence_hit_rate = float(covered) / float(total_ev)
58
+ else:
59
+ evidence_hit_rate = 0.0
60
+
61
+ return {
62
+ "answer_label": answer_label,
63
+ "answer_reasoning": answer_result.get("reasoning", ""),
64
+ "answer_is_valid": answer_label in ("Correct", "Hallucination", "Omission"),
65
+ "evidence_hit_rate": evidence_hit_rate,
66
+ "evidence_covered_count": covered,
67
+ "num_evidence": total_ev,
68
+ "evidence_reasoning": evidence_result.get("reasoning", ""),
69
+ "num_cited_memories": len(record.cited_memories),
70
+ }
judges/__init__.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Judge stack: batch LLM evaluation.
2
+
3
+ Session: 2 calls (recall + correctness) + per-item calls for update/interference.
4
+ QA: 2 calls (answer + evidence).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from eval_framework.judges.llm_client import llm_request_for_json
10
+ from eval_framework.judges.prompts import (
11
+ CORRECTNESS_BATCH_PROMPT,
12
+ EVIDENCE_BATCH_PROMPT,
13
+ INTERFERENCE_EVAL_PROMPT,
14
+ QA_EVALUATION_PROMPT,
15
+ RECALL_BATCH_PROMPT,
16
+ UPDATE_EVAL_PROMPT,
17
+ )
18
+
19
+ __all__ = [
20
+ "evaluate_recall_batch",
21
+ "evaluate_correctness_batch",
22
+ "evaluate_update_single",
23
+ "evaluate_interference_single",
24
+ "evaluate_evidence_batch",
25
+ "evaluate_qa_llm",
26
+ "llm_request_for_json",
27
+ ]
28
+
29
+
30
+ def evaluate_recall_batch(
31
+ extracted_memories_str: str,
32
+ gold_points_tagged: list[str],
33
+ ) -> dict[str, object]:
34
+ """One LLM call: how many gold points are covered? Distinguishes update sub-score.
35
+
36
+ gold_points_tagged: list of "[normal] content" or "[update] content" strings.
37
+ Returns {covered_count, update_covered_count, total, update_total, reasoning}.
38
+ """
39
+ if not extracted_memories_str.strip():
40
+ update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]"))
41
+ return {
42
+ "covered_count": 0, "update_covered_count": 0,
43
+ "total": len(gold_points_tagged), "update_total": update_total,
44
+ "reasoning": "No extracted memories.",
45
+ }
46
+ if not gold_points_tagged:
47
+ return {
48
+ "covered_count": 0, "update_covered_count": 0,
49
+ "total": 0, "update_total": 0, "reasoning": "No gold points.",
50
+ }
51
+
52
+ numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(gold_points_tagged))
53
+ update_total = sum(1 for p in gold_points_tagged if p.startswith("[update]"))
54
+
55
+ prompt = RECALL_BATCH_PROMPT.format(memories=extracted_memories_str, gold_points=numbered)
56
+ try:
57
+ result = llm_request_for_json(prompt)
58
+ covered = int(result.get("covered_count", 0))
59
+ upd_covered = int(result.get("update_covered_count", 0))
60
+ return {
61
+ "covered_count": min(covered, len(gold_points_tagged)),
62
+ "update_covered_count": min(upd_covered, update_total),
63
+ "total": len(gold_points_tagged),
64
+ "update_total": update_total,
65
+ "reasoning": result.get("reasoning", ""),
66
+ }
67
+ except Exception as e:
68
+ return {
69
+ "covered_count": None, "update_covered_count": None,
70
+ "total": len(gold_points_tagged), "update_total": update_total,
71
+ "reasoning": f"LLM error: {e}",
72
+ }
73
+
74
+
75
+ def evaluate_correctness_batch(
76
+ snapshot_memories: list[str],
77
+ gold_points_tagged: list[str],
78
+ interference_total: int,
79
+ ) -> dict[str, object]:
80
+ """One LLM call: is each snapshot memory correct? Includes interference detection.
81
+
82
+ gold_points_tagged: list of "[normal] content", "[update] content", "[interference] content".
83
+ Returns {results: [{id, label}], interference_memorized_count, interference_total, reasoning}.
84
+ """
85
+ if not snapshot_memories:
86
+ return {
87
+ "results": [],
88
+ "interference_memorized_count": 0,
89
+ "interference_total": interference_total,
90
+ "reasoning": "No snapshot memories.",
91
+ }
92
+
93
+ numbered_memories = "\n".join(f"[{i+1}] {m}" for i, m in enumerate(snapshot_memories))
94
+ numbered_golds = "\n".join(f"- {p}" for p in gold_points_tagged) if gold_points_tagged else "(no ground-truth)"
95
+
96
+ prompt = CORRECTNESS_BATCH_PROMPT.format(memories=numbered_memories, gold_points=numbered_golds)
97
+ try:
98
+ result = llm_request_for_json(prompt)
99
+ raw_results = result.get("results", [])
100
+ valid_labels = {"correct", "hallucination", "irrelevant"}
101
+ cleaned = []
102
+ for r in raw_results:
103
+ label = str(r.get("label", "irrelevant")).lower().strip()
104
+ if label not in valid_labels:
105
+ label = "irrelevant"
106
+ cleaned.append({"id": r.get("id"), "label": label})
107
+ interf_mem = int(result.get("interference_memorized_count", 0))
108
+ return {
109
+ "results": cleaned,
110
+ "interference_memorized_count": min(interf_mem, interference_total),
111
+ "interference_total": interference_total,
112
+ "reasoning": result.get("reasoning", ""),
113
+ }
114
+ except Exception as e:
115
+ return {
116
+ "results": [],
117
+ "interference_memorized_count": None,
118
+ "interference_total": interference_total,
119
+ "reasoning": f"LLM error: {e}",
120
+ }
121
+
122
+
123
+ def evaluate_update_single(
124
+ delta_memories_str: str,
125
+ new_content: str,
126
+ old_contents: list[str],
127
+ ) -> dict[str, object]:
128
+ """One LLM call: how did the system handle a single memory update?
129
+
130
+ Returns {label: "updated"|"both"|"outdated", reasoning}.
131
+ """
132
+ old_str = "\n".join(f"- {o}" for o in old_contents) if old_contents else "(none)"
133
+ prompt = UPDATE_EVAL_PROMPT.format(
134
+ memories=delta_memories_str,
135
+ new_content=new_content,
136
+ old_contents=old_str,
137
+ )
138
+ try:
139
+ result = llm_request_for_json(prompt)
140
+ label = str(result.get("label", "outdated")).lower().strip()
141
+ if label not in ("updated", "both", "outdated"):
142
+ label = "outdated"
143
+ return {"label": label, "reasoning": result.get("reasoning", "")}
144
+ except Exception as e:
145
+ return {"label": None, "reasoning": f"LLM error: {e}"}
146
+
147
+
148
+ def evaluate_interference_single(
149
+ delta_memories_str: str,
150
+ interference_content: str,
151
+ ) -> dict[str, object]:
152
+ """One LLM call: did the system incorrectly memorize an interference point?
153
+
154
+ Returns {label: "rejected"|"memorized", reasoning}.
155
+ """
156
+ prompt = INTERFERENCE_EVAL_PROMPT.format(
157
+ memories=delta_memories_str,
158
+ interference_content=interference_content,
159
+ )
160
+ try:
161
+ result = llm_request_for_json(prompt)
162
+ label = str(result.get("label", "memorized")).lower().strip()
163
+ if label not in ("rejected", "memorized"):
164
+ label = "memorized"
165
+ return {"label": label, "reasoning": result.get("reasoning", "")}
166
+ except Exception as e:
167
+ return {"label": None, "reasoning": f"LLM error: {e}"}
168
+
169
+
170
+ def evaluate_evidence_batch(
171
+ retrieved_memories_str: str,
172
+ evidence_points: list[str],
173
+ ) -> dict[str, object]:
174
+ """One LLM call: how many gold evidence points are covered by retrieval?"""
175
+ if not retrieved_memories_str.strip():
176
+ return {"covered_count": 0, "total": len(evidence_points), "reasoning": "No retrieved memories."}
177
+ if not evidence_points:
178
+ return {"covered_count": 0, "total": 0, "reasoning": "No evidence points."}
179
+
180
+ numbered = "\n".join(f"[{i+1}] {p}" for i, p in enumerate(evidence_points))
181
+ prompt = EVIDENCE_BATCH_PROMPT.format(retrieved_memories=retrieved_memories_str, gold_evidence_points=numbered)
182
+ try:
183
+ result = llm_request_for_json(prompt)
184
+ covered = int(result.get("covered_count", 0))
185
+ return {
186
+ "covered_count": min(covered, len(evidence_points)),
187
+ "total": len(evidence_points),
188
+ "reasoning": result.get("reasoning", ""),
189
+ }
190
+ except Exception as e:
191
+ return {"covered_count": None, "total": len(evidence_points), "reasoning": f"LLM error: {e}"}
192
+
193
+
194
+ def evaluate_qa_llm(
195
+ question: str,
196
+ reference_answer: str,
197
+ key_memory_points: str,
198
+ system_response: str,
199
+ ) -> dict[str, object]:
200
+ """LLM judge: classify the QA response as Correct/Hallucination/Omission."""
201
+ if not system_response.strip():
202
+ return {"evaluation_result": "Omission", "reasoning": "Empty system response."}
203
+
204
+ prompt = QA_EVALUATION_PROMPT.format(
205
+ question=question, reference_answer=reference_answer,
206
+ key_memory_points=key_memory_points, response=system_response,
207
+ )
208
+ try:
209
+ result = llm_request_for_json(prompt)
210
+ label = result.get("evaluation_result", "Omission")
211
+ if label not in ("Correct", "Hallucination", "Omission"):
212
+ label = "Omission"
213
+ return {"evaluation_result": label, "reasoning": result.get("reasoning", "")}
214
+ except Exception as e:
215
+ return {"evaluation_result": None, "reasoning": f"LLM judge error: {e}"}
judges/llm_client.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI LLM client for judge calls with retry logic and concurrency control."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import re
8
+ import logging
9
+ import threading
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ from dotenv import load_dotenv
14
+ from tenacity import retry, stop_after_attempt, wait_random_exponential, before_sleep_log
15
+
16
+ # Load .env from project root (walk up from this file to find it)
17
+ _PROJECT_ROOT = Path(__file__).resolve().parents[2]
18
+ load_dotenv(_PROJECT_ROOT / ".env")
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ _JSON_FENCE_RE = re.compile(r"```(?:json)?\s*\n?(.*?)\n?\s*```", re.DOTALL)
23
+ _JSON_FENCE_OPEN_RE = re.compile(r"```(?:json)?\s*\n?(.*)", re.DOTALL)
24
+
25
+ _client: Any = None
26
+ _client_lock = threading.Lock()
27
+ _semaphore: threading.Semaphore | None = None
28
+
29
+
30
+ def _build_client() -> Any:
31
+ from openai import OpenAI
32
+ return OpenAI(
33
+ api_key=os.getenv("OPENAI_API_KEY", ""),
34
+ base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"),
35
+ )
36
+
37
+
38
+ def _get_client() -> Any:
39
+ global _client
40
+ if _client is None:
41
+ with _client_lock:
42
+ if _client is None:
43
+ _client = _build_client()
44
+ return _client
45
+
46
+
47
+ def _get_semaphore() -> threading.Semaphore:
48
+ global _semaphore
49
+ if _semaphore is None:
50
+ max_concurrent = int(os.getenv("LLM_MAX_CONCURRENT", "5"))
51
+ _semaphore = threading.Semaphore(max_concurrent)
52
+ return _semaphore
53
+
54
+
55
+ def _common_params() -> dict[str, Any]:
56
+ params: dict[str, Any] = {}
57
+ model = os.getenv("OPENAI_MODEL") or ""
58
+ max_tok = os.getenv("OPENAI_MAX_TOKENS")
59
+ if max_tok:
60
+ if model.startswith("gpt-5") or model.startswith("o"):
61
+ params["max_completion_tokens"] = int(max_tok)
62
+ else:
63
+ params["max_tokens"] = int(max_tok)
64
+ params["temperature"] = float(os.getenv("JUDGE_TEMPERATURE", "0.0"))
65
+ if os.getenv("OPENAI_TIMEOUT"):
66
+ params["timeout"] = int(os.getenv("OPENAI_TIMEOUT"))
67
+ return params
68
+
69
+
70
+ @retry(
71
+ wait=wait_random_exponential(min=2, max=60),
72
+ stop=stop_after_attempt(8),
73
+ reraise=True,
74
+ before_sleep=before_sleep_log(logger, logging.WARNING),
75
+ )
76
+ def llm_request_for_json(prompt: str) -> dict[str, Any]:
77
+ """Send a prompt to the LLM and parse the JSON block from the response.
78
+
79
+ Respects global concurrency limit (LLM_MAX_CONCURRENT env var, default 5).
80
+ """
81
+ sem = _get_semaphore()
82
+ sem.acquire()
83
+ try:
84
+ client = _get_client()
85
+ model = os.getenv("OPENAI_MODEL") or "gpt-4o"
86
+
87
+ response = client.chat.completions.create(
88
+ model=model,
89
+ messages=[{"role": "user", "content": prompt}],
90
+ **_common_params(),
91
+ )
92
+ content = response.choices[0].message.content or ""
93
+ finally:
94
+ sem.release()
95
+
96
+ parsed = _extract_json(content)
97
+ if parsed is not None:
98
+ return parsed
99
+
100
+ raise ValueError(f"No JSON block found in model output: {content[:500]}")
101
+
102
+
103
+ def _extract_json(content: str) -> dict[str, Any] | None:
104
+ """Try to extract a JSON object from model output, with truncation repair."""
105
+ # 1. Closed fence: ```json ... ```
106
+ for match in _JSON_FENCE_RE.finditer(content):
107
+ candidate = match.group(1).strip()
108
+ if candidate.startswith("{"):
109
+ try:
110
+ return json.loads(candidate)
111
+ except json.JSONDecodeError:
112
+ pass
113
+
114
+ # 2. Open fence (output truncated before closing ```): ```json ...EOF
115
+ match = _JSON_FENCE_OPEN_RE.search(content)
116
+ if match:
117
+ candidate = match.group(1).strip()
118
+ if candidate.startswith("{"):
119
+ repaired = _repair_truncated_json(candidate)
120
+ if repaired is not None:
121
+ return repaired
122
+
123
+ # 3. Raw JSON without fences
124
+ stripped = content.strip()
125
+ if stripped.startswith("{"):
126
+ try:
127
+ return json.loads(stripped)
128
+ except json.JSONDecodeError:
129
+ repaired = _repair_truncated_json(stripped)
130
+ if repaired is not None:
131
+ return repaired
132
+
133
+ return None
134
+
135
+
136
+ def _repair_truncated_json(text: str) -> dict[str, Any] | None:
137
+ """Best-effort repair of truncated JSON by closing open brackets/braces."""
138
+ # Remove trailing partial tokens (incomplete key/value after last comma)
139
+ text = re.sub(r',\s*"[^"]*$', "", text) # trailing partial key
140
+ text = re.sub(r',\s*\{[^}]*$', "", text) # trailing partial object
141
+ text = re.sub(r',\s*$', "", text) # trailing comma
142
+
143
+ # Count unclosed brackets and braces, then append closers
144
+ open_braces = text.count("{") - text.count("}")
145
+ open_brackets = text.count("[") - text.count("]")
146
+ suffix = "]" * max(open_brackets, 0) + "}" * max(open_braces, 0)
147
+ candidate = text + suffix
148
+
149
+ try:
150
+ result = json.loads(candidate)
151
+ if isinstance(result, dict):
152
+ logger.warning("Repaired truncated JSON (appended %r)", suffix)
153
+ return result
154
+ except json.JSONDecodeError:
155
+ pass
156
+ return None
judges/prompts.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM judge prompt templates — batch mode.
2
+
3
+ Each session: 2 LLM calls (recall + correctness).
4
+ Each QA question: 2 LLM calls (answer + evidence).
5
+ """
6
+
7
+ # ---------------------------------------------------------------------------
8
+ # Session Call 1: Recall (Gold -> Snapshot)
9
+ # ---------------------------------------------------------------------------
10
+ RECALL_BATCH_PROMPT = """You are a **Memory Recall Evaluator**.
11
+ Determine how many of the **Expected Memory Points** are covered by the system's **Extracted Memories**.
12
+
13
+ # Inputs
14
+
15
+ 1. **Extracted Memories** (what the system actually stored):
16
+ {memories}
17
+
18
+ 2. **Expected Memory Points** (numbered, each tagged [normal] or [update]):
19
+ {gold_points}
20
+
21
+ # Instructions
22
+
23
+ - Go through each Expected Memory Point and check whether the Extracted Memories contain information that covers it.
24
+ - Semantic matching is acceptable; exact wording is NOT required.
25
+ - Count **total** covered points AND separately count how many **[update]** tagged points are covered.
26
+
27
+ # Output
28
+
29
+ ```json
30
+ {{
31
+ "covered_count": <int>,
32
+ "update_covered_count": <int>,
33
+ "total": <int>,
34
+ "update_total": <int>,
35
+ "reasoning": "Brief summary"
36
+ }}
37
+ ```
38
+ """
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Session Call 2: Correctness (Snapshot -> Gold)
42
+ # ---------------------------------------------------------------------------
43
+ CORRECTNESS_BATCH_PROMPT = """You are a **Memory Correctness Evaluator**.
44
+ Evaluate whether each memory stored by the system is factually correct.
45
+
46
+ # Inputs
47
+
48
+ 1. **System Memories** (numbered, what the system actually stored):
49
+ {memories}
50
+
51
+ 2. **Ground-Truth Reference Points** (tagged [normal], [update], or [interference]):
52
+ {gold_points}
53
+
54
+ # Instructions
55
+
56
+ For **each** System Memory, classify it as one of:
57
+ - **correct**: The memory is factually accurate and consistent with the [normal] or [update] ground-truth points.
58
+ - **hallucination**: The memory contains fabricated or incorrect information, OR it contains content from [interference] points (information that should NOT have been memorized).
59
+ - **irrelevant**: The memory is not wrong per se, but is trivial filler or not related to any ground-truth point.
60
+
61
+ **IMPORTANT**: If a System Memory matches or contains information from an [interference] tagged point, it MUST be classified as **hallucination**, because the system should have ignored that information.
62
+
63
+ Also count how many [interference] ground-truth points appear in the System Memories.
64
+
65
+ # Output
66
+
67
+ ```json
68
+ {{
69
+ "results": [
70
+ {{"id": 1, "label": "correct|hallucination|irrelevant"}},
71
+ {{"id": 2, "label": "correct|hallucination|irrelevant"}}
72
+ ],
73
+ "interference_memorized_count": <int>,
74
+ "interference_total": <int>,
75
+ "reasoning": "Brief justification"
76
+ }}
77
+ ```
78
+ """
79
+
80
+ # ---------------------------------------------------------------------------
81
+ # Session: Update handling (per update gold point)
82
+ # ---------------------------------------------------------------------------
83
+ UPDATE_EVAL_PROMPT = """You are a **Memory Update Evaluator**.
84
+ Determine how a memory system handled an information update.
85
+
86
+ # Inputs
87
+
88
+ 1. **System Memories** (what the system currently stores after this session):
89
+ {memories}
90
+
91
+ 2. **Updated Fact** (the NEW correct information):
92
+ {new_content}
93
+
94
+ 3. **Outdated Fact(s)** (the OLD information that should have been replaced):
95
+ {old_contents}
96
+
97
+ # Instructions
98
+
99
+ Check the System Memories and classify the update handling as one of:
100
+
101
+ - **updated**: The system stores ONLY the new/updated information. The outdated fact is no longer present. This is the ideal outcome.
102
+ - **both**: The system stores BOTH the new and the old information. The update was partially handled — the new fact was added but the old was not removed.
103
+ - **outdated**: The system stores ONLY the old/outdated information. The update was missed entirely — the new fact is absent.
104
+
105
+ Use semantic matching, not exact wording.
106
+
107
+ # Output
108
+
109
+ ```json
110
+ {{
111
+ "label": "updated|both|outdated",
112
+ "reasoning": "Brief justification"
113
+ }}
114
+ ```
115
+ """
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # Session: Interference rejection (per interference gold point)
119
+ # ---------------------------------------------------------------------------
120
+ INTERFERENCE_EVAL_PROMPT = """You are a **Memory Interference Evaluator**.
121
+ Determine whether a memory system incorrectly stored information that should have been ignored.
122
+
123
+ # Inputs
124
+
125
+ 1. **System Memories** (what the system currently stores after this session):
126
+ {memories}
127
+
128
+ 2. **Interference Content** (information that should NOT have been memorized):
129
+ {interference_content}
130
+
131
+ # Instructions
132
+
133
+ Check whether the System Memories contain the interference content (or its semantic equivalent). Classify as:
134
+
135
+ - **rejected**: The interference content is NOT present in the system memories. The system correctly ignored it.
136
+ - **memorized**: The interference content IS present (or semantically equivalent) in the system memories. The system incorrectly stored it.
137
+
138
+ Use semantic matching, not exact wording.
139
+
140
+ # Output
141
+
142
+ ```json
143
+ {{
144
+ "label": "rejected|memorized",
145
+ "reasoning": "Brief justification"
146
+ }}
147
+ ```
148
+ """
149
+
150
+ # ---------------------------------------------------------------------------
151
+ # QA Evidence Coverage (batch per question)
152
+ # ---------------------------------------------------------------------------
153
+ EVIDENCE_BATCH_PROMPT = """You are an **Evidence Retrieval Evaluator**.
154
+ Determine how many of the **Gold Evidence Points** are covered by the system's **Retrieved Memories** when answering a question.
155
+
156
+ # Inputs
157
+
158
+ 1. **Retrieved Memories** (what the system retrieved to answer the question):
159
+ {retrieved_memories}
160
+
161
+ 2. **Gold Evidence Points** (key facts needed for the correct answer, numbered):
162
+ {gold_evidence_points}
163
+
164
+ # Instructions
165
+
166
+ Go through each Gold Evidence Point and check whether the Retrieved Memories contain information that covers it. Semantic matching is acceptable; exact wording is NOT required.
167
+
168
+ Count how many Gold Evidence Points are **fully covered or logically implied** by the Retrieved Memories.
169
+
170
+ # Output
171
+
172
+ ```json
173
+ {{
174
+ "covered_count": <int>,
175
+ "total": <int>,
176
+ "reasoning": "Brief summary"
177
+ }}
178
+ ```
179
+ """
180
+
181
+ # ---------------------------------------------------------------------------
182
+ # QA Answer evaluation
183
+ # ---------------------------------------------------------------------------
184
+ QA_EVALUATION_PROMPT = """You are an **evaluation expert for AI memory system question answering**.
185
+ Based **only** on the provided **"Question"**, **"Reference Answer"**, and **"Key Memory Points"**, strictly evaluate the **accuracy** of the **"Memory System Response."** Classify it as one of **"Correct"**, **"Hallucination"**, or **"Omission."** Do **not** use any external knowledge or subjective inference.
186
+
187
+ # Evaluation Criteria
188
+
189
+ ### 1. Correct
190
+ * The response accurately answers the question and is **semantically equivalent** to the Reference Answer.
191
+ * No contradictions with Key Memory Points or Reference Answer.
192
+ * Synonyms, paraphrasing, and reasonable summarization are acceptable.
193
+
194
+ ### 2. Hallucination
195
+ * The response includes information that **contradicts** the Reference Answer or Key Memory Points.
196
+ * When the Reference Answer is *unknown/uncertain*, yet the response provides a specific fact.
197
+
198
+ ### 3. Omission
199
+ * The response is **incomplete** compared to the Reference Answer.
200
+ * It states "don't know" or "no related memory" even though relevant information exists.
201
+ * For multi-element questions, missing **any** element counts as Omission.
202
+
203
+ ## Priority Rules
204
+ * Both missing info AND fabricated info -> **Hallucination**.
205
+ * No fabrication but missing info -> **Omission**.
206
+ * Fully equivalent -> **Correct**.
207
+
208
+ # Information
209
+
210
+ * **Question:** {question}
211
+ * **Reference Answer:** {reference_answer}
212
+ * **Key Memory Points:** {key_memory_points}
213
+ * **Memory System Response:** {response}
214
+
215
+ # Output
216
+
217
+ ```json
218
+ {{
219
+ "reasoning": "Concise evaluation rationale",
220
+ "evaluation_result": "Correct | Hallucination | Omission"
221
+ }}
222
+ ```
223
+ """
memory_adapters/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Memory system adapters for the eval framework."""
2
+
3
+ from eval_framework.memory_adapters.base import MemoryAdapter
4
+ from eval_framework.memory_adapters.memgallery_native import (
5
+ MemGalleryNativeAdapter,
6
+ instantiate_memgallery_memory,
7
+ )
8
+ from eval_framework.memory_adapters.registry import (
9
+ EXTERNAL_ADAPTER_KEYS,
10
+ EXTERNAL_ADAPTER_REGISTRY,
11
+ MEMGALLERY_NATIVE_BASELINES,
12
+ MEMGALLERY_NATIVE_REGISTRY,
13
+ create_external_adapter,
14
+ create_memgallery_adapter,
15
+ )
16
+
17
+ __all__ = [
18
+ "EXTERNAL_ADAPTER_KEYS",
19
+ "EXTERNAL_ADAPTER_REGISTRY",
20
+ "MEMGALLERY_NATIVE_BASELINES",
21
+ "MEMGALLERY_NATIVE_REGISTRY",
22
+ "MemoryAdapter",
23
+ "MemGalleryNativeAdapter",
24
+ "create_external_adapter",
25
+ "create_memgallery_adapter",
26
+ "instantiate_memgallery_memory",
27
+ ]
memory_adapters/amem.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapter for the external A-Mem baseline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import Any, Callable
10
+
11
+ from eval_framework.datasets.schemas import (
12
+ MemoryDeltaRecord,
13
+ MemorySnapshotRecord,
14
+ NormalizedTurn,
15
+ RetrievalItem,
16
+ RetrievalRecord,
17
+ )
18
+ from eval_framework.memory_adapters.base import MemoryAdapter
19
+
20
+ _BACKEND_ID = "A-Mem"
21
+
22
+ INTEGRATION_ERROR = (
23
+ f"{_BACKEND_ID} backend unavailable."
24
+ )
25
+
26
+
27
+ class AMemAdapter(MemoryAdapter):
28
+ """Thin wrapper around A-Mem's robust memory system."""
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ backend: Any | None = None,
34
+ backend_factory: Callable[[], Any] | None = None,
35
+ source_root: str | os.PathLike[str] | None = None,
36
+ model_name: str = "all-MiniLM-L6-v2",
37
+ llm_backend: str = "openai",
38
+ llm_model: str | None = None,
39
+ api_key: str | None = None,
40
+ api_base: str | None = None,
41
+ sglang_host: str = "http://localhost",
42
+ sglang_port: int = 30000,
43
+ ) -> None:
44
+ self._source_root = Path(source_root).resolve() if source_root else self._default_source_root()
45
+ resolved_llm_model = llm_model or os.getenv("OPENAI_MODEL") or "gpt-5.1"
46
+ self._backend: Any | None = None
47
+ self._backend_factory = backend_factory
48
+ self._integration_error: str | None = None
49
+ self._session_id = ""
50
+ self._prev_snapshot_ids: set[str] = set()
51
+ self._note_session_map: dict[str, str] = {}
52
+
53
+ if backend is not None:
54
+ self._backend = backend
55
+ else:
56
+ try:
57
+ if self._backend_factory is None:
58
+ self._backend_factory = self._build_backend_factory(
59
+ model_name=model_name,
60
+ llm_backend=llm_backend,
61
+ llm_model=resolved_llm_model,
62
+ api_key=api_key,
63
+ api_base=api_base,
64
+ sglang_host=sglang_host,
65
+ sglang_port=sglang_port,
66
+ )
67
+ self._backend = self._backend_factory()
68
+ except Exception as exc:
69
+ self._integration_error = str(exc)
70
+
71
+ @staticmethod
72
+ def _default_source_root() -> Path:
73
+ here = Path(__file__).resolve()
74
+ # memory_adapters/ -> eval_framework/ -> our/ -> Benchmark/
75
+ return (here.parents[2].parent / "data_pipline" / "A-mem").resolve()
76
+
77
+ def _build_backend_factory(
78
+ self,
79
+ *,
80
+ model_name: str,
81
+ llm_backend: str,
82
+ llm_model: str,
83
+ api_key: str | None,
84
+ api_base: str | None,
85
+ sglang_host: str,
86
+ sglang_port: int,
87
+ ) -> Callable[[], Any]:
88
+ if not self._source_root.is_dir():
89
+ raise RuntimeError(
90
+ f"{_BACKEND_ID}: source root not found at {self._source_root}"
91
+ )
92
+ src = str(self._source_root)
93
+ if src not in sys.path:
94
+ sys.path.insert(0, src)
95
+ mod = importlib.import_module("memory_layer_robust")
96
+ backend_cls = getattr(mod, "RobustAgenticMemorySystem")
97
+ return lambda: backend_cls(
98
+ model_name=model_name,
99
+ llm_backend=llm_backend,
100
+ llm_model=llm_model,
101
+ api_key=api_key or os.getenv("OPENAI_API_KEY"),
102
+ api_base=api_base or os.getenv("OPENAI_BASE_URL"),
103
+ sglang_host=sglang_host,
104
+ sglang_port=sglang_port,
105
+ )
106
+
107
+ def _runtime_error(self) -> RuntimeError:
108
+ detail = self._integration_error or INTEGRATION_ERROR
109
+ return RuntimeError(
110
+ f"{_BACKEND_ID}: backend unavailable — {detail}"
111
+ )
112
+
113
+ def reset(self) -> None:
114
+ if self._backend_factory is None and self._backend is None:
115
+ raise self._runtime_error()
116
+ if self._backend_factory is not None:
117
+ self._backend = self._backend_factory()
118
+ self._prev_snapshot_ids = set()
119
+ self._note_session_map = {}
120
+ self._session_id = ""
121
+
122
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
123
+ backend = self._require_backend()
124
+ self._session_id = turn.session_id
125
+ text = self._turn_text(turn)
126
+ note_id = backend.add_note(text, time=turn.timestamp)
127
+ self._note_session_map[str(note_id)] = turn.session_id
128
+
129
+ def end_session(self, session_id: str) -> None:
130
+ self._require_backend()
131
+ self._session_id = session_id
132
+
133
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
134
+ backend = self._require_backend()
135
+ rows: list[MemorySnapshotRecord] = []
136
+ for note_id, note in getattr(backend, "memories", {}).items():
137
+ sid = self._note_session_map.get(str(note_id), self._session_id)
138
+ content = str(getattr(note, "content", ""))
139
+ context = getattr(note, "context", "")
140
+ keywords = list(getattr(note, "keywords", []) or [])
141
+ tags = list(getattr(note, "tags", []) or [])
142
+ # Include A-Mem enrichments in the snapshot text so that the
143
+ # eval captures what the system actually processed, not just
144
+ # the raw input.
145
+ enriched_parts = [content]
146
+ if context:
147
+ enriched_parts.append(f"[context] {context}")
148
+ if keywords:
149
+ enriched_parts.append(f"[keywords] {', '.join(keywords)}")
150
+ if tags:
151
+ enriched_parts.append(f"[tags] {', '.join(tags)}")
152
+ rows.append(
153
+ MemorySnapshotRecord(
154
+ memory_id=str(getattr(note, "id", note_id)),
155
+ text="\n".join(enriched_parts),
156
+ session_id=sid,
157
+ status="active",
158
+ source=_BACKEND_ID,
159
+ raw_backend_id=str(getattr(note, "id", note_id)),
160
+ raw_backend_type="a_mem_note",
161
+ metadata={
162
+ "timestamp": getattr(note, "timestamp", None),
163
+ "context": context,
164
+ "keywords": keywords,
165
+ "tags": tags,
166
+ "links": list(getattr(note, "links", []) or []),
167
+ },
168
+ )
169
+ )
170
+ return rows
171
+
172
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
173
+ """Export delta by diffing current snapshot against previous snapshot."""
174
+ self._require_backend()
175
+ current_snapshot = self.snapshot_memories()
176
+ deltas: list[MemoryDeltaRecord] = []
177
+ current_ids: set[str] = set()
178
+
179
+ for snap in current_snapshot:
180
+ current_ids.add(snap.memory_id)
181
+ if snap.memory_id not in self._prev_snapshot_ids:
182
+ deltas.append(
183
+ MemoryDeltaRecord(
184
+ session_id=session_id,
185
+ op="add",
186
+ text=snap.text,
187
+ linked_previous=(),
188
+ raw_backend_id=snap.raw_backend_id,
189
+ metadata={
190
+ "baseline": _BACKEND_ID,
191
+ "backend_type": snap.raw_backend_type,
192
+ },
193
+ )
194
+ )
195
+
196
+ self._prev_snapshot_ids = current_ids
197
+ return deltas
198
+
199
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
200
+ backend = self._require_backend()
201
+ items: list[RetrievalItem] = []
202
+ memories = list(getattr(backend, "memories", {}).values())
203
+ retriever = getattr(backend, "retriever", None)
204
+ if retriever is not None and hasattr(retriever, "search"):
205
+ for rank, idx in enumerate(retriever.search(query, top_k)):
206
+ if 0 <= int(idx) < len(memories):
207
+ note = memories[int(idx)]
208
+ items.append(
209
+ RetrievalItem(
210
+ rank=rank,
211
+ memory_id=str(getattr(note, "id", idx)),
212
+ text=str(getattr(note, "content", "")),
213
+ score=1.0 / float(rank + 1),
214
+ raw_backend_id=str(getattr(note, "id", idx)),
215
+ )
216
+ )
217
+ if not items and hasattr(backend, "find_related_memories_raw"):
218
+ raw = backend.find_related_memories_raw(query, k=top_k)
219
+ if raw:
220
+ items.append(
221
+ RetrievalItem(
222
+ rank=0,
223
+ memory_id="a_mem:bundle",
224
+ text=str(raw),
225
+ score=1.0,
226
+ raw_backend_id=None,
227
+ )
228
+ )
229
+ return RetrievalRecord(
230
+ query=query,
231
+ top_k=top_k,
232
+ items=items[:top_k],
233
+ raw_trace={"baseline": _BACKEND_ID},
234
+ )
235
+
236
+ def get_capabilities(self) -> dict[str, Any]:
237
+ available = self._backend is not None or self._backend_factory is not None
238
+ return {
239
+ "backend": _BACKEND_ID,
240
+ "baseline": _BACKEND_ID,
241
+ "available": available and self._integration_error is None,
242
+ "integration_status": "integrated" if available and self._integration_error is None else "unavailable",
243
+ "integration_error": self._integration_error or INTEGRATION_ERROR,
244
+ "delta_granularity": "ingest_turn_only",
245
+ "snapshot_mode": "full_store",
246
+ }
247
+
248
+ def _require_backend(self) -> Any:
249
+ if self._backend is None:
250
+ raise self._runtime_error()
251
+ return self._backend
252
+
253
+ @staticmethod
254
+ def _turn_text(turn: NormalizedTurn) -> str:
255
+ parts = [f"{turn.role}: {turn.text}"]
256
+ for att in turn.attachments:
257
+ parts.append(f"[{att.type}] {att.caption}")
258
+ return "\n".join(parts)
memory_adapters/amem_v2.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapter for A-Mem (new API: agentic_memory.AgenticMemorySystem)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv(Path(__file__).resolve().parents[2] / ".env")
13
+
14
+ from eval_framework.datasets.schemas import (
15
+ MemoryDeltaRecord,
16
+ MemorySnapshotRecord,
17
+ NormalizedTurn,
18
+ RetrievalItem,
19
+ RetrievalRecord,
20
+ )
21
+ from eval_framework.memory_adapters.base import MemoryAdapter
22
+
23
+ _DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/A-Mem")
24
+
25
+
26
+ class AMemV2Adapter(MemoryAdapter):
27
+ """Adapter for A-Mem (new agentic_memory API)."""
28
+
29
+ def __init__(
30
+ self,
31
+ *,
32
+ source_root: str | os.PathLike[str] | None = None,
33
+ **kwargs: Any,
34
+ ) -> None:
35
+ root = Path(source_root or _DEFAULT_SOURCE).resolve()
36
+ if str(root) not in sys.path:
37
+ sys.path.insert(0, str(root))
38
+
39
+ from agentic_memory.memory_system import AgenticMemorySystem
40
+
41
+ self._cls = AgenticMemorySystem
42
+ self._backend: Any = None
43
+ self._session_id = ""
44
+ self._prev_snapshot_ids: set[str] = set()
45
+ self._init_backend()
46
+
47
+ def _init_backend(self) -> None:
48
+ self._backend = self._cls(
49
+ model_name="all-MiniLM-L6-v2",
50
+ llm_backend="openai",
51
+ llm_model=os.getenv("OPENAI_MODEL") or "gpt-4o",
52
+ api_key=os.getenv("OPENAI_API_KEY"),
53
+ )
54
+
55
+ def reset(self) -> None:
56
+ self._init_backend()
57
+ self._prev_snapshot_ids = set()
58
+
59
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
60
+ self._session_id = turn.session_id
61
+ text = f"{turn.role}: {turn.text}"
62
+ for att in turn.attachments:
63
+ text += f"\n[{att.type}] {att.caption}"
64
+ self._backend.add_note(text, time=turn.timestamp)
65
+
66
+ def end_session(self, session_id: str) -> None:
67
+ self._session_id = session_id
68
+
69
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
70
+ rows: list[MemorySnapshotRecord] = []
71
+ for mid, note in self._backend.memories.items():
72
+ content = str(getattr(note, "content", ""))
73
+ context = getattr(note, "context", "")
74
+ keywords = list(getattr(note, "keywords", []) or [])
75
+ parts = [content]
76
+ if context:
77
+ parts.append(f"[context] {context}")
78
+ if keywords:
79
+ parts.append(f"[keywords] {', '.join(keywords)}")
80
+ rows.append(MemorySnapshotRecord(
81
+ memory_id=str(mid),
82
+ text="\n".join(parts),
83
+ session_id=self._session_id,
84
+ status="active",
85
+ source="A-Mem",
86
+ raw_backend_id=str(mid),
87
+ raw_backend_type="a_mem_note",
88
+ metadata={},
89
+ ))
90
+ return rows
91
+
92
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
93
+ current = self.snapshot_memories()
94
+ current_ids = {s.memory_id for s in current}
95
+ deltas = [
96
+ MemoryDeltaRecord(
97
+ session_id=session_id, op="add", text=s.text,
98
+ linked_previous=(), raw_backend_id=s.raw_backend_id,
99
+ metadata={"baseline": "A-Mem"},
100
+ )
101
+ for s in current if s.memory_id not in self._prev_snapshot_ids
102
+ ]
103
+ self._prev_snapshot_ids = current_ids
104
+ return deltas
105
+
106
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
107
+ items: list[RetrievalItem] = []
108
+ try:
109
+ results = self._backend.search(query, k=top_k)
110
+ for i, r in enumerate(results[:top_k]):
111
+ text = r.get("content", str(r)) if isinstance(r, dict) else str(r)
112
+ mid = r.get("id", str(i)) if isinstance(r, dict) else str(i)
113
+ score = float(r.get("score", 1.0 / (i + 1))) if isinstance(r, dict) else 1.0 / (i + 1)
114
+ items.append(RetrievalItem(
115
+ rank=i, memory_id=str(mid), text=text,
116
+ score=score, raw_backend_id=str(mid),
117
+ ))
118
+ except Exception:
119
+ # Fallback to raw search
120
+ try:
121
+ raw = self._backend.find_related_memories_raw(query, k=top_k)
122
+ if raw:
123
+ items.append(RetrievalItem(
124
+ rank=0, memory_id="bundle", text=str(raw),
125
+ score=1.0, raw_backend_id=None,
126
+ ))
127
+ except Exception:
128
+ pass
129
+
130
+ return RetrievalRecord(
131
+ query=query, top_k=top_k, items=items[:top_k],
132
+ raw_trace={"baseline": "A-Mem"},
133
+ )
134
+
135
+ def get_capabilities(self) -> dict[str, Any]:
136
+ return {
137
+ "backend": "A-Mem",
138
+ "baseline": "A-Mem",
139
+ "available": self._backend is not None,
140
+ "delta_granularity": "snapshot_diff",
141
+ "snapshot_mode": "full",
142
+ }
memory_adapters/base.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Abstract memory adapter API for eval baselines."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any
7
+
8
+ from eval_framework.datasets.schemas import (
9
+ MemoryDeltaRecord,
10
+ MemorySnapshotRecord,
11
+ NormalizedTurn,
12
+ RetrievalRecord,
13
+ )
14
+
15
+
16
+ class MemoryAdapter(ABC):
17
+ """Baseline-agnostic adapter surface used by the eval pipeline."""
18
+
19
+ @abstractmethod
20
+ def reset(self) -> None:
21
+ """Clear backend state and any adapter-side bookkeeping."""
22
+
23
+ @abstractmethod
24
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
25
+ """Feed one conversation turn into the memory system."""
26
+
27
+ @abstractmethod
28
+ def end_session(self, session_id: str) -> None:
29
+ """Notify the adapter that a session boundary was reached (optional for many backends)."""
30
+
31
+ @abstractmethod
32
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
33
+ """Return a normalized view of memories observable in the backend."""
34
+
35
+ @abstractmethod
36
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
37
+ """Export memory changes for the given session since the last call."""
38
+
39
+ @abstractmethod
40
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
41
+ """Run retrieval and normalize results."""
42
+
43
+ @abstractmethod
44
+ def get_capabilities(self) -> dict[str, Any]:
45
+ """Describe adapter behavior limits (deltas, snapshots, backend id)."""
memory_adapters/dummy.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dummy memory adapter for end-to-end pipeline testing.
2
+
3
+ Stores all ingested turns as raw text and retrieves by simple substring match.
4
+ No external dependencies required.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any
10
+
11
+ from eval_framework.datasets.schemas import (
12
+ MemoryDeltaRecord,
13
+ MemorySnapshotRecord,
14
+ NormalizedTurn,
15
+ RetrievalItem,
16
+ RetrievalRecord,
17
+ )
18
+ from eval_framework.memory_adapters.base import MemoryAdapter
19
+
20
+
21
+ class DummyAdapter(MemoryAdapter):
22
+ """Minimal adapter that stores turns verbatim — for pipeline testing."""
23
+
24
+ def __init__(self) -> None:
25
+ self._memories: list[dict[str, str]] = []
26
+ self._session_id = ""
27
+ self._prev_ids: set[str] = set()
28
+
29
+ def reset(self) -> None:
30
+ self._memories = []
31
+ self._session_id = ""
32
+ self._prev_ids = set()
33
+
34
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
35
+ self._session_id = turn.session_id
36
+ text = f"{turn.role}: {turn.text}"
37
+ for att in turn.attachments:
38
+ text += f"\n[{att.type}] {att.caption}"
39
+ mid = str(len(self._memories))
40
+ self._memories.append({
41
+ "id": mid,
42
+ "text": text,
43
+ "session_id": turn.session_id,
44
+ })
45
+
46
+ def end_session(self, session_id: str) -> None:
47
+ self._session_id = session_id
48
+
49
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
50
+ return [
51
+ MemorySnapshotRecord(
52
+ memory_id=m["id"],
53
+ text=m["text"],
54
+ session_id=m["session_id"],
55
+ status="active",
56
+ source="Dummy",
57
+ raw_backend_id=m["id"],
58
+ raw_backend_type="dummy",
59
+ metadata={},
60
+ )
61
+ for m in self._memories
62
+ ]
63
+
64
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
65
+ current_ids = {m["id"] for m in self._memories}
66
+ new_ids = current_ids - self._prev_ids
67
+ deltas = [
68
+ MemoryDeltaRecord(
69
+ session_id=session_id,
70
+ op="add",
71
+ text=m["text"],
72
+ linked_previous=(),
73
+ raw_backend_id=m["id"],
74
+ metadata={"baseline": "Dummy"},
75
+ )
76
+ for m in self._memories
77
+ if m["id"] in new_ids
78
+ ]
79
+ self._prev_ids = current_ids
80
+ return deltas
81
+
82
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
83
+ query_lower = query.lower()
84
+ scored = []
85
+ for m in self._memories:
86
+ text_lower = m["text"].lower()
87
+ # Simple word overlap score
88
+ query_words = set(query_lower.split())
89
+ text_words = set(text_lower.split())
90
+ overlap = len(query_words & text_words)
91
+ scored.append((overlap, m))
92
+ scored.sort(key=lambda x: x[0], reverse=True)
93
+
94
+ items = [
95
+ RetrievalItem(
96
+ rank=i,
97
+ memory_id=m["id"],
98
+ text=m["text"],
99
+ score=float(overlap) / max(len(query.split()), 1),
100
+ raw_backend_id=m["id"],
101
+ )
102
+ for i, (overlap, m) in enumerate(scored[:top_k])
103
+ ]
104
+ return RetrievalRecord(
105
+ query=query,
106
+ top_k=top_k,
107
+ items=items,
108
+ raw_trace={"baseline": "Dummy"},
109
+ )
110
+
111
+ def get_capabilities(self) -> dict[str, Any]:
112
+ return {
113
+ "backend": "Dummy",
114
+ "baseline": "Dummy",
115
+ "available": True,
116
+ "delta_granularity": "per_turn",
117
+ "snapshot_mode": "full",
118
+ }
memory_adapters/export_utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers to map turns, backend memory dicts, and recall outputs into shared schemas."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Mapping
6
+
7
+ from eval_framework.datasets.schemas import (
8
+ MemorySnapshotRecord,
9
+ NormalizedTurn,
10
+ RetrievalItem,
11
+ RetrievalRecord,
12
+ )
13
+
14
+
15
+ def turn_to_observation_dict(turn: NormalizedTurn) -> dict[str, Any]:
16
+ """Build a Mem-Gallery store observation from a normalized turn."""
17
+ parts: list[str] = [turn.text]
18
+ for att in turn.attachments:
19
+ parts.append(f"[{att.type}] {att.caption}")
20
+ text = "\n".join(parts)
21
+ obs: dict[str, Any] = {"text": text}
22
+ if turn.timestamp:
23
+ obs["timestamp"] = turn.timestamp
24
+ obs["dialogue_id"] = f"{turn.session_id}:{turn.turn_index}"
25
+ return obs
26
+
27
+
28
+ def memory_element_text(element: Mapping[str, Any]) -> str:
29
+ """Best-effort text extraction from a Mem-Gallery memory dict."""
30
+ raw = element.get("text", "")
31
+ if isinstance(raw, list):
32
+ return " ".join(str(x) for x in raw)
33
+ if raw is None:
34
+ base = ""
35
+ else:
36
+ base = str(raw)
37
+ image = element.get("image")
38
+ if isinstance(image, dict):
39
+ cap = image.get("caption")
40
+ if cap:
41
+ base = f"{base}\n[image] {cap}".strip()
42
+ return base
43
+
44
+
45
+ def linear_element_to_snapshot(
46
+ element: Mapping[str, Any],
47
+ *,
48
+ memory_id: str,
49
+ session_id: str,
50
+ source: str,
51
+ status: str = "active",
52
+ ) -> MemorySnapshotRecord:
53
+ """Map a linear-storage memory dict into MemorySnapshotRecord."""
54
+ cid = element.get("counter_id")
55
+ raw_id = str(cid) if cid is not None else memory_id
56
+ return MemorySnapshotRecord(
57
+ memory_id=memory_id,
58
+ text=memory_element_text(element),
59
+ session_id=session_id,
60
+ status=status,
61
+ source=source,
62
+ raw_backend_id=raw_id,
63
+ raw_backend_type="linear",
64
+ metadata={},
65
+ )
66
+
67
+
68
+ def normalize_recall_to_retrieval(
69
+ query: str,
70
+ top_k: int,
71
+ raw: Any,
72
+ *,
73
+ raw_trace: dict[str, Any] | None = None,
74
+ ) -> RetrievalRecord:
75
+ """Normalize Mem-Gallery recall outputs into RetrievalRecord."""
76
+ trace = dict(raw_trace or {})
77
+ items: list[RetrievalItem] = []
78
+
79
+ if isinstance(raw, str):
80
+ items.append(
81
+ RetrievalItem(
82
+ rank=0,
83
+ memory_id="memgallery:string_bundle",
84
+ text=raw,
85
+ score=1.0,
86
+ raw_backend_id=None,
87
+ )
88
+ )
89
+ elif isinstance(raw, list):
90
+ for i, row in enumerate(raw[: max(0, top_k)]):
91
+ if isinstance(row, dict):
92
+ mid = row.get("counter_id")
93
+ items.append(
94
+ RetrievalItem(
95
+ rank=i,
96
+ memory_id=str(mid if mid is not None else i),
97
+ text=memory_element_text(row),
98
+ score=float(row.get("score", 1.0)),
99
+ raw_backend_id=str(mid) if mid is not None else None,
100
+ )
101
+ )
102
+ else:
103
+ items.append(
104
+ RetrievalItem(
105
+ rank=i,
106
+ memory_id=str(i),
107
+ text=str(row),
108
+ score=1.0,
109
+ raw_backend_id=None,
110
+ )
111
+ )
112
+ else:
113
+ items.append(
114
+ RetrievalItem(
115
+ rank=0,
116
+ memory_id="memgallery:object_bundle",
117
+ text=str(raw),
118
+ score=1.0,
119
+ raw_backend_id=None,
120
+ )
121
+ )
122
+
123
+ return RetrievalRecord(query=query, top_k=top_k, items=items[:top_k], raw_trace=trace)
memory_adapters/mem0_adapter.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapters for Mem0 and Mem0-Graph baselines."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import uuid as _uuid
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from dotenv import load_dotenv
11
+
12
+ load_dotenv(Path(__file__).resolve().parents[2] / ".env")
13
+
14
+ from eval_framework.datasets.schemas import (
15
+ MemoryDeltaRecord,
16
+ MemorySnapshotRecord,
17
+ NormalizedTurn,
18
+ RetrievalItem,
19
+ RetrievalRecord,
20
+ )
21
+ from eval_framework.memory_adapters.base import MemoryAdapter
22
+
23
+
24
+ class Mem0Adapter(MemoryAdapter):
25
+ """Adapter for Mem0 (vector mode)."""
26
+
27
+ def __init__(self, *, use_graph: bool = False, **kwargs: Any) -> None:
28
+ from mem0 import Memory
29
+
30
+ self._user_id = f"eval_{_uuid.uuid4().hex[:8]}"
31
+ self._session_id = ""
32
+ self._prev_snapshot_ids: set[str] = set()
33
+
34
+ config: dict[str, Any] = {
35
+ "llm": {
36
+ "provider": "openai",
37
+ "config": {
38
+ "model": os.getenv("OPENAI_MODEL") or "gpt-4o",
39
+ "api_key": os.getenv("OPENAI_API_KEY") or "",
40
+ },
41
+ },
42
+ "embedder": {
43
+ "provider": "openai",
44
+ "config": {
45
+ "model": "text-embedding-3-small",
46
+ "api_key": os.getenv("OPENAI_API_KEY") or "",
47
+ "embedding_dims": 1536,
48
+ },
49
+ },
50
+ }
51
+
52
+ base_url = os.getenv("OPENAI_BASE_URL")
53
+ if base_url:
54
+ config["llm"]["config"]["openai_base_url"] = base_url
55
+ config["embedder"]["config"]["openai_base_url"] = base_url
56
+
57
+ if use_graph:
58
+ config["graph_store"] = {
59
+ "provider": "kuzu",
60
+ "config": {
61
+ "url": "/tmp/mem0_kuzu_eval",
62
+ },
63
+ }
64
+
65
+ self._memory = Memory.from_config(config)
66
+ self._use_graph = use_graph
67
+
68
+ def reset(self) -> None:
69
+ self._memory.delete_all(user_id=self._user_id)
70
+ self._user_id = f"eval_{_uuid.uuid4().hex[:8]}"
71
+ self._prev_snapshot_ids = set()
72
+
73
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
74
+ self._session_id = turn.session_id
75
+ text = f"{turn.role}: {turn.text}"
76
+ for att in turn.attachments:
77
+ text += f"\n[{att.type}] {att.caption}"
78
+ # Truncate to avoid excessively long inputs that break graph entity extraction
79
+ text = text[:2000]
80
+ try:
81
+ self._memory.add(
82
+ messages=[{"role": turn.role, "content": text}],
83
+ user_id=self._user_id,
84
+ )
85
+ except Exception:
86
+ # Graph mode can fail on entity embedding; fall back silently
87
+ pass
88
+
89
+ def end_session(self, session_id: str) -> None:
90
+ self._session_id = session_id
91
+
92
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
93
+ all_mems = self._memory.get_all(user_id=self._user_id)
94
+ rows: list[MemorySnapshotRecord] = []
95
+
96
+ # Vector results (standard mode)
97
+ results = all_mems.get("results", []) if isinstance(all_mems, dict) else all_mems
98
+ for mem in results:
99
+ mid = str(mem.get("id", ""))
100
+ text = str(mem.get("memory", ""))
101
+ rows.append(MemorySnapshotRecord(
102
+ memory_id=mid, text=text,
103
+ session_id=self._session_id, status="active",
104
+ source="Mem0", raw_backend_id=mid,
105
+ raw_backend_type="mem0_vector", metadata={},
106
+ ))
107
+
108
+ # Graph relations (graph mode)
109
+ relations = all_mems.get("relations", []) if isinstance(all_mems, dict) else []
110
+ for i, rel in enumerate(relations):
111
+ if isinstance(rel, dict):
112
+ src = rel.get("source", "")
113
+ rtype = rel.get("relationship", "")
114
+ tgt = rel.get("target") or rel.get("destination", "")
115
+ text = f"{src} → {rtype} → {tgt}"
116
+ mid = f"rel_{i}"
117
+ rows.append(MemorySnapshotRecord(
118
+ memory_id=mid, text=text,
119
+ session_id=self._session_id, status="active",
120
+ source="Mem0-Graph", raw_backend_id=mid,
121
+ raw_backend_type="mem0_graph_relation", metadata=rel,
122
+ ))
123
+
124
+ return rows
125
+
126
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
127
+ current = self.snapshot_memories()
128
+ current_ids = {s.memory_id for s in current}
129
+ deltas = [
130
+ MemoryDeltaRecord(
131
+ session_id=session_id,
132
+ op="add",
133
+ text=s.text,
134
+ linked_previous=(),
135
+ raw_backend_id=s.raw_backend_id,
136
+ metadata={"baseline": "Mem0"},
137
+ )
138
+ for s in current if s.memory_id not in self._prev_snapshot_ids
139
+ ]
140
+ self._prev_snapshot_ids = current_ids
141
+ return deltas
142
+
143
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
144
+ results = self._memory.search(query=query, user_id=self._user_id, limit=top_k)
145
+ items: list[RetrievalItem] = []
146
+
147
+ # Vector results
148
+ search_results = results.get("results", []) if isinstance(results, dict) else results
149
+ for i, r in enumerate(search_results[:top_k]):
150
+ items.append(RetrievalItem(
151
+ rank=len(items),
152
+ memory_id=str(r.get("id", i)),
153
+ text=str(r.get("memory", "")),
154
+ score=float(r.get("score", 1.0 / (i + 1))),
155
+ raw_backend_id=str(r.get("id", "")),
156
+ ))
157
+
158
+ # Graph relations
159
+ relations = results.get("relations", []) if isinstance(results, dict) else []
160
+ for rel in relations:
161
+ if isinstance(rel, dict) and len(items) < top_k:
162
+ src = rel.get("source", "")
163
+ rtype = rel.get("relationship", "")
164
+ tgt = rel.get("target") or rel.get("destination", "")
165
+ items.append(RetrievalItem(
166
+ rank=len(items),
167
+ memory_id=f"rel_{len(items)}",
168
+ text=f"{src} → {rtype} → {tgt}",
169
+ score=0.9,
170
+ raw_backend_id=None,
171
+ ))
172
+
173
+ return RetrievalRecord(
174
+ query=query, top_k=top_k, items=items[:top_k],
175
+ raw_trace={"baseline": "Mem0-Graph" if self._use_graph else "Mem0"},
176
+ )
177
+
178
+ def get_capabilities(self) -> dict[str, Any]:
179
+ return {
180
+ "backend": "Mem0-Graph" if self._use_graph else "Mem0",
181
+ "baseline": "Mem0-Graph" if self._use_graph else "Mem0",
182
+ "available": True,
183
+ "delta_granularity": "snapshot_diff",
184
+ "snapshot_mode": "full",
185
+ }
memory_adapters/memgallery_native.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mem-Gallery native baseline wrappers with conservative schema normalization."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import copy
6
+ from typing import Any, Callable
7
+
8
+ from eval_framework.datasets.schemas import (
9
+ MemoryDeltaRecord,
10
+ MemorySnapshotRecord,
11
+ NormalizedTurn,
12
+ RetrievalRecord,
13
+ )
14
+ from eval_framework.memory_adapters.base import MemoryAdapter
15
+ from eval_framework.memory_adapters.export_utils import (
16
+ linear_element_to_snapshot,
17
+ memory_element_text,
18
+ normalize_recall_to_retrieval,
19
+ turn_to_observation_dict,
20
+ )
21
+
22
+
23
+ def _deep_merge_dict(base: dict[str, Any], overrides: dict[str, Any]) -> dict[str, Any]:
24
+ out = copy.deepcopy(base)
25
+ for key, val in overrides.items():
26
+ if (
27
+ key in out
28
+ and isinstance(out[key], dict)
29
+ and isinstance(val, dict)
30
+ ):
31
+ out[key] = _deep_merge_dict(out[key], val)
32
+ else:
33
+ out[key] = copy.deepcopy(val)
34
+ return out
35
+
36
+
37
+ def _default_config_for_baseline(name: str) -> dict[str, Any]:
38
+ import default_config.DefaultMemoryConfig as dmc # type: ignore[import-not-found]
39
+
40
+ key = {
41
+ "FUMemory": "DEFAULT_FUMEMORY",
42
+ "STMemory": "DEFAULT_STMEMORY",
43
+ "LTMemory": "DEFAULT_LTMEMORY",
44
+ "GAMemory": "DEFAULT_GAMEMORY",
45
+ "MGMemory": "DEFAULT_MGMEMORY",
46
+ "RFMemory": "DEFAULT_RFMEMORY",
47
+ "MMMemory": "DEFAULT_MMMEMORY",
48
+ "MMFUMemory": "DEFAULT_MMFUMEMORY",
49
+ "NGMemory": "DEFAULT_NGMEMORY",
50
+ "AUGUSTUSMemory": "DEFAULT_AUGUSTUSMEMORY",
51
+ "UniversalRAGMemory": "DEFAULT_UNIVERSALRAGMEMORY",
52
+ }[name]
53
+ cfg = getattr(dmc, key)
54
+ return copy.deepcopy(cfg)
55
+
56
+
57
+ def _import_memory_class(name: str) -> Callable[..., Any]:
58
+ modmap = {
59
+ "FUMemory": ("memengine.memory.FUMemory", "FUMemory"),
60
+ "STMemory": ("memengine.memory.STMemory", "STMemory"),
61
+ "LTMemory": ("memengine.memory.LTMemory", "LTMemory"),
62
+ "GAMemory": ("memengine.memory.GAMemory", "GAMemory"),
63
+ "MGMemory": ("memengine.memory.MGMemory", "MGMemory"),
64
+ "RFMemory": ("memengine.memory.RFMemory", "RFMemory"),
65
+ "MMMemory": ("memengine.memory.MMMemory", "MMMemory"),
66
+ "MMFUMemory": ("memengine.memory.MMFUMemory", "MMFUMemory"),
67
+ "NGMemory": ("memengine.memory.NGMemory", "NGMemory"),
68
+ "AUGUSTUSMemory": ("memengine.memory.AUGUSTUSMemory", "AUGUSTUSMemory"),
69
+ "UniversalRAGMemory": ("memengine.memory.UniversalRAGMemory", "UniversalRAGMemory"),
70
+ }
71
+ module_path, cls_name = modmap[name]
72
+ import importlib
73
+
74
+ mod = importlib.import_module(module_path)
75
+ return getattr(mod, cls_name)
76
+
77
+
78
+ def instantiate_memgallery_memory(
79
+ baseline_name: str,
80
+ config: dict[str, Any] | None = None,
81
+ ) -> Any:
82
+ """Construct a Mem-Gallery memory object with optional config overrides."""
83
+ base_cfg = _default_config_for_baseline(baseline_name)
84
+ merged = _deep_merge_dict(base_cfg, config or {})
85
+ from memengine.config.Config import MemoryConfig # type: ignore[import-not-found]
86
+
87
+ cls = _import_memory_class(baseline_name)
88
+ return cls(MemoryConfig(merged))
89
+
90
+
91
+ def _graph_nodes_to_snapshots(
92
+ storage: Any,
93
+ *,
94
+ session_id: str,
95
+ source: str,
96
+ include_concepts: bool = False,
97
+ ) -> list[MemorySnapshotRecord]:
98
+ out: list[MemorySnapshotRecord] = []
99
+ order = getattr(storage, "memory_order_map", []) or []
100
+ node_concepts = getattr(storage, "node_concepts", {})
101
+ for mid_idx, node_id in enumerate(order):
102
+ node = storage.node[node_id]
103
+ cid = node.get("counter_id", mid_idx)
104
+ memory_id = f"n{node_id}"
105
+ text = memory_element_text(node)
106
+ # For AUGUSTUS: append concept tags extracted by the system
107
+ if include_concepts:
108
+ concepts = node_concepts.get(node_id, set())
109
+ if concepts:
110
+ text = f"{text}\n[concepts] {', '.join(sorted(concepts))}"
111
+ out.append(
112
+ MemorySnapshotRecord(
113
+ memory_id=memory_id,
114
+ text=text,
115
+ session_id=session_id,
116
+ status="active",
117
+ source=source,
118
+ raw_backend_id=str(cid),
119
+ raw_backend_type="graph_node",
120
+ metadata={"node_id": node_id},
121
+ )
122
+ )
123
+ return out
124
+
125
+
126
+ def _linear_storage_snapshots(
127
+ storage: Any,
128
+ *,
129
+ session_id: str,
130
+ source: str,
131
+ ) -> list[MemorySnapshotRecord]:
132
+ rows: list[MemorySnapshotRecord] = []
133
+ for i, m in enumerate(storage.memory_list):
134
+ cid = m.get("counter_id", i)
135
+ rows.append(
136
+ linear_element_to_snapshot(
137
+ m,
138
+ memory_id=str(cid),
139
+ session_id=session_id,
140
+ source=source,
141
+ )
142
+ )
143
+ return rows
144
+
145
+
146
+ def collect_memgallery_snapshots(
147
+ memory: Any,
148
+ baseline_name: str,
149
+ session_id: str,
150
+ ) -> list[MemorySnapshotRecord]:
151
+ """Best-effort snapshot of backend-visible memories."""
152
+ source = baseline_name
153
+ if baseline_name == "MGMemory":
154
+ out: list[MemorySnapshotRecord] = []
155
+ # store_op/recall_op have their own main_context references;
156
+ # prefer store_op's view as it holds the actual stored data.
157
+ mc = getattr(memory.store_op, "main_context", None) or memory.main_context
158
+ recall_storage = getattr(memory.recall_op, "recall_storage",
159
+ getattr(memory, "recall_storage", None))
160
+ archival_storage = getattr(memory.recall_op, "archival_storage",
161
+ getattr(memory, "archival_storage", None))
162
+ storages = [("wm", mc["working_context"]), ("fifo", mc["FIFO_queue"])]
163
+ if recall_storage is not None:
164
+ storages.append(("recall", recall_storage))
165
+ if archival_storage is not None:
166
+ storages.append(("archival", archival_storage))
167
+ for prefix, st in storages:
168
+ for i, m in enumerate(st.memory_list):
169
+ cid = m.get("counter_id", i)
170
+ mid = f"{prefix}-{cid}"
171
+ rows = linear_element_to_snapshot(
172
+ m,
173
+ memory_id=mid,
174
+ session_id=session_id,
175
+ source=source,
176
+ )
177
+ out.append(rows)
178
+ gsum = mc.get("recursive_summary", {}).get("global")
179
+ if gsum and str(gsum) != "None":
180
+ out.append(
181
+ MemorySnapshotRecord(
182
+ memory_id="recursive_summary",
183
+ text=str(gsum),
184
+ session_id=session_id,
185
+ status="active",
186
+ source=source,
187
+ raw_backend_id=None,
188
+ raw_backend_type="mg_summary",
189
+ metadata={},
190
+ )
191
+ )
192
+ return out
193
+
194
+ if baseline_name == "RFMemory":
195
+ rows = _linear_storage_snapshots(
196
+ memory.storage, session_id=session_id, source=source
197
+ )
198
+ insight = getattr(memory, "insight", {}).get("global_insight", "")
199
+ if insight:
200
+ rows.append(
201
+ MemorySnapshotRecord(
202
+ memory_id="rf_insight",
203
+ text=str(insight),
204
+ session_id=session_id,
205
+ status="active",
206
+ source=source,
207
+ raw_backend_id=None,
208
+ raw_backend_type="rf_insight",
209
+ metadata={},
210
+ )
211
+ )
212
+ return rows
213
+
214
+ if baseline_name == "NGMemory":
215
+ return _graph_nodes_to_snapshots(
216
+ memory.storage, session_id=session_id, source=source
217
+ )
218
+
219
+ if baseline_name == "AUGUSTUSMemory":
220
+ return _graph_nodes_to_snapshots(
221
+ memory.contextual_memory, session_id=session_id, source=source,
222
+ include_concepts=True,
223
+ )
224
+
225
+ if baseline_name == "UniversalRAGMemory":
226
+ return _linear_storage_snapshots(
227
+ memory.storage, session_id=session_id, source=source
228
+ )
229
+
230
+ if hasattr(memory, "storage") and hasattr(memory.storage, "memory_list"):
231
+ return _linear_storage_snapshots(
232
+ memory.storage, session_id=session_id, source=source
233
+ )
234
+
235
+ return []
236
+
237
+
238
+ class MemGalleryNativeAdapter(MemoryAdapter):
239
+ """Thin wrapper that forwards to Mem-Gallery memories and normalizes I/O."""
240
+
241
+ def __init__(self, memory: Any, *, baseline_name: str) -> None:
242
+ self._memory = memory
243
+ self._baseline_name = baseline_name
244
+ self._session_id = ""
245
+ self._prev_snapshot_ids: set[str] = set()
246
+ self._pending_user_turn: NormalizedTurn | None = None
247
+ self._session_turns: list[str] = [] # collect turn texts for RF optimize
248
+
249
+ @classmethod
250
+ def from_baseline(
251
+ cls,
252
+ baseline_name: str,
253
+ *,
254
+ config: dict[str, Any] | None = None,
255
+ ) -> MemGalleryNativeAdapter:
256
+ mem = instantiate_memgallery_memory(baseline_name, config)
257
+ return cls(mem, baseline_name=baseline_name)
258
+
259
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
260
+ """Buffer user turns; store merged user+assistant pair on assistant turn.
261
+
262
+ This matches the original Mem-Gallery benchmark behavior where each
263
+ dialogue round (user + assistant) is merged into a single observation
264
+ before calling store().
265
+ """
266
+ self._session_id = turn.session_id
267
+ if turn.role == "user":
268
+ # Flush any prior unpaired user turn, then buffer this one
269
+ if self._pending_user_turn is not None:
270
+ self._store_observation(self._pending_user_turn, assistant_turn=None)
271
+ self._pending_user_turn = turn
272
+ else:
273
+ # Assistant turn: merge with buffered user turn and store
274
+ self._store_observation(self._pending_user_turn, assistant_turn=turn)
275
+ self._pending_user_turn = None
276
+
277
+ def _store_observation(
278
+ self,
279
+ user_turn: NormalizedTurn | None,
280
+ assistant_turn: NormalizedTurn | None,
281
+ ) -> None:
282
+ """Build a merged observation dict (matching original benchmark format) and store."""
283
+ parts: list[str] = []
284
+ timestamp = None
285
+ dialogue_id = ""
286
+ if user_turn is not None:
287
+ parts.append(f"user: {user_turn.text}")
288
+ for att in user_turn.attachments:
289
+ parts.append(f"[{att.type}] {att.caption}")
290
+ timestamp = user_turn.timestamp
291
+ dialogue_id = f"{user_turn.session_id}:{user_turn.turn_index}"
292
+ if assistant_turn is not None:
293
+ parts.append(f"assistant: {assistant_turn.text}")
294
+ for att in assistant_turn.attachments:
295
+ parts.append(f"[{att.type}] {att.caption}")
296
+ if timestamp is None:
297
+ timestamp = assistant_turn.timestamp
298
+ if not dialogue_id:
299
+ dialogue_id = f"{assistant_turn.session_id}:{assistant_turn.turn_index}"
300
+
301
+ obs: dict[str, Any] = {"text": "\n".join(parts)}
302
+ if timestamp:
303
+ obs["timestamp"] = timestamp
304
+ obs["dialogue_id"] = dialogue_id
305
+ self._memory.store(obs)
306
+ self._session_turns.append(obs["text"])
307
+
308
+ def end_session(self, session_id: str) -> None:
309
+ # Flush any remaining unpaired user turn
310
+ if self._pending_user_turn is not None:
311
+ self._store_observation(self._pending_user_turn, assistant_turn=None)
312
+ self._pending_user_turn = None
313
+
314
+ # --- Trigger backend-specific post-session processing ---
315
+ # GAMemory: self-reflection generates insights and stores them
316
+ if self._baseline_name == "GAMemory":
317
+ try:
318
+ self._memory.manage("reflect")
319
+ except Exception:
320
+ pass # reflection may fail if accumulated importance < threshold
321
+
322
+ # RFMemory: optimize generates a global insight from the session trial
323
+ if self._baseline_name == "RFMemory" and self._session_turns:
324
+ try:
325
+ trial = "\n".join(self._session_turns)
326
+ self._memory.optimize(new_trial=trial)
327
+ except Exception:
328
+ pass
329
+
330
+ self._session_turns = []
331
+
332
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
333
+ sid = self._session_id or ""
334
+ return collect_memgallery_snapshots(
335
+ self._memory, self._baseline_name, sid
336
+ )
337
+
338
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
339
+ """Export delta by diffing current backend snapshot against previous snapshot.
340
+
341
+ This reflects what the backend ACTUALLY stores, not what was fed in.
342
+ For FU/ST/LT/GA/RF (LinearStorage), this is the raw observations added.
343
+ For MGMemory, this includes FIFO items, summaries, and archival entries.
344
+ """
345
+ current_snapshot = self.snapshot_memories()
346
+ prev_ids = self._prev_snapshot_ids
347
+ deltas: list[MemoryDeltaRecord] = []
348
+ current_ids: set[str] = set()
349
+
350
+ for snap in current_snapshot:
351
+ current_ids.add(snap.memory_id)
352
+ if snap.memory_id not in prev_ids:
353
+ deltas.append(
354
+ MemoryDeltaRecord(
355
+ session_id=session_id,
356
+ op="add",
357
+ text=snap.text,
358
+ linked_previous=(),
359
+ raw_backend_id=snap.raw_backend_id,
360
+ metadata={
361
+ "baseline": self._baseline_name,
362
+ "source": snap.source,
363
+ "backend_type": snap.raw_backend_type,
364
+ },
365
+ )
366
+ )
367
+
368
+ self._prev_snapshot_ids = current_ids
369
+ return deltas
370
+
371
+ def reset(self) -> None:
372
+ self._memory.reset()
373
+ self._prev_snapshot_ids = set()
374
+ self._pending_user_turn = None
375
+ self._session_turns = []
376
+
377
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
378
+ raw = self._memory.recall(query)
379
+ trace: dict[str, Any] = {"baseline": self._baseline_name}
380
+ ro = getattr(self._memory, "recall_op", None)
381
+ if ro is not None and hasattr(ro, "last_retrieved_ids"):
382
+ trace["last_retrieved_ids"] = list(ro.last_retrieved_ids)
383
+ return normalize_recall_to_retrieval(query, top_k, raw, raw_trace=trace)
384
+
385
+ def get_capabilities(self) -> dict[str, Any]:
386
+ return {
387
+ "backend": "MemGallery",
388
+ "baseline": self._baseline_name,
389
+ "delta_granularity": "ingest_turn_only",
390
+ "snapshot_mode": "conservative",
391
+ "notes": (
392
+ "Deltas record adapter ingest only; backend-internal rewrite, reflection, "
393
+ "or graph reshaping is not diffed. Snapshots read observable storage where supported."
394
+ ),
395
+ }
memory_adapters/memoryos.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapter for the external MemoryOS baseline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import importlib
6
+ import os
7
+ import shutil
8
+ import sys
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import Any, Callable
12
+
13
+ from eval_framework.datasets.schemas import (
14
+ MemoryDeltaRecord,
15
+ MemorySnapshotRecord,
16
+ NormalizedTurn,
17
+ RetrievalItem,
18
+ RetrievalRecord,
19
+ )
20
+ from eval_framework.memory_adapters.base import MemoryAdapter
21
+
22
+ _BACKEND_ID = "MemoryOS"
23
+
24
+ INTEGRATION_ERROR = (
25
+ f"{_BACKEND_ID} backend unavailable."
26
+ )
27
+
28
+
29
+ class MemoryOSAdapter(MemoryAdapter):
30
+ """Thin wrapper around MemoryOS's local Python API."""
31
+
32
+ def __init__(
33
+ self,
34
+ *,
35
+ backend: Any | None = None,
36
+ backend_factory: Callable[[], Any] | None = None,
37
+ source_root: str | os.PathLike[str] | None = None,
38
+ storage_root: str | os.PathLike[str] | None = None,
39
+ user_id: str = "eval_user",
40
+ assistant_id: str = "eval_assistant",
41
+ llm_model: str | None = None,
42
+ embedding_model_name: str = "all-MiniLM-L6-v2",
43
+ openai_api_key: str | None = None,
44
+ openai_base_url: str | None = None,
45
+ ) -> None:
46
+ self._source_root = Path(source_root).resolve() if source_root else self._default_source_root()
47
+ self._storage_root = Path(storage_root).resolve() if storage_root else Path(
48
+ tempfile.mkdtemp(prefix="memoryos_eval_")
49
+ )
50
+ self._user_id = user_id
51
+ self._assistant_id = assistant_id
52
+ self._llm_model = llm_model or os.getenv("OPENAI_MODEL") or "gpt-5.1"
53
+ self._embedding_model_name = embedding_model_name
54
+ self._openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
55
+ self._openai_base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
56
+ self._backend_factory = backend_factory
57
+ self._backend: Any | None = None
58
+ self._integration_error: str | None = None
59
+ self._session_id = ""
60
+ self._prev_snapshot_ids: set[str] = set()
61
+ self._pending_user_turns: list[NormalizedTurn] = []
62
+
63
+ if backend is not None:
64
+ self._backend = backend
65
+ else:
66
+ try:
67
+ if self._backend_factory is None:
68
+ self._backend_factory = self._build_backend_factory()
69
+ self._backend = self._backend_factory()
70
+ except Exception as exc:
71
+ self._integration_error = str(exc)
72
+
73
+ @staticmethod
74
+ def _default_source_root() -> Path:
75
+ here = Path(__file__).resolve()
76
+ # memory_adapters/ -> eval_framework/ -> nips26/ -> baselines/MemoryOS/memoryos-pypi
77
+ return (here.parents[2] / "baselines" / "MemoryOS" / "memoryos-pypi").resolve()
78
+
79
+ def _build_backend_factory(self) -> Callable[[], Any]:
80
+ if not self._source_root.is_dir():
81
+ raise RuntimeError(
82
+ f"{_BACKEND_ID}: source root not found at {self._source_root}"
83
+ )
84
+ src = str(self._source_root)
85
+ if src not in sys.path:
86
+ sys.path.insert(0, src)
87
+ mod = importlib.import_module("memoryos")
88
+ backend_cls = getattr(mod, "Memoryos")
89
+
90
+ def _factory() -> Any:
91
+ run_root = self._storage_root / "runtime"
92
+ shutil.rmtree(run_root, ignore_errors=True)
93
+ run_root.mkdir(parents=True, exist_ok=True)
94
+ return backend_cls(
95
+ user_id=self._user_id,
96
+ openai_api_key=self._openai_api_key or "",
97
+ openai_base_url=self._openai_base_url,
98
+ data_storage_path=str(run_root),
99
+ llm_model=self._llm_model,
100
+ assistant_id=self._assistant_id,
101
+ embedding_model_name=self._embedding_model_name,
102
+ )
103
+
104
+ return _factory
105
+
106
+ def _runtime_error(self) -> RuntimeError:
107
+ detail = self._integration_error or INTEGRATION_ERROR
108
+ return RuntimeError(
109
+ f"{_BACKEND_ID}: backend unavailable — {detail}"
110
+ )
111
+
112
+ def reset(self) -> None:
113
+ if self._backend_factory is None and self._backend is None:
114
+ raise self._runtime_error()
115
+ if self._backend_factory is not None:
116
+ self._backend = self._backend_factory()
117
+ self._prev_snapshot_ids = set()
118
+ self._pending_user_turns = []
119
+ self._session_id = ""
120
+
121
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
122
+ self._require_backend()
123
+ self._session_id = turn.session_id
124
+ if turn.role == "assistant":
125
+ self._store_pair(turn)
126
+ else:
127
+ self._pending_user_turns.append(turn)
128
+
129
+ def end_session(self, session_id: str) -> None:
130
+ self._require_backend()
131
+ self._session_id = session_id
132
+ if self._pending_user_turns:
133
+ synthetic = self._pending_user_turns[-1]
134
+ self._store_memory(
135
+ session_id=session_id,
136
+ user_input=self._joined_user_text(),
137
+ agent_response="",
138
+ timestamp=synthetic.timestamp,
139
+ )
140
+ self._pending_user_turns = []
141
+
142
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
143
+ backend = self._require_backend()
144
+ rows: list[MemorySnapshotRecord] = []
145
+ sid = self._session_id
146
+
147
+ for idx, qa in enumerate(backend.short_term_memory.get_all()):
148
+ rows.append(
149
+ MemorySnapshotRecord(
150
+ memory_id=f"st:{idx}",
151
+ text=self._format_qa_text(qa),
152
+ session_id=sid,
153
+ status="active",
154
+ source=_BACKEND_ID,
155
+ raw_backend_id=f"st:{idx}",
156
+ raw_backend_type="short_term",
157
+ metadata={"timestamp": qa.get("timestamp")},
158
+ )
159
+ )
160
+
161
+ for internal_session_id, session in getattr(backend.mid_term_memory, "sessions", {}).items():
162
+ for page_idx, page in enumerate(session.get("details", [])):
163
+ rows.append(
164
+ MemorySnapshotRecord(
165
+ memory_id=f"mt:{internal_session_id}:{page_idx}",
166
+ text=self._format_qa_text(page),
167
+ session_id=sid,
168
+ status="active",
169
+ source=_BACKEND_ID,
170
+ raw_backend_id=str(page.get("page_id", f"{internal_session_id}:{page_idx}")),
171
+ raw_backend_type="mid_term_page",
172
+ metadata={"memoryos_session_id": internal_session_id},
173
+ )
174
+ )
175
+
176
+ user_profile = backend.user_long_term_memory.get_raw_user_profile(backend.user_id)
177
+ if user_profile and str(user_profile).lower() != "none":
178
+ rows.append(
179
+ MemorySnapshotRecord(
180
+ memory_id="lt:user_profile",
181
+ text=str(user_profile),
182
+ session_id=sid,
183
+ status="active",
184
+ source=_BACKEND_ID,
185
+ raw_backend_id="user_profile",
186
+ raw_backend_type="user_profile",
187
+ metadata={},
188
+ )
189
+ )
190
+
191
+ for idx, item in enumerate(backend.user_long_term_memory.get_user_knowledge()):
192
+ rows.append(
193
+ MemorySnapshotRecord(
194
+ memory_id=f"lt:user:{idx}",
195
+ text=str(item.get("knowledge", "")),
196
+ session_id=sid,
197
+ status="active",
198
+ source=_BACKEND_ID,
199
+ raw_backend_id=f"user:{idx}",
200
+ raw_backend_type="user_knowledge",
201
+ metadata={"timestamp": item.get("timestamp")},
202
+ )
203
+ )
204
+
205
+ assistant_ltm = getattr(backend, "assistant_long_term_memory", None)
206
+ if assistant_ltm is not None and hasattr(assistant_ltm, "get_assistant_knowledge"):
207
+ for idx, item in enumerate(assistant_ltm.get_assistant_knowledge()):
208
+ rows.append(
209
+ MemorySnapshotRecord(
210
+ memory_id=f"lt:assistant:{idx}",
211
+ text=str(item.get("knowledge", "")),
212
+ session_id=sid,
213
+ status="active",
214
+ source=_BACKEND_ID,
215
+ raw_backend_id=f"assistant:{idx}",
216
+ raw_backend_type="assistant_knowledge",
217
+ metadata={"timestamp": item.get("timestamp")},
218
+ )
219
+ )
220
+ return rows
221
+
222
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
223
+ """Export delta by diffing current snapshot against previous snapshot."""
224
+ self._require_backend()
225
+ current_snapshot = self.snapshot_memories()
226
+ deltas: list[MemoryDeltaRecord] = []
227
+ current_ids: set[str] = set()
228
+
229
+ for snap in current_snapshot:
230
+ current_ids.add(snap.memory_id)
231
+ if snap.memory_id not in self._prev_snapshot_ids:
232
+ deltas.append(
233
+ MemoryDeltaRecord(
234
+ session_id=session_id,
235
+ op="add",
236
+ text=snap.text,
237
+ linked_previous=(),
238
+ raw_backend_id=snap.raw_backend_id,
239
+ metadata={
240
+ "baseline": _BACKEND_ID,
241
+ "backend_type": snap.raw_backend_type,
242
+ },
243
+ )
244
+ )
245
+
246
+ self._prev_snapshot_ids = current_ids
247
+ return deltas
248
+
249
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
250
+ backend = self._require_backend()
251
+ raw = backend.retriever.retrieve_context(query, user_id=backend.user_id)
252
+ items: list[RetrievalItem] = []
253
+
254
+ for page in raw.get("retrieved_pages", []):
255
+ items.append(
256
+ RetrievalItem(
257
+ rank=len(items),
258
+ memory_id=f"page:{len(items)}",
259
+ text=self._format_qa_text(page),
260
+ score=1.0 / float(len(items) + 1),
261
+ raw_backend_id=page.get("page_id"),
262
+ )
263
+ )
264
+ for item in raw.get("retrieved_user_knowledge", []):
265
+ items.append(
266
+ RetrievalItem(
267
+ rank=len(items),
268
+ memory_id=f"user:{len(items)}",
269
+ text=str(item.get("knowledge", "")),
270
+ score=1.0 / float(len(items) + 1),
271
+ raw_backend_id=None,
272
+ )
273
+ )
274
+ for item in raw.get("retrieved_assistant_knowledge", []):
275
+ items.append(
276
+ RetrievalItem(
277
+ rank=len(items),
278
+ memory_id=f"assistant:{len(items)}",
279
+ text=str(item.get("knowledge", "")),
280
+ score=1.0 / float(len(items) + 1),
281
+ raw_backend_id=None,
282
+ )
283
+ )
284
+ return RetrievalRecord(
285
+ query=query,
286
+ top_k=top_k,
287
+ items=items[:top_k],
288
+ raw_trace={"baseline": _BACKEND_ID, "retrieved_at": raw.get("retrieved_at")},
289
+ )
290
+
291
+ def get_capabilities(self) -> dict[str, Any]:
292
+ available = self._backend is not None or self._backend_factory is not None
293
+ return {
294
+ "backend": _BACKEND_ID,
295
+ "baseline": _BACKEND_ID,
296
+ "available": available and self._integration_error is None,
297
+ "integration_status": "integrated" if available and self._integration_error is None else "unavailable",
298
+ "integration_error": self._integration_error or INTEGRATION_ERROR,
299
+ "delta_granularity": "ingest_pair_only",
300
+ "snapshot_mode": "short_mid_long_term",
301
+ }
302
+
303
+ def _require_backend(self) -> Any:
304
+ if self._backend is None:
305
+ raise self._runtime_error()
306
+ return self._backend
307
+
308
+ def _store_pair(self, assistant_turn: NormalizedTurn) -> None:
309
+ user_input = self._joined_user_text()
310
+ self._store_memory(
311
+ session_id=assistant_turn.session_id,
312
+ user_input=user_input,
313
+ agent_response=self._turn_text(assistant_turn),
314
+ timestamp=assistant_turn.timestamp,
315
+ )
316
+ self._pending_user_turns = []
317
+
318
+ def _store_memory(
319
+ self,
320
+ *,
321
+ session_id: str,
322
+ user_input: str,
323
+ agent_response: str,
324
+ timestamp: str | None,
325
+ ) -> None:
326
+ backend = self._require_backend()
327
+ backend.add_memory(
328
+ user_input=user_input,
329
+ agent_response=agent_response,
330
+ timestamp=timestamp,
331
+ meta_data={"session_id": session_id},
332
+ )
333
+
334
+ def _joined_user_text(self) -> str:
335
+ if not self._pending_user_turns:
336
+ return ""
337
+ return "\n".join(self._turn_text(turn) for turn in self._pending_user_turns)
338
+
339
+ @staticmethod
340
+ def _turn_text(turn: NormalizedTurn) -> str:
341
+ parts = [turn.text]
342
+ for att in turn.attachments:
343
+ parts.append(f"[{att.type}] {att.caption}")
344
+ return "\n".join(parts)
345
+
346
+ @staticmethod
347
+ def _format_qa_text(item: dict[str, Any]) -> str:
348
+ parts = []
349
+ user_text = item.get("user_input", "")
350
+ if user_text:
351
+ parts.append(f"user: {user_text}")
352
+ assistant_text = item.get("agent_response", "")
353
+ if assistant_text:
354
+ parts.append(f"assistant: {assistant_text}")
355
+ if not parts:
356
+ parts.append(str(item))
357
+ return "\n".join(parts)
memory_adapters/memverse_adapter.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapter for MemVerse — uses build_memory for storage + cosine retrieval."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import tempfile
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv(Path(__file__).resolve().parents[2] / ".env")
17
+
18
+ from eval_framework.datasets.schemas import (
19
+ MemoryDeltaRecord,
20
+ MemorySnapshotRecord,
21
+ NormalizedTurn,
22
+ RetrievalItem,
23
+ RetrievalRecord,
24
+ )
25
+ from eval_framework.memory_adapters.base import MemoryAdapter
26
+
27
+ _DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/MemVerse")
28
+
29
+
30
+ class MemVerseAdapter(MemoryAdapter):
31
+ """Adapter for MemVerse using build_memory + cosine retrieval.
32
+
33
+ Bypasses the async orchestrator/LightRAG and uses the core
34
+ memory building + embedding-based retrieval directly.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ *,
40
+ source_root: str | os.PathLike[str] | None = None,
41
+ **kwargs: Any,
42
+ ) -> None:
43
+ root = Path(source_root or _DEFAULT_SOURCE).resolve()
44
+ if str(root) not in sys.path:
45
+ sys.path.insert(0, str(root))
46
+
47
+ from openai import OpenAI
48
+
49
+ self._client = OpenAI(
50
+ api_key=os.getenv("OPENAI_API_KEY"),
51
+ base_url=os.getenv("OPENAI_BASE_URL"),
52
+ )
53
+ self._model = os.getenv("OPENAI_MODEL") or "gpt-4o"
54
+
55
+ # Working directory for memory files
56
+ self._work_dir = Path(tempfile.mkdtemp(prefix="memverse_eval_"))
57
+ self._root = root
58
+ self._session_id = ""
59
+ self._prev_snapshot_ids: set[str] = set()
60
+ self._memories: list[dict[str, Any]] = [] # {id, text, embedding, output}
61
+ self._conversation: list[dict[str, Any]] = []
62
+ self._turn_counter = 0
63
+
64
+ # Load system prompts for memory agents
65
+ self._prompts: dict[str, str] = {}
66
+ for name in ["core_memory_agent", "episodic_memory_agent", "semantic_memory_agent"]:
67
+ prompt_path = root / "MemoryKB" / "Long_Term_Memory" / "system" / f"{name}.txt"
68
+ if prompt_path.exists():
69
+ self._prompts[name] = prompt_path.read_text(encoding="utf-8").strip()
70
+
71
+ def _get_embedding(self, text: str) -> np.ndarray:
72
+ resp = self._client.embeddings.create(
73
+ model="text-embedding-3-small",
74
+ input=text,
75
+ )
76
+ return np.array(resp.data[0].embedding)
77
+
78
+ def _cosine_sim(self, a: np.ndarray, b: np.ndarray) -> float:
79
+ norm = np.linalg.norm(a) * np.linalg.norm(b)
80
+ if norm == 0:
81
+ return 0.0
82
+ return float(np.dot(a, b) / norm)
83
+
84
+ def reset(self) -> None:
85
+ self._memories = []
86
+ self._conversation = []
87
+ self._prev_snapshot_ids = set()
88
+ self._turn_counter = 0
89
+ if self._work_dir.exists():
90
+ shutil.rmtree(self._work_dir, ignore_errors=True)
91
+ self._work_dir = Path(tempfile.mkdtemp(prefix="memverse_eval_"))
92
+
93
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
94
+ self._session_id = turn.session_id
95
+ text = f"{turn.role}: {turn.text}"
96
+ for att in turn.attachments:
97
+ text += f"\n[{att.type}] {att.caption}"
98
+
99
+ entry_id = f"turn_{self._turn_counter}"
100
+ self._turn_counter += 1
101
+
102
+ entry = {
103
+ "id": entry_id,
104
+ "query": text,
105
+ "videocaption": None,
106
+ "audiocaption": None,
107
+ "imagecaption": None,
108
+ }
109
+ self._conversation.append(entry)
110
+
111
+ # Build memory: get embedding + LLM extraction for each memory type
112
+ embedding = self._get_embedding(text)
113
+
114
+ # Use the first available prompt (core memory agent) for extraction
115
+ prompt = next(iter(self._prompts.values()), "Extract key facts from this text.")
116
+ try:
117
+ resp = self._client.chat.completions.create(
118
+ model=self._model,
119
+ messages=[
120
+ {"role": "system", "content": prompt},
121
+ {"role": "user", "content": text},
122
+ ],
123
+ temperature=0,
124
+ max_tokens=512,
125
+ )
126
+ output = resp.choices[0].message.content or ""
127
+ except Exception:
128
+ output = text
129
+
130
+ self._memories.append({
131
+ "id": entry_id,
132
+ "text": text,
133
+ "output": output,
134
+ "embedding": embedding,
135
+ "session_id": turn.session_id,
136
+ })
137
+
138
+ def end_session(self, session_id: str) -> None:
139
+ self._session_id = session_id
140
+
141
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
142
+ return [
143
+ MemorySnapshotRecord(
144
+ memory_id=m["id"],
145
+ text=m["output"],
146
+ session_id=m.get("session_id", self._session_id),
147
+ status="active",
148
+ source="MemVerse",
149
+ raw_backend_id=m["id"],
150
+ raw_backend_type="memverse",
151
+ metadata={},
152
+ )
153
+ for m in self._memories
154
+ ]
155
+
156
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
157
+ current = self.snapshot_memories()
158
+ current_ids = {s.memory_id for s in current}
159
+ deltas = [
160
+ MemoryDeltaRecord(
161
+ session_id=session_id, op="add", text=s.text,
162
+ linked_previous=(), raw_backend_id=s.raw_backend_id,
163
+ metadata={"baseline": "MemVerse"},
164
+ )
165
+ for s in current if s.memory_id not in self._prev_snapshot_ids
166
+ ]
167
+ self._prev_snapshot_ids = current_ids
168
+ return deltas
169
+
170
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
171
+ if not self._memories:
172
+ return RetrievalRecord(query=query, top_k=top_k, items=[], raw_trace={})
173
+
174
+ query_emb = self._get_embedding(query)
175
+ scored = []
176
+ for m in self._memories:
177
+ sim = self._cosine_sim(query_emb, m["embedding"])
178
+ scored.append((sim, m))
179
+ scored.sort(key=lambda x: x[0], reverse=True)
180
+
181
+ items = [
182
+ RetrievalItem(
183
+ rank=i,
184
+ memory_id=m["id"],
185
+ text=m["output"],
186
+ score=float(sim),
187
+ raw_backend_id=m["id"],
188
+ )
189
+ for i, (sim, m) in enumerate(scored[:top_k])
190
+ ]
191
+ return RetrievalRecord(
192
+ query=query, top_k=top_k, items=items,
193
+ raw_trace={"baseline": "MemVerse"},
194
+ )
195
+
196
+ def get_capabilities(self) -> dict[str, Any]:
197
+ return {
198
+ "backend": "MemVerse",
199
+ "baseline": "MemVerse",
200
+ "available": True,
201
+ "delta_granularity": "per_turn",
202
+ "snapshot_mode": "full",
203
+ }
memory_adapters/registry.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Registry and factory for native Mem-Gallery and external placeholder adapters."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import sys
7
+ import types
8
+ from contextlib import nullcontext
9
+ from functools import partial
10
+ import importlib
11
+ from pathlib import Path
12
+ from typing import Any, Callable
13
+
14
+ from eval_framework.memory_adapters.amem import AMemAdapter
15
+ from eval_framework.memory_adapters.base import MemoryAdapter
16
+ from eval_framework.memory_adapters.dummy import DummyAdapter
17
+ from eval_framework.memory_adapters.memgallery_native import MemGalleryNativeAdapter
18
+ from eval_framework.memory_adapters.memoryos import MemoryOSAdapter
19
+
20
+ MEMGALLERY_NATIVE_BASELINES: frozenset[str] = frozenset(
21
+ {
22
+ "FUMemory",
23
+ "STMemory",
24
+ "LTMemory",
25
+ "GAMemory",
26
+ "MGMemory",
27
+ "RFMemory",
28
+ "MMMemory",
29
+ "MMFUMemory",
30
+ "NGMemory",
31
+ "AUGUSTUSMemory",
32
+ "UniversalRAGMemory",
33
+ }
34
+ )
35
+
36
+
37
+ def _word_mode_truncation(number: int = 50_000) -> dict[str, Any]:
38
+ return {
39
+ "method": "LMTruncation",
40
+ "mode": "word",
41
+ "number": number,
42
+ "path": "",
43
+ }
44
+
45
+
46
+ def _text_encoder_override() -> dict[str, Any]:
47
+ return {
48
+ "method": "STEncoder",
49
+ "path": "all-MiniLM-L6-v2",
50
+ }
51
+
52
+
53
+ def _openai_llm_override() -> dict[str, Any]:
54
+ return {
55
+ "method": "APILLM",
56
+ "name": os.getenv("OPENAI_MODEL") or "gpt-5.1",
57
+ "api_key": os.getenv("OPENAI_API_KEY") or "",
58
+ "base_url": os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1",
59
+ "temperature": float(os.getenv("OPENAI_TEMPERATURE", "0.0")),
60
+ }
61
+
62
+
63
+ def _default_memgallery_runtime_overrides(baseline_name: str) -> dict[str, Any]:
64
+ overrides: dict[str, Any] = {}
65
+
66
+ # --- text-only baselines ---
67
+ if baseline_name in {"FUMemory", "STMemory", "LTMemory", "RFMemory"}:
68
+ overrides["recall"] = {"truncation": _word_mode_truncation()}
69
+ if baseline_name == "LTMemory":
70
+ overrides.setdefault("recall", {})
71
+ overrides["recall"]["text_retrieval"] = {"encoder": _text_encoder_override()}
72
+ if baseline_name == "GAMemory":
73
+ overrides = {
74
+ "recall": {
75
+ "truncation": _word_mode_truncation(),
76
+ "text_retrieval": {"encoder": _text_encoder_override()},
77
+ "importance_judge": {"LLM_config": _openai_llm_override()},
78
+ },
79
+ "reflect": {
80
+ "reflector": {"LLM_config": _openai_llm_override()},
81
+ },
82
+ }
83
+ if baseline_name == "MGMemory":
84
+ overrides = {
85
+ "recall": {
86
+ "truncation": _word_mode_truncation(),
87
+ "recall_retrieval": {"encoder": _text_encoder_override()},
88
+ "archival_retrieval": {"encoder": _text_encoder_override()},
89
+ "trigger": {"LLM_config": _openai_llm_override()},
90
+ },
91
+ "store": {
92
+ "flush_checker": _word_mode_truncation(),
93
+ "summarizer": {"LLM_config": _openai_llm_override()},
94
+ },
95
+ }
96
+ if baseline_name == "RFMemory":
97
+ overrides.setdefault("optimize", {})
98
+ overrides["optimize"] = {
99
+ "reflector": {"LLM_config": _openai_llm_override()},
100
+ }
101
+
102
+ # --- multimodal baselines ---
103
+ if baseline_name == "MMMemory":
104
+ overrides = {
105
+ "recall": {
106
+ "truncation": _word_mode_truncation(),
107
+ },
108
+ }
109
+ if baseline_name == "MMFUMemory":
110
+ overrides = {
111
+ "recall": {
112
+ "truncation": _word_mode_truncation(),
113
+ },
114
+ }
115
+ if baseline_name == "NGMemory":
116
+ overrides = {
117
+ "recall": {
118
+ "truncation": _word_mode_truncation(),
119
+ },
120
+ }
121
+ if baseline_name == "AUGUSTUSMemory":
122
+ overrides = {
123
+ "recall": {
124
+ "truncation": _word_mode_truncation(),
125
+ },
126
+ "concept_extractor": {
127
+ "llm": _openai_llm_override(),
128
+ },
129
+ }
130
+ if baseline_name == "UniversalRAGMemory":
131
+ overrides = {
132
+ "recall": {
133
+ "truncation": _word_mode_truncation(),
134
+ "text_retrieval": {"encoder": _text_encoder_override()},
135
+ },
136
+ "routing": {
137
+ "llm": _openai_llm_override(),
138
+ },
139
+ }
140
+ return overrides
141
+
142
+
143
+ def _resolve_baselines_root() -> Path:
144
+ """Return the ``baselines/`` directory (sibling of eval_framework/).
145
+
146
+ Layout::
147
+
148
+ nips26/
149
+ ├── eval_framework/
150
+ └── baselines/
151
+ ├── memengine/
152
+ └── default_config/
153
+ """
154
+ # registry.py -> memory_adapters/ -> eval_framework/ -> nips26/
155
+ return Path(__file__).resolve().parents[2] / "baselines"
156
+
157
+
158
+ def _ensure_memgallery_benchmark_on_path() -> Path:
159
+ """Add ``baselines/`` to sys.path so that ``memengine`` and
160
+ ``default_config`` packages are importable."""
161
+ baselines_root = _resolve_baselines_root()
162
+ if not (baselines_root / "memengine").is_dir():
163
+ raise FileNotFoundError(
164
+ f"memengine/ not found under {baselines_root}. "
165
+ f"Clone MemEngine into baselines/memengine."
166
+ )
167
+ s = str(baselines_root)
168
+ if s not in sys.path:
169
+ sys.path.insert(0, s)
170
+ _bootstrap_memengine_namespace(baselines_root)
171
+ return baselines_root
172
+
173
+
174
+ def _bootstrap_memengine_namespace(root: Path) -> None:
175
+ """
176
+ Pre-seed lightweight namespace packages for the co-located memengine package.
177
+
178
+ memengine's package-level ``__init__.py`` eagerly imports all memories and function
179
+ modules, which pulls in heavyweight optional dependencies like ``torch`` even for
180
+ simple baselines such as ``FUMemory``. By registering package shells in ``sys.modules``
181
+ first, we can import only the specific submodules we need.
182
+
183
+ *root* is the ``our/`` directory that contains ``memengine/``.
184
+ """
185
+ package_paths = {
186
+ "memengine": root / "memengine",
187
+ "memengine.config": root / "memengine" / "config",
188
+ "memengine.memory": root / "memengine" / "memory",
189
+ "memengine.function": root / "memengine" / "function",
190
+ "memengine.operation": root / "memengine" / "operation",
191
+ "memengine.utils": root / "memengine" / "utils",
192
+ }
193
+ for package_name, package_path in package_paths.items():
194
+ existing = sys.modules.get(package_name)
195
+ if existing is not None:
196
+ continue
197
+ module = types.ModuleType(package_name)
198
+ module.__path__ = [str(package_path)] # type: ignore[attr-defined]
199
+ module.__package__ = package_name
200
+ sys.modules[package_name] = module
201
+
202
+ for package_name in package_paths:
203
+ if "." not in package_name:
204
+ continue
205
+ parent_name, child_name = package_name.rsplit(".", 1)
206
+ parent = sys.modules.get(parent_name)
207
+ child = sys.modules.get(package_name)
208
+ if parent is not None and child is not None and not hasattr(parent, child_name):
209
+ setattr(parent, child_name, child)
210
+
211
+ _bootstrap_optional_dependency_stubs()
212
+ _populate_safe_memengine_function_exports()
213
+
214
+
215
+ def _bootstrap_optional_dependency_stubs() -> None:
216
+ """Provide narrow stubs for optional imports needed only on unused code paths."""
217
+ if "torch" not in sys.modules:
218
+ try:
219
+ sys.modules["torch"] = importlib.import_module("torch")
220
+ except Exception:
221
+ pass
222
+ if "torch" not in sys.modules:
223
+ torch_module = types.ModuleType("torch")
224
+
225
+ def _torch_unavailable(*args: Any, **kwargs: Any) -> Any:
226
+ del args, kwargs
227
+ raise RuntimeError(
228
+ "PyTorch is required for encoder-backed or tensor-based Mem-Gallery "
229
+ "baselines, but `torch` is not installed in this environment."
230
+ )
231
+
232
+ torch_module.cuda = types.SimpleNamespace(is_available=lambda: False) # type: ignore[attr-defined]
233
+ torch_module.device = lambda spec: spec # type: ignore[attr-defined]
234
+ torch_module.no_grad = lambda: nullcontext() # type: ignore[attr-defined]
235
+ torch_module.from_numpy = _torch_unavailable # type: ignore[attr-defined]
236
+ torch_module.stack = _torch_unavailable # type: ignore[attr-defined]
237
+ torch_module.sort = _torch_unavailable # type: ignore[attr-defined]
238
+ torch_module.matmul = _torch_unavailable # type: ignore[attr-defined]
239
+ torch_module.ones = _torch_unavailable # type: ignore[attr-defined]
240
+ torch_module.nn = types.SimpleNamespace( # type: ignore[attr-defined]
241
+ functional=types.SimpleNamespace(normalize=_torch_unavailable)
242
+ )
243
+ sys.modules["torch"] = torch_module
244
+
245
+ if "transformers" not in sys.modules:
246
+ try:
247
+ sys.modules["transformers"] = importlib.import_module("transformers")
248
+ except Exception:
249
+ pass
250
+ if "transformers" not in sys.modules:
251
+ transformers_module = types.ModuleType("transformers")
252
+
253
+ class _UnavailableAutoTokenizer:
254
+ @classmethod
255
+ def from_pretrained(cls, *args: Any, **kwargs: Any) -> Any:
256
+ del args, kwargs
257
+ raise RuntimeError(
258
+ "transformers.AutoTokenizer is required for token-mode truncation "
259
+ "or encoder-backed baselines, but `transformers` is not installed."
260
+ )
261
+
262
+ transformers_module.AutoTokenizer = _UnavailableAutoTokenizer # type: ignore[attr-defined]
263
+ sys.modules["transformers"] = transformers_module
264
+
265
+
266
+ def _populate_safe_memengine_function_exports() -> None:
267
+ """Expose all function symbols for complete baseline deployment without running package __init__."""
268
+ function_pkg = sys.modules.get("memengine.function")
269
+ if function_pkg is None:
270
+ return
271
+
272
+ # Complete list — covers every module referenced by any of the 11 baselines:
273
+ # FU/ST/LT/GA/MG/RF (text-only) + MM/MMFU/NG/AUGUSTUS/UniversalRAG (multimodal)
274
+ for module_name in (
275
+ # --- text-only baselines ---
276
+ "memengine.function.Encoder",
277
+ "memengine.function.Retrieval",
278
+ "memengine.function.LLM",
279
+ "memengine.function.Judge",
280
+ "memengine.function.Reflector",
281
+ "memengine.function.Summarizer",
282
+ "memengine.function.Truncation",
283
+ "memengine.function.Trigger",
284
+ "memengine.function.Utilization",
285
+ "memengine.function.Forget",
286
+ # --- multimodal / graph / concept baselines ---
287
+ "memengine.function.MultiModalEncoder",
288
+ "memengine.function.MultiModalRetrieval",
289
+ "memengine.function.ConceptExtractor",
290
+ "memengine.function.ConceptBasedRetrieval",
291
+ "memengine.function.EntityExtractor",
292
+ "memengine.function.FactExtractor",
293
+ "memengine.function.UniversalRAGRouting",
294
+ "memengine.function.UniversalRAGRetrieval",
295
+ ):
296
+ try:
297
+ module = importlib.import_module(module_name)
298
+ except Exception:
299
+ # Some modules may depend on optional heavy deps (torch, transformers).
300
+ # Skip gracefully — they will fail loudly if the baseline actually needs them.
301
+ continue
302
+ for attr_name, value in vars(module).items():
303
+ if attr_name.startswith("_"):
304
+ continue
305
+ if not hasattr(function_pkg, attr_name):
306
+ setattr(function_pkg, attr_name, value)
307
+
308
+
309
+ def create_memgallery_adapter(
310
+ baseline_name: str,
311
+ *,
312
+ config_overrides: dict[str, Any] | None = None,
313
+ ) -> MemGalleryNativeAdapter:
314
+ """
315
+ Instantiate a native Mem-Gallery adapter for a known baseline name.
316
+
317
+ Loads default_config + memengine from the Mem-Gallery benchmark tree.
318
+ """
319
+ if baseline_name not in MEMGALLERY_NATIVE_BASELINES:
320
+ raise KeyError(f"unknown Mem-Gallery baseline: {baseline_name!r}")
321
+ _ensure_memgallery_benchmark_on_path()
322
+ runtime_overrides = _default_memgallery_runtime_overrides(baseline_name)
323
+ if config_overrides:
324
+ runtime_overrides = {
325
+ **runtime_overrides,
326
+ **config_overrides,
327
+ }
328
+ return MemGalleryNativeAdapter.from_baseline(
329
+ baseline_name, config=runtime_overrides or None
330
+ )
331
+
332
+
333
+ MEMGALLERY_NATIVE_REGISTRY: dict[str, Callable[..., MemGalleryNativeAdapter]] = {
334
+ name: partial(create_memgallery_adapter, name) for name in MEMGALLERY_NATIVE_BASELINES
335
+ }
336
+
337
+ EXTERNAL_ADAPTER_KEYS: frozenset[str] = frozenset({
338
+ "A-Mem", "MemoryOS", "Dummy",
339
+ "Mem0", "Mem0-Graph",
340
+ "SimpleMem", "Omni-SimpleMem",
341
+ "MemVerse",
342
+ "Zep",
343
+ })
344
+
345
+
346
+ def create_amem_adapter(**kwargs: Any) -> MemoryAdapter:
347
+ from eval_framework.memory_adapters.amem_v2 import AMemV2Adapter
348
+ return AMemV2Adapter(**kwargs)
349
+
350
+
351
+ def create_memoryos_adapter(**kwargs: Any) -> MemoryOSAdapter:
352
+ return MemoryOSAdapter(**kwargs)
353
+
354
+
355
+ def create_dummy_adapter(**kwargs: Any) -> DummyAdapter:
356
+ return DummyAdapter()
357
+
358
+
359
+ def create_mem0_adapter(**kwargs: Any) -> MemoryAdapter:
360
+ from eval_framework.memory_adapters.mem0_adapter import Mem0Adapter
361
+ return Mem0Adapter(use_graph=False, **kwargs)
362
+
363
+
364
+ def create_mem0_graph_adapter(**kwargs: Any) -> MemoryAdapter:
365
+ from eval_framework.memory_adapters.mem0_adapter import Mem0Adapter
366
+ return Mem0Adapter(use_graph=True, **kwargs)
367
+
368
+
369
+ def create_simplemem_adapter(**kwargs: Any) -> MemoryAdapter:
370
+ from eval_framework.memory_adapters.simplemem_adapter import SimpleMemAdapter
371
+ return SimpleMemAdapter(mode="text", **kwargs)
372
+
373
+
374
+ def create_omni_simplemem_adapter(**kwargs: Any) -> MemoryAdapter:
375
+ from eval_framework.memory_adapters.simplemem_adapter import SimpleMemAdapter
376
+ return SimpleMemAdapter(mode="omni", **kwargs)
377
+
378
+
379
+ def create_memverse_adapter(**kwargs: Any) -> MemoryAdapter:
380
+ from eval_framework.memory_adapters.memverse_adapter import MemVerseAdapter
381
+ return MemVerseAdapter(**kwargs)
382
+
383
+
384
+ def create_zep_adapter(**kwargs: Any) -> MemoryAdapter:
385
+ from eval_framework.memory_adapters.zep_adapter import ZepAdapter
386
+ return ZepAdapter(**kwargs)
387
+
388
+
389
+ EXTERNAL_ADAPTER_REGISTRY: dict[str, Callable[..., MemoryAdapter]] = {
390
+ "A-Mem": create_amem_adapter,
391
+ "MemoryOS": create_memoryos_adapter,
392
+ "Dummy": create_dummy_adapter,
393
+ "Mem0": create_mem0_adapter,
394
+ "Mem0-Graph": create_mem0_graph_adapter,
395
+ "SimpleMem": create_simplemem_adapter,
396
+ "Omni-SimpleMem": create_omni_simplemem_adapter,
397
+ "MemVerse": create_memverse_adapter,
398
+ "Zep": create_zep_adapter,
399
+ }
400
+
401
+
402
+ def create_external_adapter(
403
+ name: str,
404
+ *,
405
+ config_overrides: dict[str, Any] | None = None,
406
+ ) -> MemoryAdapter:
407
+ """Instantiate an external adapter for a known baseline name."""
408
+ if name not in EXTERNAL_ADAPTER_KEYS:
409
+ raise KeyError(f"unknown external adapter: {name!r}")
410
+ return EXTERNAL_ADAPTER_REGISTRY[name](**(config_overrides or {}))
memory_adapters/simplemem_adapter.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapter for SimpleMem and Omni-SimpleMem baselines."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ from eval_framework.datasets.schemas import (
11
+ MemoryDeltaRecord,
12
+ MemorySnapshotRecord,
13
+ NormalizedTurn,
14
+ RetrievalItem,
15
+ RetrievalRecord,
16
+ )
17
+ from eval_framework.memory_adapters.base import MemoryAdapter
18
+
19
+ _DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/SimpleMem")
20
+
21
+
22
+ class SimpleMemAdapter(MemoryAdapter):
23
+ """Adapter for SimpleMem (text mode) or Omni-SimpleMem (omni mode)."""
24
+
25
+ def __init__(
26
+ self,
27
+ *,
28
+ mode: str = "text",
29
+ source_root: str | os.PathLike[str] | None = None,
30
+ **kwargs: Any,
31
+ ) -> None:
32
+ self._mode = mode # "text" or "omni"
33
+ root = Path(source_root or _DEFAULT_SOURCE).resolve()
34
+ if str(root) not in sys.path:
35
+ sys.path.insert(0, str(root))
36
+
37
+ import simplemem_router as simplemem
38
+ self._simplemem = simplemem
39
+ self._mem: Any = None
40
+ self._session_id = ""
41
+ self._prev_snapshot_ids: set[str] = set()
42
+ self._stored_texts: list[dict[str, str]] = []
43
+ self._init_mem()
44
+
45
+ def _init_mem(self) -> None:
46
+ self._mem = self._simplemem.create(mode=self._mode, clear_db=True)
47
+ self._stored_texts = []
48
+
49
+ def reset(self) -> None:
50
+ if self._mem is not None:
51
+ try:
52
+ self._mem.close()
53
+ except Exception:
54
+ pass
55
+ self._init_mem()
56
+ self._prev_snapshot_ids = set()
57
+
58
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
59
+ self._session_id = turn.session_id
60
+ text = f"{turn.role}: {turn.text}"
61
+ for att in turn.attachments:
62
+ text += f"\n[{att.type}] {att.caption}"
63
+
64
+ mid = str(len(self._stored_texts))
65
+ if self._mode == "omni":
66
+ self._mem.add_text(text, tags=[f"session:{turn.session_id}"])
67
+ else:
68
+ speaker = "User" if turn.role == "user" else "Assistant"
69
+ ts = turn.timestamp or ""
70
+ self._mem.add_dialogue(speaker, text, ts)
71
+ self._stored_texts.append({"id": mid, "text": text, "session_id": turn.session_id})
72
+
73
+ def end_session(self, session_id: str) -> None:
74
+ self._session_id = session_id
75
+ if self._mode == "text":
76
+ try:
77
+ self._mem.finalize()
78
+ except Exception:
79
+ pass
80
+
81
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
82
+ return [
83
+ MemorySnapshotRecord(
84
+ memory_id=t["id"], text=t["text"],
85
+ session_id=t["session_id"], status="active",
86
+ source=f"SimpleMem-{self._mode}",
87
+ raw_backend_id=t["id"], raw_backend_type="simplemem",
88
+ metadata={},
89
+ )
90
+ for t in self._stored_texts
91
+ ]
92
+
93
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
94
+ current = self.snapshot_memories()
95
+ current_ids = {s.memory_id for s in current}
96
+ deltas = [
97
+ MemoryDeltaRecord(
98
+ session_id=session_id, op="add", text=s.text,
99
+ linked_previous=(), raw_backend_id=s.raw_backend_id,
100
+ metadata={"baseline": f"SimpleMem-{self._mode}"},
101
+ )
102
+ for s in current if s.memory_id not in self._prev_snapshot_ids
103
+ ]
104
+ self._prev_snapshot_ids = current_ids
105
+ return deltas
106
+
107
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
108
+ items: list[RetrievalItem] = []
109
+ try:
110
+ if self._mode == "omni":
111
+ result = self._mem.query(query, top_k=top_k)
112
+ if isinstance(result, list):
113
+ for i, r in enumerate(result[:top_k]):
114
+ text = r.get("text", str(r)) if isinstance(r, dict) else str(r)
115
+ items.append(RetrievalItem(
116
+ rank=i, memory_id=str(i), text=text,
117
+ score=1.0 / (i + 1), raw_backend_id=None,
118
+ ))
119
+ else:
120
+ answer = self._mem.ask(query)
121
+ if answer:
122
+ items.append(RetrievalItem(
123
+ rank=0, memory_id="answer", text=str(answer),
124
+ score=1.0, raw_backend_id=None,
125
+ ))
126
+ except Exception:
127
+ pass
128
+
129
+ if not items:
130
+ # Fallback: simple text search over stored memories
131
+ query_lower = query.lower()
132
+ scored = []
133
+ for t in self._stored_texts:
134
+ overlap = len(set(query_lower.split()) & set(t["text"].lower().split()))
135
+ scored.append((overlap, t))
136
+ scored.sort(key=lambda x: x[0], reverse=True)
137
+ for i, (sc, t) in enumerate(scored[:top_k]):
138
+ items.append(RetrievalItem(
139
+ rank=i, memory_id=t["id"], text=t["text"],
140
+ score=float(sc) / max(len(query.split()), 1),
141
+ raw_backend_id=t["id"],
142
+ ))
143
+
144
+ return RetrievalRecord(
145
+ query=query, top_k=top_k, items=items[:top_k],
146
+ raw_trace={"baseline": f"SimpleMem-{self._mode}"},
147
+ )
148
+
149
+ def get_capabilities(self) -> dict[str, Any]:
150
+ name = "Omni-SimpleMem" if self._mode == "omni" else "SimpleMem"
151
+ return {
152
+ "backend": name, "baseline": name,
153
+ "available": self._mem is not None,
154
+ "delta_granularity": "per_turn",
155
+ "snapshot_mode": "full",
156
+ }
memory_adapters/zep_adapter.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Adapter for Zep memory system (community/self-hosted edition)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ import uuid as _uuid
7
+ from typing import Any
8
+
9
+ from eval_framework.datasets.schemas import (
10
+ MemoryDeltaRecord,
11
+ MemorySnapshotRecord,
12
+ NormalizedTurn,
13
+ RetrievalItem,
14
+ RetrievalRecord,
15
+ )
16
+ from eval_framework.memory_adapters.base import MemoryAdapter
17
+
18
+
19
+ class ZepAdapter(MemoryAdapter):
20
+ """Adapter for Zep community edition (self-hosted)."""
21
+
22
+ def __init__(self, *, base_url: str | None = None, **kwargs: Any) -> None:
23
+ from zep_python import ZepClient
24
+
25
+ self._base_url = base_url or os.getenv("ZEP_BASE_URL", "http://localhost:8000")
26
+ self._client = ZepClient(base_url=self._base_url)
27
+ self._session_id = ""
28
+ self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}"
29
+ self._prev_snapshot_ids: set[str] = set()
30
+
31
+ def reset(self) -> None:
32
+ try:
33
+ self._client.memory.delete_memory(self._thread_id)
34
+ except Exception:
35
+ pass
36
+ self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}"
37
+ self._prev_snapshot_ids = set()
38
+
39
+ def ingest_turn(self, turn: NormalizedTurn) -> None:
40
+ from zep_python.memory import Memory
41
+ from zep_python.message import Message
42
+
43
+ self._session_id = turn.session_id
44
+ text = f"{turn.role}: {turn.text}"
45
+ for att in turn.attachments:
46
+ text += f"\n[{att.type}] {att.caption}"
47
+
48
+ role_type = "user" if turn.role == "user" else "ai"
49
+ msg = Message(role=turn.role, role_type=role_type, content=text)
50
+ memory = Memory(messages=[msg])
51
+ self._client.memory.add_memory(self._thread_id, memory)
52
+
53
+ def end_session(self, session_id: str) -> None:
54
+ self._session_id = session_id
55
+
56
+ def snapshot_memories(self) -> list[MemorySnapshotRecord]:
57
+ try:
58
+ memory = self._client.memory.get_memory(self._thread_id)
59
+ except Exception:
60
+ return []
61
+
62
+ rows: list[MemorySnapshotRecord] = []
63
+ if memory and memory.messages:
64
+ for i, msg in enumerate(memory.messages):
65
+ mid = str(getattr(msg, "uuid", i))
66
+ rows.append(MemorySnapshotRecord(
67
+ memory_id=mid,
68
+ text=msg.content or "",
69
+ session_id=self._session_id,
70
+ status="active",
71
+ source="Zep",
72
+ raw_backend_id=mid,
73
+ raw_backend_type="zep_message",
74
+ metadata={},
75
+ ))
76
+ return rows
77
+
78
+ def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]:
79
+ current = self.snapshot_memories()
80
+ current_ids = {s.memory_id for s in current}
81
+ deltas = [
82
+ MemoryDeltaRecord(
83
+ session_id=session_id, op="add", text=s.text,
84
+ linked_previous=(), raw_backend_id=s.raw_backend_id,
85
+ metadata={"baseline": "Zep"},
86
+ )
87
+ for s in current if s.memory_id not in self._prev_snapshot_ids
88
+ ]
89
+ self._prev_snapshot_ids = current_ids
90
+ return deltas
91
+
92
+ def retrieve(self, query: str, top_k: int) -> RetrievalRecord:
93
+ try:
94
+ results = self._client.memory.search_memory(
95
+ self._thread_id, query, limit=top_k,
96
+ )
97
+ except Exception:
98
+ results = []
99
+
100
+ items = [
101
+ RetrievalItem(
102
+ rank=i,
103
+ memory_id=str(getattr(r.message, "uuid", i)) if r.message else str(i),
104
+ text=r.message.content if r.message else str(r),
105
+ score=float(getattr(r, "score", 1.0 / (i + 1))),
106
+ raw_backend_id=str(getattr(r.message, "uuid", "")) if r.message else None,
107
+ )
108
+ for i, r in enumerate(results[:top_k])
109
+ ]
110
+ return RetrievalRecord(
111
+ query=query, top_k=top_k, items=items,
112
+ raw_trace={"baseline": "Zep"},
113
+ )
114
+
115
+ def get_capabilities(self) -> dict[str, Any]:
116
+ return {
117
+ "backend": "Zep",
118
+ "baseline": "Zep",
119
+ "available": True,
120
+ "delta_granularity": "snapshot_diff",
121
+ "snapshot_mode": "full",
122
+ }
openai_compat.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compatibility helpers for OpenAI chat completions across model families."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+
8
+ def rewrite_chat_completion_kwargs(payload: dict[str, Any]) -> dict[str, Any]:
9
+ """Translate deprecated chat completion parameters for reasoning models."""
10
+ rewritten = dict(payload)
11
+ model = str(rewritten.get("model") or "")
12
+ if (
13
+ model.startswith("gpt-5")
14
+ and "max_tokens" in rewritten
15
+ and "max_completion_tokens" not in rewritten
16
+ ):
17
+ rewritten["max_completion_tokens"] = rewritten.pop("max_tokens")
18
+ return rewritten
19
+
20
+
21
+ def patch_openai_chat_completions() -> bool:
22
+ """Monkeypatch the OpenAI SDK so GPT-5 chat calls accept legacy max_tokens."""
23
+ try:
24
+ from openai.resources.chat.completions.completions import Completions
25
+ except Exception:
26
+ return False
27
+
28
+ current = Completions.create
29
+ if getattr(current, "_eval_framework_patched", False):
30
+ return True
31
+
32
+ original_create = current
33
+
34
+ def _patched_create(self: Any, *args: Any, **kwargs: Any) -> Any:
35
+ rewritten = rewrite_chat_completion_kwargs(kwargs)
36
+ try:
37
+ return original_create(self, *args, **rewritten)
38
+ except Exception as exc:
39
+ if (
40
+ "Unsupported parameter: 'max_tokens'" in str(exc)
41
+ and "max_tokens" in kwargs
42
+ ):
43
+ retried = rewrite_chat_completion_kwargs(kwargs)
44
+ return original_create(self, *args, **retried)
45
+ raise
46
+
47
+ _patched_create._eval_framework_patched = True # type: ignore[attr-defined]
48
+ Completions.create = _patched_create # type: ignore[assignment]
49
+ return True
pipeline/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Session and checkpoint orchestration."""
pipeline/gold_state.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Cumulative gold memory state from staged memory-point annotations."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, Mapping, Sequence
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class GoldMemoryPoint:
11
+ memory_id: str
12
+ memory_content: str
13
+ memory_type: str
14
+ memory_source: str
15
+ is_update: bool
16
+ original_memories: tuple[str, ...]
17
+ importance: float
18
+ timestamp: str | None = None
19
+ update_type: str = ""
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class SessionGoldState:
24
+ session_id: str
25
+ cumulative_gold_memories: tuple[GoldMemoryPoint, ...]
26
+ session_new_memories: tuple[GoldMemoryPoint, ...]
27
+ session_update_memories: tuple[GoldMemoryPoint, ...]
28
+ session_interference_memories: tuple[GoldMemoryPoint, ...]
29
+
30
+
31
+ def _as_str_tuple(val: Any) -> tuple[str, ...]:
32
+ if val is None:
33
+ return ()
34
+ if isinstance(val, str):
35
+ return (val,)
36
+ if isinstance(val, Sequence) and not isinstance(val, (str, bytes)):
37
+ out: list[str] = []
38
+ for x in val:
39
+ out.append(x if isinstance(x, str) else str(x))
40
+ return tuple(out)
41
+ return (str(val),)
42
+
43
+
44
+ def _parse_is_update(raw: Mapping[str, Any]) -> bool:
45
+ v = raw.get("is_update")
46
+ if v is True:
47
+ return True
48
+ if v is False or v is None:
49
+ return False
50
+ if isinstance(v, str):
51
+ return v.strip().lower() == "true"
52
+ return bool(v)
53
+
54
+
55
+ def _parse_importance(raw: Mapping[str, Any]) -> float:
56
+ value = raw.get("importance", 0.0)
57
+ try:
58
+ return float(value)
59
+ except (TypeError, ValueError):
60
+ return 0.0
61
+
62
+
63
+ def gold_point_from_raw(raw: Mapping[str, Any]) -> GoldMemoryPoint:
64
+ mid = raw.get("memory_id")
65
+ content = raw.get("memory_content")
66
+ return GoldMemoryPoint(
67
+ memory_id=str(mid) if mid is not None else "",
68
+ memory_content=str(content) if content is not None else "",
69
+ memory_type=str(raw.get("memory_type", "")),
70
+ memory_source=str(raw.get("memory_source", "")),
71
+ is_update=_parse_is_update(raw),
72
+ original_memories=_as_str_tuple(raw.get("original_memories")),
73
+ importance=_parse_importance(raw),
74
+ timestamp=(
75
+ str(raw["timestamp"])
76
+ if raw.get("timestamp") is not None
77
+ else None
78
+ ),
79
+ update_type=str(raw.get("update_type", "") or ""),
80
+ )
81
+
82
+
83
+ def build_session_gold_states(
84
+ ordered_session_ids: Sequence[str],
85
+ *,
86
+ s00_memory_points: Sequence[Mapping[str, Any]],
87
+ stage4_by_session_id: Mapping[str, Sequence[Mapping[str, Any]]],
88
+ ) -> tuple[SessionGoldState, ...]:
89
+ """Accumulate non-interference gold memories through sessions in order.
90
+
91
+ S00 is taken from the domain JSON session; later sessions prefer ``stage4``
92
+ rows when present, since those drive staged evaluation labels.
93
+ """
94
+ cumulative: list[GoldMemoryPoint] = []
95
+ states: list[SessionGoldState] = []
96
+
97
+ for sid in ordered_session_ids:
98
+ if sid == "S00":
99
+ raw_points: Sequence[Mapping[str, Any]] = s00_memory_points
100
+ else:
101
+ raw_points = stage4_by_session_id.get(sid)
102
+ if raw_points is None:
103
+ raw_points = ()
104
+
105
+ news: list[GoldMemoryPoint] = []
106
+ updates: list[GoldMemoryPoint] = []
107
+ interference: list[GoldMemoryPoint] = []
108
+
109
+ for raw in raw_points:
110
+ gp = gold_point_from_raw(raw)
111
+ if gp.memory_source == "interference":
112
+ interference.append(gp)
113
+ continue
114
+ cumulative.append(gp)
115
+ if gp.is_update:
116
+ updates.append(gp)
117
+ else:
118
+ news.append(gp)
119
+
120
+ states.append(
121
+ SessionGoldState(
122
+ session_id=sid,
123
+ cumulative_gold_memories=tuple(cumulative),
124
+ session_new_memories=tuple(news),
125
+ session_update_memories=tuple(updates),
126
+ session_interference_memories=tuple(interference),
127
+ )
128
+ )
129
+
130
+ return tuple(states)
pipeline/qa_runner.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared checkpoint QA: retrieval via adapter + answer from an injected callable.
2
+
3
+ ``AnswerFn`` may return either a plain ``str`` (legacy) or a
4
+ ``(str, list[str])`` tuple of ``(answer, cited_memories)``.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from collections.abc import Callable
10
+ from typing import Union
11
+
12
+ from eval_framework.datasets.domain_a_v2 import NormalizedCheckpoint, NormalizedCheckpointQuestion
13
+ from eval_framework.datasets.schemas import RetrievalRecord
14
+ from eval_framework.memory_adapters.base import MemoryAdapter
15
+ from eval_framework.pipeline.records import PipelineCheckpointQARecord
16
+
17
+ # answer_fn may return str (legacy) or (str, list[str])
18
+ AnswerResult = Union[str, tuple[str, list[str]]]
19
+ AnswerFn = Callable[[NormalizedCheckpointQuestion, RetrievalRecord], AnswerResult]
20
+
21
+
22
+ def run_checkpoint_qa_records(
23
+ adapter: MemoryAdapter,
24
+ *,
25
+ sample_id: str,
26
+ sample_uuid: str,
27
+ checkpoint: NormalizedCheckpoint,
28
+ top_k: int,
29
+ answer_fn: AnswerFn,
30
+ ) -> tuple[PipelineCheckpointQARecord, ...]:
31
+ """For each question, call ``retrieve`` then ``answer_fn`` (not ``adapter.answer``)."""
32
+ out: list[PipelineCheckpointQARecord] = []
33
+ for q in checkpoint.questions:
34
+ retrieval = adapter.retrieve(q.question, top_k)
35
+ result = answer_fn(q, retrieval)
36
+
37
+ if isinstance(result, tuple):
38
+ generated, cited = result
39
+ else:
40
+ generated, cited = result, []
41
+
42
+ out.append(
43
+ PipelineCheckpointQARecord(
44
+ sample_id=sample_id,
45
+ sample_uuid=sample_uuid,
46
+ checkpoint_id=checkpoint.checkpoint_id,
47
+ question=q.question,
48
+ gold_answer=q.gold_answer,
49
+ gold_evidence_memory_ids=q.gold_evidence_memory_ids,
50
+ gold_evidence_contents=q.gold_evidence_contents,
51
+ question_type=q.question_type,
52
+ question_type_abbrev=q.question_type_abbrev,
53
+ difficulty=q.difficulty,
54
+ retrieval=retrieval,
55
+ generated_answer=generated,
56
+ cited_memories=tuple(cited),
57
+ )
58
+ )
59
+ return tuple(out)
pipeline/records.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pipeline-facing aliases and runtime record types emitted by the eval runner."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ from eval_framework.datasets.schemas import (
8
+ Attachment,
9
+ MemoryDeltaRecord,
10
+ MemorySnapshotRecord,
11
+ NormalizedTurn,
12
+ RetrievalItem,
13
+ RetrievalRecord,
14
+ normalize_turn,
15
+ )
16
+ from eval_framework.pipeline.gold_state import SessionGoldState
17
+
18
+ __all__ = [
19
+ "Attachment",
20
+ "MemoryDeltaRecord",
21
+ "MemorySnapshotRecord",
22
+ "NormalizedTurn",
23
+ "PipelineCheckpointQARecord",
24
+ "PipelineSessionRecord",
25
+ "RetrievalItem",
26
+ "RetrievalRecord",
27
+ "SessionGoldState",
28
+ "normalize_turn",
29
+ ]
30
+
31
+
32
+ @dataclass(frozen=True)
33
+ class PipelineSessionRecord:
34
+ """Normalized outputs after one dialogue session (one row per session)."""
35
+
36
+ sample_id: str
37
+ sample_uuid: str
38
+ session_id: str
39
+ memory_snapshot: tuple[MemorySnapshotRecord, ...]
40
+ memory_delta: tuple[MemoryDeltaRecord, ...]
41
+ gold_state: SessionGoldState
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class PipelineCheckpointQARecord:
46
+ """One checkpoint question: retrieval trace plus model-generated answer."""
47
+
48
+ sample_id: str
49
+ sample_uuid: str
50
+ checkpoint_id: str
51
+ question: str
52
+ gold_answer: str
53
+ gold_evidence_memory_ids: tuple[str, ...]
54
+ gold_evidence_contents: tuple[str, ...]
55
+ question_type: str
56
+ question_type_abbrev: str
57
+ difficulty: str
58
+ retrieval: RetrievalRecord
59
+ generated_answer: str
60
+ cited_memories: tuple[str, ...] = ()
pipeline/runner.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session-by-session ingest, memory export, and checkpoint QA orchestration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable
6
+
7
+ from eval_framework.datasets.domain_a_v2 import (
8
+ DomainAV2AcademicSample,
9
+ NormalizedCheckpointQuestion,
10
+ )
11
+ from eval_framework.memory_adapters.base import MemoryAdapter
12
+ from eval_framework.pipeline.qa_runner import run_checkpoint_qa_records
13
+ from eval_framework.pipeline.records import PipelineCheckpointQARecord, PipelineSessionRecord
14
+ from eval_framework.datasets.schemas import RetrievalRecord
15
+
16
+
17
+ def ensure_adapter_available(adapter: MemoryAdapter) -> None:
18
+ caps = adapter.get_capabilities()
19
+ if caps.get("available") is False:
20
+ backend = caps.get("backend", type(adapter).__name__)
21
+ detail = caps.get("integration_error") or caps.get(
22
+ "integration_status", "available=False"
23
+ )
24
+ raise RuntimeError(
25
+ f"Memory adapter {backend!r} is not available for pipeline runs: {detail}"
26
+ )
27
+
28
+
29
+ def run_domain_a_v2_sample(
30
+ adapter: MemoryAdapter,
31
+ sample: DomainAV2AcademicSample,
32
+ *,
33
+ top_k: int = 5,
34
+ answer_fn: Callable | None = None,
35
+ ) -> tuple[tuple[PipelineSessionRecord, ...], tuple[PipelineCheckpointQARecord, ...]]:
36
+ """Run all sessions in order, emit one session record per session, then checkpoint QA when due."""
37
+ ensure_adapter_available(adapter)
38
+ if sample.normalized_checkpoints and answer_fn is None:
39
+ raise ValueError(
40
+ "answer_fn is required when the sample defines normalized checkpoints"
41
+ )
42
+
43
+ adapter.reset()
44
+ session_out: list[PipelineSessionRecord] = []
45
+ qa_out: list[PipelineCheckpointQARecord] = []
46
+ completed_sessions: set[str] = set()
47
+ session_order = {
48
+ session.session_id: index for index, session in enumerate(sample.sessions)
49
+ }
50
+
51
+ if len(sample.sessions) != len(sample.session_gold_states):
52
+ raise ValueError(
53
+ "sample.sessions and sample.session_gold_states length mismatch"
54
+ )
55
+
56
+ for sess, gold in zip(sample.sessions, sample.session_gold_states):
57
+ if sess.session_id != gold.session_id:
58
+ raise ValueError(
59
+ f"session / gold_state id mismatch: {sess.session_id!r} vs {gold.session_id!r}"
60
+ )
61
+ for turn in sess.turns:
62
+ adapter.ingest_turn(turn)
63
+ adapter.end_session(sess.session_id)
64
+
65
+ snapshot = tuple(adapter.snapshot_memories())
66
+ delta = tuple(adapter.export_memory_delta(sess.session_id))
67
+ session_out.append(
68
+ PipelineSessionRecord(
69
+ sample_id=sample.sample_id,
70
+ sample_uuid=sample.uuid,
71
+ session_id=sess.session_id,
72
+ memory_snapshot=snapshot,
73
+ memory_delta=delta,
74
+ gold_state=gold,
75
+ )
76
+ )
77
+ completed_sessions.add(sess.session_id)
78
+
79
+ for cp in sample.normalized_checkpoints:
80
+ covered = cp.covered_sessions
81
+ if not covered:
82
+ continue
83
+ missing = [sid for sid in covered if sid not in session_order]
84
+ if missing:
85
+ raise ValueError(
86
+ f"checkpoint {cp.checkpoint_id!r} references unknown sessions: {missing}"
87
+ )
88
+ if not set(covered).issubset(completed_sessions):
89
+ continue
90
+ trigger_session_id = max(covered, key=session_order.__getitem__)
91
+ if sess.session_id != trigger_session_id:
92
+ continue
93
+ qa_out.extend(
94
+ run_checkpoint_qa_records(
95
+ adapter,
96
+ sample_id=sample.sample_id,
97
+ sample_uuid=sample.uuid,
98
+ checkpoint=cp,
99
+ top_k=top_k,
100
+ answer_fn=answer_fn,
101
+ )
102
+ )
103
+
104
+ return tuple(session_out), tuple(qa_out)