| """Adapter for SimpleMem and Omni-SimpleMem baselines.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
| from pathlib import Path |
| from typing import Any |
|
|
| from eval_framework.datasets.schemas import ( |
| MemoryDeltaRecord, |
| MemorySnapshotRecord, |
| NormalizedTurn, |
| RetrievalItem, |
| RetrievalRecord, |
| ) |
| from eval_framework.memory_adapters.base import MemoryAdapter |
|
|
| _DEFAULT_SOURCE = Path("/data1/toby/nips26/baselines/SimpleMem") |
|
|
|
|
| class SimpleMemAdapter(MemoryAdapter): |
| """Adapter for SimpleMem (text mode) or Omni-SimpleMem (omni mode).""" |
|
|
| def __init__( |
| self, |
| *, |
| mode: str = "text", |
| source_root: str | os.PathLike[str] | None = None, |
| **kwargs: Any, |
| ) -> None: |
| self._mode = mode |
| root = Path(source_root or _DEFAULT_SOURCE).resolve() |
| if str(root) not in sys.path: |
| sys.path.insert(0, str(root)) |
|
|
| import simplemem_router as simplemem |
| self._simplemem = simplemem |
| self._mem: Any = None |
| self._session_id = "" |
| self._prev_snapshot_ids: set[str] = set() |
| self._stored_texts: list[dict[str, str]] = [] |
| self._init_mem() |
|
|
| def _init_mem(self) -> None: |
| self._mem = self._simplemem.create(mode=self._mode, clear_db=True) |
| self._stored_texts = [] |
|
|
| def reset(self) -> None: |
| if self._mem is not None: |
| try: |
| self._mem.close() |
| except Exception: |
| pass |
| self._init_mem() |
| self._prev_snapshot_ids = set() |
|
|
| def ingest_turn(self, turn: NormalizedTurn) -> None: |
| self._session_id = turn.session_id |
| text = f"{turn.role}: {turn.text}" |
| for att in turn.attachments: |
| text += f"\n[{att.type}] {att.caption}" |
|
|
| mid = str(len(self._stored_texts)) |
| if self._mode == "omni": |
| self._mem.add_text(text, tags=[f"session:{turn.session_id}"]) |
| else: |
| speaker = "User" if turn.role == "user" else "Assistant" |
| ts = turn.timestamp or "" |
| self._mem.add_dialogue(speaker, text, ts) |
| self._stored_texts.append({"id": mid, "text": text, "session_id": turn.session_id}) |
|
|
| def end_session(self, session_id: str) -> None: |
| self._session_id = session_id |
| if self._mode == "text": |
| try: |
| self._mem.finalize() |
| except Exception: |
| pass |
|
|
| def snapshot_memories(self) -> list[MemorySnapshotRecord]: |
| return [ |
| MemorySnapshotRecord( |
| memory_id=t["id"], text=t["text"], |
| session_id=t["session_id"], status="active", |
| source=f"SimpleMem-{self._mode}", |
| raw_backend_id=t["id"], raw_backend_type="simplemem", |
| metadata={}, |
| ) |
| for t in self._stored_texts |
| ] |
|
|
| def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: |
| current = self.snapshot_memories() |
| current_ids = {s.memory_id for s in current} |
| deltas = [ |
| MemoryDeltaRecord( |
| session_id=session_id, op="add", text=s.text, |
| linked_previous=(), raw_backend_id=s.raw_backend_id, |
| metadata={"baseline": f"SimpleMem-{self._mode}"}, |
| ) |
| for s in current if s.memory_id not in self._prev_snapshot_ids |
| ] |
| self._prev_snapshot_ids = current_ids |
| return deltas |
|
|
| def retrieve(self, query: str, top_k: int) -> RetrievalRecord: |
| items: list[RetrievalItem] = [] |
| try: |
| if self._mode == "omni": |
| result = self._mem.query(query, top_k=top_k) |
| if isinstance(result, list): |
| for i, r in enumerate(result[:top_k]): |
| text = r.get("text", str(r)) if isinstance(r, dict) else str(r) |
| items.append(RetrievalItem( |
| rank=i, memory_id=str(i), text=text, |
| score=1.0 / (i + 1), raw_backend_id=None, |
| )) |
| else: |
| answer = self._mem.ask(query) |
| if answer: |
| items.append(RetrievalItem( |
| rank=0, memory_id="answer", text=str(answer), |
| score=1.0, raw_backend_id=None, |
| )) |
| except Exception: |
| pass |
|
|
| if not items: |
| |
| query_lower = query.lower() |
| scored = [] |
| for t in self._stored_texts: |
| overlap = len(set(query_lower.split()) & set(t["text"].lower().split())) |
| scored.append((overlap, t)) |
| scored.sort(key=lambda x: x[0], reverse=True) |
| for i, (sc, t) in enumerate(scored[:top_k]): |
| items.append(RetrievalItem( |
| rank=i, memory_id=t["id"], text=t["text"], |
| score=float(sc) / max(len(query.split()), 1), |
| raw_backend_id=t["id"], |
| )) |
|
|
| return RetrievalRecord( |
| query=query, top_k=top_k, items=items[:top_k], |
| raw_trace={"baseline": f"SimpleMem-{self._mode}"}, |
| ) |
|
|
| def get_capabilities(self) -> dict[str, Any]: |
| name = "Omni-SimpleMem" if self._mode == "omni" else "SimpleMem" |
| return { |
| "backend": name, "baseline": name, |
| "available": self._mem is not None, |
| "delta_granularity": "per_turn", |
| "snapshot_mode": "full", |
| } |
|
|