Spaces:
Running
Running
| from __future__ import annotations | |
| """Context memory management for the agent.""" | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from typing import Any | |
| class MemoryEntry: | |
| """A single entry in memory.""" | |
| key: str | |
| value: Any | |
| timestamp: datetime = field(default_factory=datetime.now) | |
| source: str = "unknown" | |
| relevance: float = 1.0 | |
| class ContextMemory: | |
| """Manages context and working memory for the agent.""" | |
| def __init__(self, max_entries: int = 100): | |
| """Initialize memory. | |
| Args: | |
| max_entries: Maximum entries to keep | |
| """ | |
| self.max_entries = max_entries | |
| self._short_term: dict[str, MemoryEntry] = {} | |
| self._working: dict[str, Any] = {} | |
| self._conversation: list[dict[str, str]] = [] | |
| def store(self, key: str, value: Any, source: str = "agent") -> None: | |
| """Store a value in short-term memory. | |
| Args: | |
| key: Memory key | |
| value: Value to store | |
| source: Source of the information | |
| """ | |
| self._short_term[key] = MemoryEntry( | |
| key=key, | |
| value=value, | |
| source=source, | |
| ) | |
| # Trim if over capacity | |
| if len(self._short_term) > self.max_entries: | |
| self._trim_oldest() | |
| def retrieve(self, key: str) -> Any | None: | |
| """Retrieve a value from memory. | |
| Args: | |
| key: Memory key | |
| Returns: | |
| Stored value or None | |
| """ | |
| entry = self._short_term.get(key) | |
| return entry.value if entry else None | |
| def update_working(self, key: str, value: Any) -> None: | |
| """Update working memory. | |
| Args: | |
| key: Memory key | |
| value: Value to store | |
| """ | |
| self._working[key] = value | |
| def get_working(self, key: str, default: Any = None) -> Any: | |
| """Get from working memory. | |
| Args: | |
| key: Memory key | |
| default: Default value if not found | |
| Returns: | |
| Stored value or default | |
| """ | |
| return self._working.get(key, default) | |
| def add_conversation_turn(self, role: str, content: str) -> None: | |
| """Add a turn to conversation history. | |
| Args: | |
| role: Message role (user/assistant) | |
| content: Message content | |
| """ | |
| self._conversation.append({ | |
| "role": role, | |
| "content": content, | |
| "timestamp": datetime.now().isoformat(), | |
| }) | |
| def get_conversation_history(self, limit: int = 10) -> list[dict[str, str]]: | |
| """Get recent conversation history. | |
| Args: | |
| limit: Maximum turns to return | |
| Returns: | |
| List of conversation turns | |
| """ | |
| return self._conversation[-limit:] | |
| def get_context_summary(self) -> dict[str, Any]: | |
| """Get a summary of current context. | |
| Returns: | |
| Dictionary with context summary | |
| """ | |
| return { | |
| "short_term_keys": list(self._short_term.keys()), | |
| "working_memory_keys": list(self._working.keys()), | |
| "conversation_length": len(self._conversation), | |
| } | |
| def clear_working(self) -> None: | |
| """Clear working memory.""" | |
| self._working.clear() | |
| def clear_all(self) -> None: | |
| """Clear all memory.""" | |
| self._short_term.clear() | |
| self._working.clear() | |
| self._conversation.clear() | |
| def _trim_oldest(self) -> None: | |
| """Remove oldest entries to stay within capacity.""" | |
| if not self._short_term: | |
| return | |
| # Sort by timestamp and remove oldest | |
| sorted_keys = sorted( | |
| self._short_term.keys(), | |
| key=lambda k: self._short_term[k].timestamp, | |
| ) | |
| # Remove oldest 10% | |
| to_remove = max(1, len(sorted_keys) // 10) | |
| for key in sorted_keys[:to_remove]: | |
| del self._short_term[key] | |
| def search(self, query: str) -> list[MemoryEntry]: | |
| """Search memory for relevant entries. | |
| Args: | |
| query: Search query | |
| Returns: | |
| List of matching entries | |
| """ | |
| query_lower = query.lower() | |
| results = [] | |
| for entry in self._short_term.values(): | |
| # Simple keyword matching | |
| value_str = str(entry.value).lower() | |
| if query_lower in value_str or query_lower in entry.key.lower(): | |
| results.append(entry) | |
| # Sort by relevance (for now, just by timestamp) | |
| results.sort(key=lambda e: e.timestamp, reverse=True) | |
| return results | |