File size: 13,502 Bytes
d8d14f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
import json
import logging
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional

import yaml
from pydantic import BaseModel
from swarm_models.tiktoken_wrapper import TikTokenizer

logger = logging.getLogger(__name__)


class MemoryMetadata(BaseModel):
    """Metadata for memory entries"""

    timestamp: Optional[float] = time.time()
    role: Optional[str] = None
    agent_name: Optional[str] = None
    session_id: Optional[str] = None
    memory_type: Optional[str] = None  # 'short_term' or 'long_term'
    token_count: Optional[int] = None
    message_id: Optional[str] = str(uuid.uuid4())


class MemoryEntry(BaseModel):
    """Single memory entry with content and metadata"""

    content: Optional[str] = None
    metadata: Optional[MemoryMetadata] = None


class MemoryConfig(BaseModel):
    """Configuration for memory manager"""

    max_short_term_tokens: Optional[int] = 4096
    max_entries: Optional[int] = None
    system_messages_token_buffer: Optional[int] = 1000
    enable_long_term_memory: Optional[bool] = False
    auto_archive: Optional[bool] = True
    archive_threshold: Optional[float] = 0.8  # Archive when 80% full


class MemoryManager:
    """
    Manages both short-term and long-term memory for an agent, handling token limits,
    archival, and context retrieval.

    Args:
        config (MemoryConfig): Configuration for memory management
        tokenizer (Optional[Any]): Tokenizer to use for token counting
        long_term_memory (Optional[Any]): Vector store or database for long-term storage
    """

    def __init__(
        self,
        config: MemoryConfig,
        tokenizer: Optional[Any] = None,
        long_term_memory: Optional[Any] = None,
    ):
        self.config = config
        self.tokenizer = tokenizer or TikTokenizer()
        self.long_term_memory = long_term_memory

        # Initialize memories
        self.short_term_memory: List[MemoryEntry] = []
        self.system_messages: List[MemoryEntry] = []

        # Memory statistics
        self.total_tokens_processed: int = 0
        self.archived_entries_count: int = 0

    def create_memory_entry(
        self,
        content: str,
        role: str,
        agent_name: str,
        session_id: str,
        memory_type: str = "short_term",
    ) -> MemoryEntry:
        """Create a new memory entry with metadata"""
        metadata = MemoryMetadata(
            timestamp=time.time(),
            role=role,
            agent_name=agent_name,
            session_id=session_id,
            memory_type=memory_type,
            token_count=self.tokenizer.count_tokens(content),
        )
        return MemoryEntry(content=content, metadata=metadata)

    def add_memory(
        self,
        content: str,
        role: str,
        agent_name: str,
        session_id: str,
        is_system: bool = False,
    ) -> None:
        """Add a new memory entry to appropriate storage"""
        entry = self.create_memory_entry(
            content=content,
            role=role,
            agent_name=agent_name,
            session_id=session_id,
            memory_type="system" if is_system else "short_term",
        )

        if is_system:
            self.system_messages.append(entry)
        else:
            self.short_term_memory.append(entry)

        # Check if archiving is needed
        if self.should_archive():
            self.archive_old_memories()

        self.total_tokens_processed += entry.metadata.token_count

    def get_current_token_count(self) -> int:
        """Get total tokens in short-term memory"""
        return sum(
            entry.metadata.token_count
            for entry in self.short_term_memory
        )

    def get_system_messages_token_count(self) -> int:
        """Get total tokens in system messages"""
        return sum(
            entry.metadata.token_count
            for entry in self.system_messages
        )

    def should_archive(self) -> bool:
        """Check if archiving is needed based on configuration"""
        if not self.config.auto_archive:
            return False

        current_usage = (
            self.get_current_token_count()
            / self.config.max_short_term_tokens
        )
        return current_usage >= self.config.archive_threshold

    def archive_old_memories(self) -> None:
        """Move older memories to long-term storage"""
        if not self.long_term_memory:
            logger.warning(
                "No long-term memory storage configured for archiving"
            )
            return

        while self.should_archive():
            # Get oldest non-system message
            if not self.short_term_memory:
                break

            oldest_entry = self.short_term_memory.pop(0)

            # Store in long-term memory
            self.store_in_long_term_memory(oldest_entry)
            self.archived_entries_count += 1

    def store_in_long_term_memory(self, entry: MemoryEntry) -> None:
        """Store a memory entry in long-term memory"""
        if self.long_term_memory is None:
            logger.warning(
                "Attempted to store in non-existent long-term memory"
            )
            return

        try:
            self.long_term_memory.add(str(entry.model_dump()))
        except Exception as e:
            logger.error(f"Error storing in long-term memory: {e}")
            # Re-add to short-term if storage fails
            self.short_term_memory.insert(0, entry)

    def get_relevant_context(
        self, query: str, max_tokens: Optional[int] = None
    ) -> str:
        """
        Get relevant context from both memory types

        Args:
            query (str): Query to match against memories
            max_tokens (Optional[int]): Maximum tokens to return

        Returns:
            str: Combined relevant context
        """
        contexts = []

        # Add system messages first
        for entry in self.system_messages:
            contexts.append(entry.content)

        # Add short-term memory
        for entry in reversed(self.short_term_memory):
            contexts.append(entry.content)

        # Query long-term memory if available
        if self.long_term_memory is not None:
            long_term_context = self.long_term_memory.query(query)
            if long_term_context:
                contexts.append(str(long_term_context))

        # Combine and truncate if needed
        combined = "\n".join(contexts)
        if max_tokens:
            combined = self.truncate_to_token_limit(
                combined, max_tokens
            )

        return combined

    def truncate_to_token_limit(
        self, text: str, max_tokens: int
    ) -> str:
        """Truncate text to fit within token limit"""
        current_tokens = self.tokenizer.count_tokens(text)

        if current_tokens <= max_tokens:
            return text

        # Truncate by splitting into sentences and rebuilding
        sentences = text.split(". ")
        result = []
        current_count = 0

        for sentence in sentences:
            sentence_tokens = self.tokenizer.count_tokens(sentence)
            if current_count + sentence_tokens <= max_tokens:
                result.append(sentence)
                current_count += sentence_tokens
            else:
                break

        return ". ".join(result)

    def clear_short_term_memory(
        self, preserve_system: bool = True
    ) -> None:
        """Clear short-term memory with option to preserve system messages"""
        if not preserve_system:
            self.system_messages.clear()
        self.short_term_memory.clear()
        logger.info(
            "Cleared short-term memory"
            + " (preserved system messages)"
            if preserve_system
            else ""
        )

    def get_memory_stats(self) -> Dict[str, Any]:
        """Get detailed memory statistics"""
        return {
            "short_term_messages": len(self.short_term_memory),
            "system_messages": len(self.system_messages),
            "current_tokens": self.get_current_token_count(),
            "system_tokens": self.get_system_messages_token_count(),
            "max_tokens": self.config.max_short_term_tokens,
            "token_usage_percent": round(
                (
                    self.get_current_token_count()
                    / self.config.max_short_term_tokens
                )
                * 100,
                2,
            ),
            "has_long_term_memory": self.long_term_memory is not None,
            "archived_entries": self.archived_entries_count,
            "total_tokens_processed": self.total_tokens_processed,
        }

    def save_memory_snapshot(self, file_path: str) -> None:
        """Save current memory state to file"""
        try:
            data = {
                "timestamp": datetime.now().isoformat(),
                "config": self.config.model_dump(),
                "system_messages": [
                    entry.model_dump()
                    for entry in self.system_messages
                ],
                "short_term_memory": [
                    entry.model_dump()
                    for entry in self.short_term_memory
                ],
                "stats": self.get_memory_stats(),
            }

            with open(file_path, "w") as f:
                if file_path.endswith(".yaml"):
                    yaml.dump(data, f)
                else:
                    json.dump(data, f, indent=2)

            logger.info(f"Saved memory snapshot to {file_path}")

        except Exception as e:
            logger.error(f"Error saving memory snapshot: {e}")
            raise

    def load_memory_snapshot(self, file_path: str) -> None:
        """Load memory state from file"""
        try:
            with open(file_path, "r") as f:
                if file_path.endswith(".yaml"):
                    data = yaml.safe_load(f)
                else:
                    data = json.load(f)

            self.config = MemoryConfig(**data["config"])
            self.system_messages = [
                MemoryEntry(**entry)
                for entry in data["system_messages"]
            ]
            self.short_term_memory = [
                MemoryEntry(**entry)
                for entry in data["short_term_memory"]
            ]

            logger.info(f"Loaded memory snapshot from {file_path}")

        except Exception as e:
            logger.error(f"Error loading memory snapshot: {e}")
            raise

    def search_memories(
        self, query: str, memory_type: str = "all"
    ) -> List[MemoryEntry]:
        """
        Search through memories of specified type

        Args:
            query (str): Search query
            memory_type (str): Type of memories to search ("short_term", "system", "long_term", or "all")

        Returns:
            List[MemoryEntry]: Matching memory entries
        """
        results = []

        if memory_type in ["short_term", "all"]:
            results.extend(
                [
                    entry
                    for entry in self.short_term_memory
                    if query.lower() in entry.content.lower()
                ]
            )

        if memory_type in ["system", "all"]:
            results.extend(
                [
                    entry
                    for entry in self.system_messages
                    if query.lower() in entry.content.lower()
                ]
            )

        if (
            memory_type in ["long_term", "all"]
            and self.long_term_memory is not None
        ):
            long_term_results = self.long_term_memory.query(query)
            if long_term_results:
                # Convert long-term results to MemoryEntry format
                for result in long_term_results:
                    content = str(result)
                    metadata = MemoryMetadata(
                        timestamp=time.time(),
                        role="long_term",
                        agent_name="system",
                        session_id="long_term",
                        memory_type="long_term",
                        token_count=self.tokenizer.count_tokens(
                            content
                        ),
                    )
                    results.append(
                        MemoryEntry(
                            content=content, metadata=metadata
                        )
                    )

        return results

    def get_memory_by_timeframe(
        self, start_time: float, end_time: float
    ) -> List[MemoryEntry]:
        """Get memories within a specific timeframe"""
        return [
            entry
            for entry in self.short_term_memory
            if start_time <= entry.metadata.timestamp <= end_time
        ]

    def export_memories(
        self, file_path: str, format: str = "json"
    ) -> None:
        """Export memories to file in specified format"""
        data = {
            "system_messages": [
                entry.model_dump() for entry in self.system_messages
            ],
            "short_term_memory": [
                entry.model_dump() for entry in self.short_term_memory
            ],
            "stats": self.get_memory_stats(),
        }

        with open(file_path, "w") as f:
            if format == "yaml":
                yaml.dump(data, f)
            else:
                json.dump(data, f, indent=2)