MTC / src /memory.py
userlocallm's picture
Upload 17 files
500516e verified
# src/memory.py
import sqlite3
from datetime import datetime, timedelta
import json
from typing import List, Dict, Any, Tuple
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class MemoryManager:
def __init__(self, db_path: str):
self.conn = sqlite3.connect(db_path)
self.cursor = self.conn.cursor()
self.create_tables()
self.vectorizer = TfidfVectorizer(stop_words='english')
logging.info("MemoryManager initialized and tables created.")
def create_tables(self):
# Create tables if they don't exist
self.cursor.execute('''CREATE TABLE IF NOT EXISTS semantic_memory
(id INTEGER PRIMARY KEY, concept TEXT, description TEXT, last_accessed DATETIME, tags TEXT, importance REAL DEFAULT 0.5)''')
# Add tags and importance columns if they don't exist
self.cursor.execute("PRAGMA table_info(semantic_memory)")
columns = [column[1] for column in self.cursor.fetchall()]
if 'tags' not in columns:
self.cursor.execute("ALTER TABLE semantic_memory ADD COLUMN tags TEXT")
if 'importance' not in columns:
self.cursor.execute("ALTER TABLE semantic_memory ADD COLUMN importance REAL DEFAULT 0.5")
self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_concept ON semantic_memory (concept)''')
self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_last_accessed ON semantic_memory (last_accessed)''')
self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_semantic_tags ON semantic_memory (tags)''')
# Create table for user interactions
self.cursor.execute('''CREATE TABLE IF NOT EXISTS user_interactions
(user_id TEXT, query TEXT, response TEXT, timestamp DATETIME)''')
self.cursor.execute('''CREATE INDEX IF NOT EXISTS idx_user_interactions_timestamp ON user_interactions (timestamp)''')
self.conn.commit()
logging.info("Tables and indexes created successfully.")
def add_semantic_memory(self, concept: str, description: str, tags: List[str] = None):
if tags is None:
tags = []
tags_str = json.dumps(tags)
self.cursor.execute("INSERT INTO semantic_memory (concept, description, last_accessed, tags) VALUES (?, ?, ?, ?)",
(concept, description, datetime.now().isoformat(), tags_str))
self.conn.commit()
logging.info("Semantic memory added.")
def retrieve_relevant_memories(self, query: str, limit: int = 30) -> List[Dict[str, Any]]:
all_memories = self._get_all_memories()
# Handle empty or stop-word-only query
if not query.strip() or self.vectorizer.stop_words and all(word in self.vectorizer.stop_words for word in query.split()):
return []
scored_memories = self._score_memories(query, all_memories)
return [memory for memory, score in sorted(scored_memories, key=lambda x: x[1], reverse=True)[:limit]]
def _get_all_memories(self) -> List[Tuple[Dict[str, Any], datetime]]:
self.cursor.execute("SELECT concept, description, importance, last_accessed, tags FROM semantic_memory ORDER BY importance DESC, last_accessed DESC")
semantic_memories = self.cursor.fetchall()
all_memories = [({"concept": concept, "description": description, "importance": importance},
datetime.fromisoformat(last_accessed), json.loads(tags) if tags else None) for concept, description, importance, last_accessed, tags in semantic_memories]
return all_memories
def _score_memories(self, query: str, memories: List[Tuple[Dict[str, Any], datetime, List[str]]]) -> List[Tuple[Dict[str, Any], float]]:
query_vector = self.vectorizer.fit_transform([query])
scored_memories = []
for memory, timestamp, tags in memories:
text = f"{memory['concept']} {memory['description']}"
importance = memory.get('importance', 0.5)
memory_vector = self.vectorizer.transform([text])
similarity = cosine_similarity(query_vector, memory_vector)[0][0]
if timestamp:
recency = 1 / (1 + (datetime.now() - timestamp).total_seconds() / 60) # Favor recent memories
else:
recency = 0.5 # Neutral recency for semantic memories
score = (similarity + importance + recency) / 3
scored_memories.append((memory, score))
return scored_memories
def section_exists(self, concept: str) -> bool:
# Normalize the concept to lowercase
concept = concept.lower()
self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE ?", (f"{concept}%",))
count = self.cursor.fetchone()[0]
return count > 0
def add_user_interaction(self, user_id: str, query: str, response: str):
self.cursor.execute("INSERT INTO user_interactions (user_id, query, response, timestamp) VALUES (?, ?, ?, ?)",
(user_id, query, response, datetime.now().isoformat()))
self.conn.commit()
logging.info(f"User interaction added: User ID: {user_id}, Query: {query}, Response: {response}")
def get_user_interactions(self, user_id: str) -> List[Dict[str, Any]]:
self.cursor.execute("SELECT query, response, timestamp FROM user_interactions WHERE user_id = ?", (user_id,))
interactions = self.cursor.fetchall()
return [{"query": query, "response": response, "timestamp": timestamp} for query, response, timestamp in interactions]
def cleanup_expired_interactions(self):
cutoff_time = datetime.now() - timedelta(minutes=5)
self.cursor.execute("DELETE FROM user_interactions WHERE timestamp < ?", (cutoff_time.isoformat(),))
self.conn.commit()
logging.info(f"Expired user interactions cleaned up. Cutoff time: {cutoff_time}")
def get_section_description(self, section_name: str) -> str:
# Normalize the section name to lowercase
section_name = section_name.lower()
# Retrieve the specific section from the database
self.cursor.execute("SELECT description FROM semantic_memory WHERE concept LIKE ?", (f"{section_name}%",))
result = self.cursor.fetchone()
if result:
logging.info(f"Found section: {section_name}")
return result[0]
else:
logging.warning(f"Section not found: {section_name}")
return ""
def count_chroniques(self) -> int:
# Count the number of chroniques in the database
self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'chronique #%'")
count = self.cursor.fetchone()[0]
logging.info(f"Number of chroniques: {count}")
return count
def count_flash_infos(self) -> int:
# Count the number of flash infos in the database
self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'flash info fl-%'")
count = self.cursor.fetchone()[0]
logging.info(f"Number of flash infos: {count}")
return count
def count_chronique_faqs(self) -> int:
# Count the number of chronique-faqs in the database
self.cursor.execute("SELECT COUNT(*) FROM semantic_memory WHERE concept LIKE 'chronique-faq #%'")
count = self.cursor.fetchone()[0]
logging.info(f"Number of chronique-faqs: {count}")
return count
if __name__ == "__main__":
db_path = "agent.db"
memory_manager = MemoryManager(db_path)
memory_manager.cleanup_expired_interactions()