import json import logging import time import uuid from datetime import datetime from typing import Any, Dict, List, Optional import yaml from pydantic import BaseModel from swarm_models.tiktoken_wrapper import TikTokenizer logger = logging.getLogger(__name__) class MemoryMetadata(BaseModel): """Metadata for memory entries""" timestamp: Optional[float] = time.time() role: Optional[str] = None agent_name: Optional[str] = None session_id: Optional[str] = None memory_type: Optional[str] = None # 'short_term' or 'long_term' token_count: Optional[int] = None message_id: Optional[str] = str(uuid.uuid4()) class MemoryEntry(BaseModel): """Single memory entry with content and metadata""" content: Optional[str] = None metadata: Optional[MemoryMetadata] = None class MemoryConfig(BaseModel): """Configuration for memory manager""" max_short_term_tokens: Optional[int] = 4096 max_entries: Optional[int] = None system_messages_token_buffer: Optional[int] = 1000 enable_long_term_memory: Optional[bool] = False auto_archive: Optional[bool] = True archive_threshold: Optional[float] = 0.8 # Archive when 80% full class MemoryManager: """ Manages both short-term and long-term memory for an agent, handling token limits, archival, and context retrieval. Args: config (MemoryConfig): Configuration for memory management tokenizer (Optional[Any]): Tokenizer to use for token counting long_term_memory (Optional[Any]): Vector store or database for long-term storage """ def __init__( self, config: MemoryConfig, tokenizer: Optional[Any] = None, long_term_memory: Optional[Any] = None, ): self.config = config self.tokenizer = tokenizer or TikTokenizer() self.long_term_memory = long_term_memory # Initialize memories self.short_term_memory: List[MemoryEntry] = [] self.system_messages: List[MemoryEntry] = [] # Memory statistics self.total_tokens_processed: int = 0 self.archived_entries_count: int = 0 def create_memory_entry( self, content: str, role: str, agent_name: str, session_id: str, memory_type: str = "short_term", ) -> MemoryEntry: """Create a new memory entry with metadata""" metadata = MemoryMetadata( timestamp=time.time(), role=role, agent_name=agent_name, session_id=session_id, memory_type=memory_type, token_count=self.tokenizer.count_tokens(content), ) return MemoryEntry(content=content, metadata=metadata) def add_memory( self, content: str, role: str, agent_name: str, session_id: str, is_system: bool = False, ) -> None: """Add a new memory entry to appropriate storage""" entry = self.create_memory_entry( content=content, role=role, agent_name=agent_name, session_id=session_id, memory_type="system" if is_system else "short_term", ) if is_system: self.system_messages.append(entry) else: self.short_term_memory.append(entry) # Check if archiving is needed if self.should_archive(): self.archive_old_memories() self.total_tokens_processed += entry.metadata.token_count def get_current_token_count(self) -> int: """Get total tokens in short-term memory""" return sum( entry.metadata.token_count for entry in self.short_term_memory ) def get_system_messages_token_count(self) -> int: """Get total tokens in system messages""" return sum( entry.metadata.token_count for entry in self.system_messages ) def should_archive(self) -> bool: """Check if archiving is needed based on configuration""" if not self.config.auto_archive: return False current_usage = ( self.get_current_token_count() / self.config.max_short_term_tokens ) return current_usage >= self.config.archive_threshold def archive_old_memories(self) -> None: """Move older memories to long-term storage""" if not self.long_term_memory: logger.warning( "No long-term memory storage configured for archiving" ) return while self.should_archive(): # Get oldest non-system message if not self.short_term_memory: break oldest_entry = self.short_term_memory.pop(0) # Store in long-term memory self.store_in_long_term_memory(oldest_entry) self.archived_entries_count += 1 def store_in_long_term_memory(self, entry: MemoryEntry) -> None: """Store a memory entry in long-term memory""" if self.long_term_memory is None: logger.warning( "Attempted to store in non-existent long-term memory" ) return try: self.long_term_memory.add(str(entry.model_dump())) except Exception as e: logger.error(f"Error storing in long-term memory: {e}") # Re-add to short-term if storage fails self.short_term_memory.insert(0, entry) def get_relevant_context( self, query: str, max_tokens: Optional[int] = None ) -> str: """ Get relevant context from both memory types Args: query (str): Query to match against memories max_tokens (Optional[int]): Maximum tokens to return Returns: str: Combined relevant context """ contexts = [] # Add system messages first for entry in self.system_messages: contexts.append(entry.content) # Add short-term memory for entry in reversed(self.short_term_memory): contexts.append(entry.content) # Query long-term memory if available if self.long_term_memory is not None: long_term_context = self.long_term_memory.query(query) if long_term_context: contexts.append(str(long_term_context)) # Combine and truncate if needed combined = "\n".join(contexts) if max_tokens: combined = self.truncate_to_token_limit( combined, max_tokens ) return combined def truncate_to_token_limit( self, text: str, max_tokens: int ) -> str: """Truncate text to fit within token limit""" current_tokens = self.tokenizer.count_tokens(text) if current_tokens <= max_tokens: return text # Truncate by splitting into sentences and rebuilding sentences = text.split(". ") result = [] current_count = 0 for sentence in sentences: sentence_tokens = self.tokenizer.count_tokens(sentence) if current_count + sentence_tokens <= max_tokens: result.append(sentence) current_count += sentence_tokens else: break return ". ".join(result) def clear_short_term_memory( self, preserve_system: bool = True ) -> None: """Clear short-term memory with option to preserve system messages""" if not preserve_system: self.system_messages.clear() self.short_term_memory.clear() logger.info( "Cleared short-term memory" + " (preserved system messages)" if preserve_system else "" ) def get_memory_stats(self) -> Dict[str, Any]: """Get detailed memory statistics""" return { "short_term_messages": len(self.short_term_memory), "system_messages": len(self.system_messages), "current_tokens": self.get_current_token_count(), "system_tokens": self.get_system_messages_token_count(), "max_tokens": self.config.max_short_term_tokens, "token_usage_percent": round( ( self.get_current_token_count() / self.config.max_short_term_tokens ) * 100, 2, ), "has_long_term_memory": self.long_term_memory is not None, "archived_entries": self.archived_entries_count, "total_tokens_processed": self.total_tokens_processed, } def save_memory_snapshot(self, file_path: str) -> None: """Save current memory state to file""" try: data = { "timestamp": datetime.now().isoformat(), "config": self.config.model_dump(), "system_messages": [ entry.model_dump() for entry in self.system_messages ], "short_term_memory": [ entry.model_dump() for entry in self.short_term_memory ], "stats": self.get_memory_stats(), } with open(file_path, "w") as f: if file_path.endswith(".yaml"): yaml.dump(data, f) else: json.dump(data, f, indent=2) logger.info(f"Saved memory snapshot to {file_path}") except Exception as e: logger.error(f"Error saving memory snapshot: {e}") raise def load_memory_snapshot(self, file_path: str) -> None: """Load memory state from file""" try: with open(file_path, "r") as f: if file_path.endswith(".yaml"): data = yaml.safe_load(f) else: data = json.load(f) self.config = MemoryConfig(**data["config"]) self.system_messages = [ MemoryEntry(**entry) for entry in data["system_messages"] ] self.short_term_memory = [ MemoryEntry(**entry) for entry in data["short_term_memory"] ] logger.info(f"Loaded memory snapshot from {file_path}") except Exception as e: logger.error(f"Error loading memory snapshot: {e}") raise def search_memories( self, query: str, memory_type: str = "all" ) -> List[MemoryEntry]: """ Search through memories of specified type Args: query (str): Search query memory_type (str): Type of memories to search ("short_term", "system", "long_term", or "all") Returns: List[MemoryEntry]: Matching memory entries """ results = [] if memory_type in ["short_term", "all"]: results.extend( [ entry for entry in self.short_term_memory if query.lower() in entry.content.lower() ] ) if memory_type in ["system", "all"]: results.extend( [ entry for entry in self.system_messages if query.lower() in entry.content.lower() ] ) if ( memory_type in ["long_term", "all"] and self.long_term_memory is not None ): long_term_results = self.long_term_memory.query(query) if long_term_results: # Convert long-term results to MemoryEntry format for result in long_term_results: content = str(result) metadata = MemoryMetadata( timestamp=time.time(), role="long_term", agent_name="system", session_id="long_term", memory_type="long_term", token_count=self.tokenizer.count_tokens( content ), ) results.append( MemoryEntry( content=content, metadata=metadata ) ) return results def get_memory_by_timeframe( self, start_time: float, end_time: float ) -> List[MemoryEntry]: """Get memories within a specific timeframe""" return [ entry for entry in self.short_term_memory if start_time <= entry.metadata.timestamp <= end_time ] def export_memories( self, file_path: str, format: str = "json" ) -> None: """Export memories to file in specified format""" data = { "system_messages": [ entry.model_dump() for entry in self.system_messages ], "short_term_memory": [ entry.model_dump() for entry in self.short_term_memory ], "stats": self.get_memory_stats(), } with open(file_path, "w") as f: if format == "yaml": yaml.dump(data, f) else: json.dump(data, f, indent=2)