# src/memory.py import sqlite3 from datetime import datetime, timedelta import json from typing import List, Dict, Any, Tuple import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import logging # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') class MemoryManager: def __init__(self, db_path: str): self.conn = sqlite3.connect(db_path) self.cursor = self.conn.cursor() self.create_tables() self.vectorizer = TfidfVectorizer(stop_words='english') logging.info("MemoryManager initialized and tables created.") def create_tables(self): # Create tables if they don't exist self.cursor.execute('''CREATE TABLE IF NOT EXISTS semantic_memory (id INTEGER PRIMARY KEY, concept TEXT, description TEXT, last_accessed DATETIME, tags TEXT, importance REAL DEFAULT 0.5)''') # Add tags and importance columns if they don't exist self.cursor.execute("PRAGMA table_info(semantic_memory)") columns = [column[1] for column in self.cursor.fetchall()] if 'tags' not in columns: self.cursor.execute("ALTER TABLE semantic_memory ADD COLUMN tags TEXT") if 'importance' not in columns: self.cursor.execute("ALTER TABLE semantic_memory ADD COLUMN importance REAL DEFAULT 0.5") self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_concept ON semantic_memory (concept)''') self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_last_accessed ON semantic_memory (last_accessed)''') self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_tags ON semantic_memory (tags)''') # Create table for user interactions self.cursor.execute('''CREATE TABLE IF NOT EXISTS user_interactions (user_id TEXT, query TEXT, response TEXT, timestamp DATETIME)''') self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_user_interactions_timestamp ON user_interactions (timestamp)''') self.conn.commit() logging.info("Tables and indexes created successfully.") def add_semantic_memory(self, concept: str, description: str, tags: List[str] = None): if tags is None: tags = [] tags_str = json.dumps(tags) self.cursor.execute("INSERT INTO semantic_memory (concept, description, last_accessed, tags) VALUES (?, ?, ?, ?)", (concept, description, datetime.now().isoformat(), tags_str)) self.conn.commit() logging.info("Semantic memory added.") def retrieve_relevant_memories(self, query: str, limit: int = 30) -> List[Dict[str, Any]]: all_memories = self._get_all_memories() # Handle empty or stop-word-only query if not query.strip() or self.vectorizer.stop_words and all(word in self.vectorizer.stop_words for word in query.split()): return [] scored_memories = self._score_memories(query, all_memories) return [memory for memory, score in sorted(scored_memories, key=lambda x: x[1], reverse=True)[:limit]] def _get_all_memories(self) -> List[Tuple[Dict[str, Any], datetime]]: self.cursor.execute("SELECT concept, description, importance, last_accessed, tags FROM semantic_memory ORDER BY importance DESC, last_accessed DESC") semantic_memories = self.cursor.fetchall() all_memories = [({"concept": concept, "description": description, "importance": importance}, datetime.fromisoformat(last_accessed), json.loads(tags) if tags else None) for concept, description, importance, last_accessed, tags in semantic_memories] return all_memories def _score_memories(self, query: str, memories: List[Tuple[Dict[str, Any], datetime, List[str]]]) -> List[Tuple[Dict[str, Any], float]]: query_vector = self.vectorizer.fit_transform([query]) scored_memories = [] for memory, timestamp, tags in memories: text = f"{memory['concept']} {memory['description']}" importance = memory.get('importance', 0.5) memory_vector = self.vectorizer.transform([text]) similarity = cosine_similarity(query_vector, memory_vector)[0][0] if timestamp: recency = 1 / (1 + (datetime.now() - timestamp).total_seconds() / 60) # Favor recent memories else: recency = 0.5 # Neutral recency for semantic memories score = (similarity + importance + recency) / 3 scored_memories.append((memory, score)) return scored_memories def section_exists(self, concept: str) -> bool: # Normalize the concept to lowercase concept = concept.lower() self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE ?", (f"{concept}%",)) count = self.cursor.fetchone()[0] return count > 0 def add_user_interaction(self, user_id: str, query: str, response: str): self.cursor.execute("INSERT INTO user_interactions (user_id, query, response, timestamp) VALUES (?, ?, ?, ?)", (user_id, query, response, datetime.now().isoformat())) self.conn.commit() logging.info(f"User interaction added: User ID: {user_id}, Query: {query}, Response: {response}") def get_user_interactions(self, user_id: str) -> List[Dict[str, Any]]: self.cursor.execute("SELECT query, response, timestamp FROM user_interactions WHERE user_id = ?", (user_id,)) interactions = self.cursor.fetchall() return [{"query": query, "response": response, "timestamp": timestamp} for query, response, timestamp in interactions] def cleanup_expired_interactions(self): cutoff_time = datetime.now() - timedelta(minutes=5) self.cursor.execute("DELETE FROM user_interactions WHERE timestamp < ?", (cutoff_time.isoformat(),)) self.conn.commit() logging.info(f"Expired user interactions cleaned up. Cutoff time: {cutoff_time}") def get_section_description(self, section_name: str) -> str: # Normalize the section name to lowercase section_name = section_name.lower() # Retrieve the specific section from the database self.cursor.execute("SELECT description FROM semantic_memory WHERE concept LIKE ?", (f"{section_name}%",)) result = self.cursor.fetchone() if result: logging.info(f"Found section: {section_name}") return result[0] else: logging.warning(f"Section not found: {section_name}") return "" def count_chroniques(self) -> int: # Count the number of chroniques in the database self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'chronique #%'") count = self.cursor.fetchone()[0] logging.info(f"Number of chroniques: {count}") return count def count_flash_infos(self) -> int: # Count the number of flash infos in the database self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'flash info fl-%'") count = self.cursor.fetchone()[0] logging.info(f"Number of flash infos: {count}") return count def count_chronique_faqs(self) -> int: # Count the number of chronique-faqs in the database self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'chronique-faq #%'") count = self.cursor.fetchone()[0] logging.info(f"Number of chronique-faqs: {count}") return count if __name__ == "__main__": db_path = "agent.db" memory_manager = MemoryManager(db_path) memory_manager.cleanup_expired_interactions()