| """Adapter for Zep memory system (community/self-hosted edition).""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import uuid as _uuid |
| from typing import Any |
|
|
| from eval_framework.datasets.schemas import ( |
| MemoryDeltaRecord, |
| MemorySnapshotRecord, |
| NormalizedTurn, |
| RetrievalItem, |
| RetrievalRecord, |
| ) |
| from eval_framework.memory_adapters.base import MemoryAdapter |
|
|
|
|
| class ZepAdapter(MemoryAdapter): |
| """Adapter for Zep community edition (self-hosted).""" |
|
|
| def __init__(self, *, base_url: str | None = None, **kwargs: Any) -> None: |
| from zep_python import ZepClient |
|
|
| self._base_url = base_url or os.getenv("ZEP_BASE_URL", "http://localhost:8000") |
| self._client = ZepClient(base_url=self._base_url) |
| self._session_id = "" |
| self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}" |
| self._prev_snapshot_ids: set[str] = set() |
|
|
| def reset(self) -> None: |
| try: |
| self._client.memory.delete_memory(self._thread_id) |
| except Exception: |
| pass |
| self._thread_id = f"eval_{_uuid.uuid4().hex[:8]}" |
| self._prev_snapshot_ids = set() |
|
|
| def ingest_turn(self, turn: NormalizedTurn) -> None: |
| from zep_python.memory import Memory |
| from zep_python.message import Message |
|
|
| self._session_id = turn.session_id |
| text = f"{turn.role}: {turn.text}" |
| for att in turn.attachments: |
| text += f"\n[{att.type}] {att.caption}" |
|
|
| role_type = "user" if turn.role == "user" else "ai" |
| msg = Message(role=turn.role, role_type=role_type, content=text) |
| memory = Memory(messages=[msg]) |
| self._client.memory.add_memory(self._thread_id, memory) |
|
|
| def end_session(self, session_id: str) -> None: |
| self._session_id = session_id |
|
|
| def snapshot_memories(self) -> list[MemorySnapshotRecord]: |
| try: |
| memory = self._client.memory.get_memory(self._thread_id) |
| except Exception: |
| return [] |
|
|
| rows: list[MemorySnapshotRecord] = [] |
| if memory and memory.messages: |
| for i, msg in enumerate(memory.messages): |
| mid = str(getattr(msg, "uuid", i)) |
| rows.append(MemorySnapshotRecord( |
| memory_id=mid, |
| text=msg.content or "", |
| session_id=self._session_id, |
| status="active", |
| source="Zep", |
| raw_backend_id=mid, |
| raw_backend_type="zep_message", |
| metadata={}, |
| )) |
| return rows |
|
|
| def export_memory_delta(self, session_id: str) -> list[MemoryDeltaRecord]: |
| current = self.snapshot_memories() |
| current_ids = {s.memory_id for s in current} |
| deltas = [ |
| MemoryDeltaRecord( |
| session_id=session_id, op="add", text=s.text, |
| linked_previous=(), raw_backend_id=s.raw_backend_id, |
| metadata={"baseline": "Zep"}, |
| ) |
| for s in current if s.memory_id not in self._prev_snapshot_ids |
| ] |
| self._prev_snapshot_ids = current_ids |
| return deltas |
|
|
| def retrieve(self, query: str, top_k: int) -> RetrievalRecord: |
| try: |
| results = self._client.memory.search_memory( |
| self._thread_id, query, limit=top_k, |
| ) |
| except Exception: |
| results = [] |
|
|
| items = [ |
| RetrievalItem( |
| rank=i, |
| memory_id=str(getattr(r.message, "uuid", i)) if r.message else str(i), |
| text=r.message.content if r.message else str(r), |
| score=float(getattr(r, "score", 1.0 / (i + 1))), |
| raw_backend_id=str(getattr(r.message, "uuid", "")) if r.message else None, |
| ) |
| for i, r in enumerate(results[:top_k]) |
| ] |
| return RetrievalRecord( |
| query=query, top_k=top_k, items=items, |
| raw_trace={"baseline": "Zep"}, |
| ) |
|
|
| def get_capabilities(self) -> dict[str, Any]: |
| return { |
| "backend": "Zep", |
| "baseline": "Zep", |
| "available": True, |
| "delta_granularity": "snapshot_diff", |
| "snapshot_mode": "full", |
| } |
|
|