| """Helpers to map turns, backend memory dicts, and recall outputs into shared schemas.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Any, Mapping |
|
|
| from eval_framework.datasets.schemas import ( |
| MemorySnapshotRecord, |
| NormalizedTurn, |
| RetrievalItem, |
| RetrievalRecord, |
| ) |
|
|
|
|
| def turn_to_observation_dict(turn: NormalizedTurn) -> dict[str, Any]: |
| """Build a Mem-Gallery store observation from a normalized turn.""" |
| parts: list[str] = [turn.text] |
| for att in turn.attachments: |
| parts.append(f"[{att.type}] {att.caption}") |
| text = "\n".join(parts) |
| obs: dict[str, Any] = {"text": text} |
| if turn.timestamp: |
| obs["timestamp"] = turn.timestamp |
| obs["dialogue_id"] = f"{turn.session_id}:{turn.turn_index}" |
| return obs |
|
|
|
|
| def memory_element_text(element: Mapping[str, Any]) -> str: |
| """Best-effort text extraction from a Mem-Gallery memory dict.""" |
| raw = element.get("text", "") |
| if isinstance(raw, list): |
| return " ".join(str(x) for x in raw) |
| if raw is None: |
| base = "" |
| else: |
| base = str(raw) |
| image = element.get("image") |
| if isinstance(image, dict): |
| cap = image.get("caption") |
| if cap: |
| base = f"{base}\n[image] {cap}".strip() |
| return base |
|
|
|
|
| def linear_element_to_snapshot( |
| element: Mapping[str, Any], |
| *, |
| memory_id: str, |
| session_id: str, |
| source: str, |
| status: str = "active", |
| ) -> MemorySnapshotRecord: |
| """Map a linear-storage memory dict into MemorySnapshotRecord.""" |
| cid = element.get("counter_id") |
| raw_id = str(cid) if cid is not None else memory_id |
| return MemorySnapshotRecord( |
| memory_id=memory_id, |
| text=memory_element_text(element), |
| session_id=session_id, |
| status=status, |
| source=source, |
| raw_backend_id=raw_id, |
| raw_backend_type="linear", |
| metadata={}, |
| ) |
|
|
|
|
| def normalize_recall_to_retrieval( |
| query: str, |
| top_k: int, |
| raw: Any, |
| *, |
| raw_trace: dict[str, Any] | None = None, |
| ) -> RetrievalRecord: |
| """Normalize Mem-Gallery recall outputs into RetrievalRecord.""" |
| trace = dict(raw_trace or {}) |
| items: list[RetrievalItem] = [] |
|
|
| if isinstance(raw, str): |
| items.append( |
| RetrievalItem( |
| rank=0, |
| memory_id="memgallery:string_bundle", |
| text=raw, |
| score=1.0, |
| raw_backend_id=None, |
| ) |
| ) |
| elif isinstance(raw, list): |
| for i, row in enumerate(raw[: max(0, top_k)]): |
| if isinstance(row, dict): |
| mid = row.get("counter_id") |
| items.append( |
| RetrievalItem( |
| rank=i, |
| memory_id=str(mid if mid is not None else i), |
| text=memory_element_text(row), |
| score=float(row.get("score", 1.0)), |
| raw_backend_id=str(mid) if mid is not None else None, |
| ) |
| ) |
| else: |
| items.append( |
| RetrievalItem( |
| rank=i, |
| memory_id=str(i), |
| text=str(row), |
| score=1.0, |
| raw_backend_id=None, |
| ) |
| ) |
| else: |
| items.append( |
| RetrievalItem( |
| rank=0, |
| memory_id="memgallery:object_bundle", |
| text=str(raw), |
| score=1.0, |
| raw_backend_id=None, |
| ) |
| ) |
|
|
| return RetrievalRecord(query=query, top_k=top_k, items=items[:top_k], raw_trace=trace) |
|
|