Spaces:
Sleeping
Sleeping
# src/memory.py | |
import sqlite3 | |
from datetime import datetime, timedelta | |
import json | |
from typing import List, Dict, Any, Tuple | |
import numpy as np | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
class MemoryManager: | |
def __init__(self, db_path: str): | |
self.conn = sqlite3.connect(db_path) | |
self.cursor = self.conn.cursor() | |
self.create_tables() | |
self.vectorizer = TfidfVectorizer(stop_words='english') | |
logging.info("MemoryManager initialized and tables created.") | |
def create_tables(self): | |
# Create tables if they don't exist | |
self.cursor.execute('''CREATE TABLE IF NOT EXISTS semantic_memory | |
(id INTEGER PRIMARY KEY, concept TEXT, description TEXT, last_accessed DATETIME, tags TEXT, importance REAL DEFAULT 0.5)''') | |
# Add tags and importance columns if they don't exist | |
self.cursor.execute("PRAGMA table_info(semantic_memory)") | |
columns = [column[1] for column in self.cursor.fetchall()] | |
if 'tags' not in columns: | |
self.cursor.execute("ALTER TABLE semantic_memory ADD COLUMN tags TEXT") | |
if 'importance' not in columns: | |
self.cursor.execute("ALTER TABLE semantic_memory ADD COLUMN importance REAL DEFAULT 0.5") | |
self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_concept ON semantic_memory (concept)''') | |
self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_last_accessed ON semantic_memory (last_accessed)''') | |
self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_tags ON semantic_memory (tags)''') | |
# Create table for user interactions | |
self.cursor.execute('''CREATE TABLE IF NOT EXISTS user_interactions | |
(user_id TEXT, query TEXT, response TEXT, timestamp DATETIME)''') | |
self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_user_interactions_timestamp ON user_interactions (timestamp)''') | |
self.conn.commit() | |
logging.info("Tables and indexes created successfully.") | |
def add_semantic_memory(self, concept: str, description: str, tags: List[str] = None): | |
if tags is None: | |
tags = [] | |
tags_str = json.dumps(tags) | |
self.cursor.execute("INSERT INTO semantic_memory (concept, description, last_accessed, tags) VALUES (?, ?, ?, ?)", | |
(concept, description, datetime.now().isoformat(), tags_str)) | |
self.conn.commit() | |
logging.info("Semantic memory added.") | |
def retrieve_relevant_memories(self, query: str, limit: int = 30) -> List[Dict[str, Any]]: | |
all_memories = self._get_all_memories() | |
# Handle empty or stop-word-only query | |
if not query.strip() or self.vectorizer.stop_words and all(word in self.vectorizer.stop_words for word in query.split()): | |
return [] | |
scored_memories = self._score_memories(query, all_memories) | |
return [memory for memory, score in sorted(scored_memories, key=lambda x: x[1], reverse=True)[:limit]] | |
def _get_all_memories(self) -> List[Tuple[Dict[str, Any], datetime]]: | |
self.cursor.execute("SELECT concept, description, importance, last_accessed, tags FROM semantic_memory ORDER BY importance DESC, last_accessed DESC") | |
semantic_memories = self.cursor.fetchall() | |
all_memories = [({"concept": concept, "description": description, "importance": importance}, | |
datetime.fromisoformat(last_accessed), json.loads(tags) if tags else None) for concept, description, importance, last_accessed, tags in semantic_memories] | |
return all_memories | |
def _score_memories(self, query: str, memories: List[Tuple[Dict[str, Any], datetime, List[str]]]) -> List[Tuple[Dict[str, Any], float]]: | |
query_vector = self.vectorizer.fit_transform([query]) | |
scored_memories = [] | |
for memory, timestamp, tags in memories: | |
text = f"{memory['concept']} {memory['description']}" | |
importance = memory.get('importance', 0.5) | |
memory_vector = self.vectorizer.transform([text]) | |
similarity = cosine_similarity(query_vector, memory_vector)[0][0] | |
if timestamp: | |
recency = 1 / (1 + (datetime.now() - timestamp).total_seconds() / 60) # Favor recent memories | |
else: | |
recency = 0.5 # Neutral recency for semantic memories | |
score = (similarity + importance + recency) / 3 | |
scored_memories.append((memory, score)) | |
return scored_memories | |
def section_exists(self, concept: str) -> bool: | |
# Normalize the concept to lowercase | |
concept = concept.lower() | |
self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE ?", (f"{concept}%",)) | |
count = self.cursor.fetchone()[0] | |
return count > 0 | |
def add_user_interaction(self, user_id: str, query: str, response: str): | |
self.cursor.execute("INSERT INTO user_interactions (user_id, query, response, timestamp) VALUES (?, ?, ?, ?)", | |
(user_id, query, response, datetime.now().isoformat())) | |
self.conn.commit() | |
logging.info(f"User interaction added: User ID: {user_id}, Query: {query}, Response: {response}") | |
def get_user_interactions(self, user_id: str) -> List[Dict[str, Any]]: | |
self.cursor.execute("SELECT query, response, timestamp FROM user_interactions WHERE user_id = ?", (user_id,)) | |
interactions = self.cursor.fetchall() | |
return [{"query": query, "response": response, "timestamp": timestamp} for query, response, timestamp in interactions] | |
def cleanup_expired_interactions(self): | |
cutoff_time = datetime.now() - timedelta(minutes=5) | |
self.cursor.execute("DELETE FROM user_interactions WHERE timestamp < ?", (cutoff_time.isoformat(),)) | |
self.conn.commit() | |
logging.info(f"Expired user interactions cleaned up. Cutoff time: {cutoff_time}") | |
def get_section_description(self, section_name: str) -> str: | |
# Normalize the section name to lowercase | |
section_name = section_name.lower() | |
# Retrieve the specific section from the database | |
self.cursor.execute("SELECT description FROM semantic_memory WHERE concept LIKE ?", (f"{section_name}%",)) | |
result = self.cursor.fetchone() | |
if result: | |
logging.info(f"Found section: {section_name}") | |
return result[0] | |
else: | |
logging.warning(f"Section not found: {section_name}") | |
return "" | |
def count_chroniques(self) -> int: | |
# Count the number of chroniques in the database | |
self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'chronique #%'") | |
count = self.cursor.fetchone()[0] | |
logging.info(f"Number of chroniques: {count}") | |
return count | |
def count_flash_infos(self) -> int: | |
# Count the number of flash infos in the database | |
self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'flash info fl-%'") | |
count = self.cursor.fetchone()[0] | |
logging.info(f"Number of flash infos: {count}") | |
return count | |
def count_chronique_faqs(self) -> int: | |
# Count the number of chronique-faqs in the database | |
self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'chronique-faq #%'") | |
count = self.cursor.fetchone()[0] | |
logging.info(f"Number of chronique-faqs: {count}") | |
return count | |
if __name__ == "__main__": | |
db_path = "agent.db" | |
memory_manager = MemoryManager(db_path) | |
memory_manager.cleanup_expired_interactions() | |