Spaces:
Sleeping
Sleeping
File size: 13,502 Bytes
d8d14f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 |
import json
import logging
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
import yaml
from pydantic import BaseModel
from swarm_models.tiktoken_wrapper import TikTokenizer
logger = logging.getLogger(__name__)
class MemoryMetadata(BaseModel):
"""Metadata for memory entries"""
timestamp: Optional[float] = time.time()
role: Optional[str] = None
agent_name: Optional[str] = None
session_id: Optional[str] = None
memory_type: Optional[str] = None # 'short_term' or 'long_term'
token_count: Optional[int] = None
message_id: Optional[str] = str(uuid.uuid4())
class MemoryEntry(BaseModel):
"""Single memory entry with content and metadata"""
content: Optional[str] = None
metadata: Optional[MemoryMetadata] = None
class MemoryConfig(BaseModel):
"""Configuration for memory manager"""
max_short_term_tokens: Optional[int] = 4096
max_entries: Optional[int] = None
system_messages_token_buffer: Optional[int] = 1000
enable_long_term_memory: Optional[bool] = False
auto_archive: Optional[bool] = True
archive_threshold: Optional[float] = 0.8 # Archive when 80% full
class MemoryManager:
"""
Manages both short-term and long-term memory for an agent, handling token limits,
archival, and context retrieval.
Args:
config (MemoryConfig): Configuration for memory management
tokenizer (Optional[Any]): Tokenizer to use for token counting
long_term_memory (Optional[Any]): Vector store or database for long-term storage
"""
def __init__(
self,
config: MemoryConfig,
tokenizer: Optional[Any] = None,
long_term_memory: Optional[Any] = None,
):
self.config = config
self.tokenizer = tokenizer or TikTokenizer()
self.long_term_memory = long_term_memory
# Initialize memories
self.short_term_memory: List[MemoryEntry] = []
self.system_messages: List[MemoryEntry] = []
# Memory statistics
self.total_tokens_processed: int = 0
self.archived_entries_count: int = 0
def create_memory_entry(
self,
content: str,
role: str,
agent_name: str,
session_id: str,
memory_type: str = "short_term",
) -> MemoryEntry:
"""Create a new memory entry with metadata"""
metadata = MemoryMetadata(
timestamp=time.time(),
role=role,
agent_name=agent_name,
session_id=session_id,
memory_type=memory_type,
token_count=self.tokenizer.count_tokens(content),
)
return MemoryEntry(content=content, metadata=metadata)
def add_memory(
self,
content: str,
role: str,
agent_name: str,
session_id: str,
is_system: bool = False,
) -> None:
"""Add a new memory entry to appropriate storage"""
entry = self.create_memory_entry(
content=content,
role=role,
agent_name=agent_name,
session_id=session_id,
memory_type="system" if is_system else "short_term",
)
if is_system:
self.system_messages.append(entry)
else:
self.short_term_memory.append(entry)
# Check if archiving is needed
if self.should_archive():
self.archive_old_memories()
self.total_tokens_processed += entry.metadata.token_count
def get_current_token_count(self) -> int:
"""Get total tokens in short-term memory"""
return sum(
entry.metadata.token_count
for entry in self.short_term_memory
)
def get_system_messages_token_count(self) -> int:
"""Get total tokens in system messages"""
return sum(
entry.metadata.token_count
for entry in self.system_messages
)
def should_archive(self) -> bool:
"""Check if archiving is needed based on configuration"""
if not self.config.auto_archive:
return False
current_usage = (
self.get_current_token_count()
/ self.config.max_short_term_tokens
)
return current_usage >= self.config.archive_threshold
def archive_old_memories(self) -> None:
"""Move older memories to long-term storage"""
if not self.long_term_memory:
logger.warning(
"No long-term memory storage configured for archiving"
)
return
while self.should_archive():
# Get oldest non-system message
if not self.short_term_memory:
break
oldest_entry = self.short_term_memory.pop(0)
# Store in long-term memory
self.store_in_long_term_memory(oldest_entry)
self.archived_entries_count += 1
def store_in_long_term_memory(self, entry: MemoryEntry) -> None:
"""Store a memory entry in long-term memory"""
if self.long_term_memory is None:
logger.warning(
"Attempted to store in non-existent long-term memory"
)
return
try:
self.long_term_memory.add(str(entry.model_dump()))
except Exception as e:
logger.error(f"Error storing in long-term memory: {e}")
# Re-add to short-term if storage fails
self.short_term_memory.insert(0, entry)
def get_relevant_context(
self, query: str, max_tokens: Optional[int] = None
) -> str:
"""
Get relevant context from both memory types
Args:
query (str): Query to match against memories
max_tokens (Optional[int]): Maximum tokens to return
Returns:
str: Combined relevant context
"""
contexts = []
# Add system messages first
for entry in self.system_messages:
contexts.append(entry.content)
# Add short-term memory
for entry in reversed(self.short_term_memory):
contexts.append(entry.content)
# Query long-term memory if available
if self.long_term_memory is not None:
long_term_context = self.long_term_memory.query(query)
if long_term_context:
contexts.append(str(long_term_context))
# Combine and truncate if needed
combined = "\n".join(contexts)
if max_tokens:
combined = self.truncate_to_token_limit(
combined, max_tokens
)
return combined
def truncate_to_token_limit(
self, text: str, max_tokens: int
) -> str:
"""Truncate text to fit within token limit"""
current_tokens = self.tokenizer.count_tokens(text)
if current_tokens <= max_tokens:
return text
# Truncate by splitting into sentences and rebuilding
sentences = text.split(". ")
result = []
current_count = 0
for sentence in sentences:
sentence_tokens = self.tokenizer.count_tokens(sentence)
if current_count + sentence_tokens <= max_tokens:
result.append(sentence)
current_count += sentence_tokens
else:
break
return ". ".join(result)
def clear_short_term_memory(
self, preserve_system: bool = True
) -> None:
"""Clear short-term memory with option to preserve system messages"""
if not preserve_system:
self.system_messages.clear()
self.short_term_memory.clear()
logger.info(
"Cleared short-term memory"
+ " (preserved system messages)"
if preserve_system
else ""
)
def get_memory_stats(self) -> Dict[str, Any]:
"""Get detailed memory statistics"""
return {
"short_term_messages": len(self.short_term_memory),
"system_messages": len(self.system_messages),
"current_tokens": self.get_current_token_count(),
"system_tokens": self.get_system_messages_token_count(),
"max_tokens": self.config.max_short_term_tokens,
"token_usage_percent": round(
(
self.get_current_token_count()
/ self.config.max_short_term_tokens
)
* 100,
2,
),
"has_long_term_memory": self.long_term_memory is not None,
"archived_entries": self.archived_entries_count,
"total_tokens_processed": self.total_tokens_processed,
}
def save_memory_snapshot(self, file_path: str) -> None:
"""Save current memory state to file"""
try:
data = {
"timestamp": datetime.now().isoformat(),
"config": self.config.model_dump(),
"system_messages": [
entry.model_dump()
for entry in self.system_messages
],
"short_term_memory": [
entry.model_dump()
for entry in self.short_term_memory
],
"stats": self.get_memory_stats(),
}
with open(file_path, "w") as f:
if file_path.endswith(".yaml"):
yaml.dump(data, f)
else:
json.dump(data, f, indent=2)
logger.info(f"Saved memory snapshot to {file_path}")
except Exception as e:
logger.error(f"Error saving memory snapshot: {e}")
raise
def load_memory_snapshot(self, file_path: str) -> None:
"""Load memory state from file"""
try:
with open(file_path, "r") as f:
if file_path.endswith(".yaml"):
data = yaml.safe_load(f)
else:
data = json.load(f)
self.config = MemoryConfig(**data["config"])
self.system_messages = [
MemoryEntry(**entry)
for entry in data["system_messages"]
]
self.short_term_memory = [
MemoryEntry(**entry)
for entry in data["short_term_memory"]
]
logger.info(f"Loaded memory snapshot from {file_path}")
except Exception as e:
logger.error(f"Error loading memory snapshot: {e}")
raise
def search_memories(
self, query: str, memory_type: str = "all"
) -> List[MemoryEntry]:
"""
Search through memories of specified type
Args:
query (str): Search query
memory_type (str): Type of memories to search ("short_term", "system", "long_term", or "all")
Returns:
List[MemoryEntry]: Matching memory entries
"""
results = []
if memory_type in ["short_term", "all"]:
results.extend(
[
entry
for entry in self.short_term_memory
if query.lower() in entry.content.lower()
]
)
if memory_type in ["system", "all"]:
results.extend(
[
entry
for entry in self.system_messages
if query.lower() in entry.content.lower()
]
)
if (
memory_type in ["long_term", "all"]
and self.long_term_memory is not None
):
long_term_results = self.long_term_memory.query(query)
if long_term_results:
# Convert long-term results to MemoryEntry format
for result in long_term_results:
content = str(result)
metadata = MemoryMetadata(
timestamp=time.time(),
role="long_term",
agent_name="system",
session_id="long_term",
memory_type="long_term",
token_count=self.tokenizer.count_tokens(
content
),
)
results.append(
MemoryEntry(
content=content, metadata=metadata
)
)
return results
def get_memory_by_timeframe(
self, start_time: float, end_time: float
) -> List[MemoryEntry]:
"""Get memories within a specific timeframe"""
return [
entry
for entry in self.short_term_memory
if start_time <= entry.metadata.timestamp <= end_time
]
def export_memories(
self, file_path: str, format: str = "json"
) -> None:
"""Export memories to file in specified format"""
data = {
"system_messages": [
entry.model_dump() for entry in self.system_messages
],
"short_term_memory": [
entry.model_dump() for entry in self.short_term_memory
],
"stats": self.get_memory_stats(),
}
with open(file_path, "w") as f:
if format == "yaml":
yaml.dump(data, f)
else:
json.dump(data, f, indent=2)
|