Spaces:
Sleeping
Sleeping
| """ | |
| Working Hyper RAG System - FINAL FIXED VERSION. | |
| Proper ID mapping between keyword index and FAISS. | |
| """ | |
| import time | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import sqlite3 | |
| import hashlib | |
| from typing import List, Tuple, Optional, Dict, Any | |
| from pathlib import Path | |
| from datetime import datetime, timedelta | |
| import re | |
| from collections import defaultdict | |
| import psutil | |
| import os | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from config import ( | |
| EMBEDDING_MODEL, DATA_DIR, FAISS_INDEX_PATH, DOCSTORE_PATH, | |
| EMBEDDING_CACHE_PATH, CHUNK_SIZE, TOP_K_DYNAMIC_HYPER, | |
| MAX_TOKENS, ENABLE_EMBEDDING_CACHE, ENABLE_QUERY_CACHE, | |
| ENABLE_PRE_FILTER, ENABLE_PROMPT_COMPRESSION | |
| ) | |
| class WorkingHyperRAG: | |
| """ | |
| Working Hyper RAG - FINAL FIXED VERSION with proper ID mapping. | |
| """ | |
| def __init__(self, metrics_tracker=None): | |
| self.metrics_tracker = metrics_tracker | |
| self.embedder = None | |
| self.faiss_index = None | |
| self.docstore_conn = None | |
| self._initialized = False | |
| self.process = psutil.Process(os.getpid()) | |
| # Use ThreadPoolExecutor | |
| self.thread_pool = ThreadPoolExecutor( | |
| max_workers=2, | |
| thread_name_prefix="HyperRAGWorker" | |
| ) | |
| # Adaptive parameters | |
| self.performance_history = [] | |
| self.avg_latency = 0 | |
| self.total_queries = 0 | |
| # In-memory cache for hot embeddings | |
| self._embedding_cache = {} | |
| # ID mapping: FAISS index (0-based) -> Database ID (1-based) | |
| self._id_mapping = {} | |
| def initialize(self): | |
| """Initialize all components - MAIN THREAD ONLY.""" | |
| if self._initialized: | |
| return | |
| print("🚀 Initializing WorkingHyperRAG...") | |
| start_time = time.perf_counter() | |
| # 1. Load embedding model | |
| self.embedder = SentenceTransformer(EMBEDDING_MODEL) | |
| # Warm up | |
| self.embedder.encode(["warmup"]) | |
| # 2. Load FAISS index | |
| if FAISS_INDEX_PATH.exists(): | |
| self.faiss_index = faiss.read_index(str(FAISS_INDEX_PATH)) | |
| print(f" Loaded FAISS index with {self.faiss_index.ntotal} vectors") | |
| else: | |
| print(" ⚠ FAISS index not found, retrieval will be limited") | |
| # 3. Connect to document store (main thread only) | |
| self.docstore_conn = sqlite3.connect(DOCSTORE_PATH) | |
| self._init_docstore_indices() | |
| # 4. Initialize embedding cache schema (create if not exists) | |
| self._init_cache_schema() | |
| # 5. Build keyword index for filtering WITH PROPER ID MAPPING | |
| self.keyword_index = self._build_keyword_index_with_mapping() | |
| init_time = (time.perf_counter() - start_time) * 1000 | |
| memory_mb = self.process.memory_info().rss / 1024 / 1024 | |
| print(f"✅ WorkingHyperRAG initialized in {init_time:.2f}ms") | |
| print(f" Memory: {memory_mb:.2f}MB") | |
| print(f" Keyword index: {len(self.keyword_index)} unique words") | |
| print(f" ID mapping: {len(self._id_mapping)} entries") | |
| self._initialized = True | |
| def _init_docstore_indices(self): | |
| """Create performance indices.""" | |
| cursor = self.docstore_conn.cursor() | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_chunk_hash ON chunks(chunk_hash)") | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_doc_id ON chunks(doc_id)") | |
| self.docstore_conn.commit() | |
| def _init_cache_schema(self): | |
| """Initialize cache schema - called once from main thread.""" | |
| if not ENABLE_EMBEDDING_CACHE: | |
| return | |
| # Create cache table if it doesn't exist | |
| conn = sqlite3.connect(EMBEDDING_CACHE_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS embedding_cache ( | |
| text_hash TEXT PRIMARY KEY, | |
| embedding BLOB NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| access_count INTEGER DEFAULT 0 | |
| ) | |
| """) | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON embedding_cache(created_at)") | |
| conn.commit() | |
| conn.close() | |
| def _build_keyword_index_with_mapping(self) -> Dict[str, List[int]]: | |
| """Build keyword index with proper FAISS ID mapping.""" | |
| cursor = self.docstore_conn.cursor() | |
| # Get chunks in the SAME ORDER they were added to FAISS | |
| cursor.execute("SELECT id, chunk_text FROM chunks ORDER BY id") | |
| chunks = cursor.fetchall() | |
| keyword_index = defaultdict(list) | |
| self._id_mapping = {} | |
| # FAISS IDs are 0-based, added in order | |
| # Database IDs are 1-based, also in order | |
| for faiss_id, (db_id, text) in enumerate(chunks): | |
| # Map FAISS ID (0-based) to Database ID (1-based) | |
| self._id_mapping[faiss_id] = db_id | |
| words = set(re.findall(r'\b\w{3,}\b', text.lower())) | |
| for word in words: | |
| # Store FAISS ID (0-based) in keyword index | |
| keyword_index[word].append(faiss_id) | |
| print(f" Built mapping: {len(self._id_mapping)} FAISS IDs -> DB IDs") | |
| return keyword_index | |
| def _faiss_id_to_db_id(self, faiss_id: int) -> int: | |
| """Convert FAISS ID (0-based) to Database ID (1-based).""" | |
| return self._id_mapping.get(faiss_id, faiss_id + 1) | |
| def _db_id_to_faiss_id(self, db_id: int) -> int: | |
| """Convert Database ID (1-based) to FAISS ID (0-based).""" | |
| # Search for the mapping (inefficient but works for small datasets) | |
| for faiss_id, mapped_db_id in self._id_mapping.items(): | |
| if mapped_db_id == db_id: | |
| return faiss_id | |
| return db_id - 1 # Fallback | |
| def _get_thread_safe_cache_connection(self): | |
| """Get a thread-local cache connection.""" | |
| return sqlite3.connect( | |
| EMBEDDING_CACHE_PATH, | |
| check_same_thread=False, | |
| timeout=10.0 | |
| ) | |
| def _get_cached_embedding(self, text: str) -> Optional[np.ndarray]: | |
| """Get embedding from cache - THREAD-SAFE.""" | |
| if not ENABLE_EMBEDDING_CACHE: | |
| return None | |
| text_hash = hashlib.md5(text.encode()).hexdigest() | |
| # Try in-memory first (fast path) | |
| if text_hash in self._embedding_cache: | |
| return self._embedding_cache[text_hash] | |
| # Check disk cache (thread-local connection) | |
| conn = self._get_thread_safe_cache_connection() | |
| try: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| "SELECT embedding FROM embedding_cache WHERE text_hash = ?", | |
| (text_hash,) | |
| ) | |
| result = cursor.fetchone() | |
| if result: | |
| cursor.execute( | |
| "UPDATE embedding_cache SET access_count = access_count + 1 WHERE text_hash = ?", | |
| (text_hash,) | |
| ) | |
| conn.commit() | |
| embedding = np.frombuffer(result[0], dtype=np.float32) | |
| self._embedding_cache[text_hash] = embedding | |
| return embedding | |
| return None | |
| finally: | |
| conn.close() | |
| def _cache_embedding(self, text: str, embedding: np.ndarray): | |
| """Cache an embedding - THREAD-SAFE.""" | |
| if not ENABLE_EMBEDDING_CACHE: | |
| return | |
| text_hash = hashlib.md5(text.encode()).hexdigest() | |
| embedding_blob = embedding.astype(np.float32).tobytes() | |
| # Cache in memory | |
| self._embedding_cache[text_hash] = embedding | |
| # Cache on disk | |
| conn = self._get_thread_safe_cache_connection() | |
| try: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """INSERT OR REPLACE INTO embedding_cache | |
| (text_hash, embedding, access_count) VALUES (?, ?, 1)""", | |
| (text_hash, embedding_blob) | |
| ) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| def _get_dynamic_top_k(self, question: str) -> int: | |
| """Determine top_k based on query complexity.""" | |
| words = len(question.split()) | |
| if words < 5: | |
| return TOP_K_DYNAMIC_HYPER["short"] | |
| elif words < 15: | |
| return TOP_K_DYNAMIC_HYPER["medium"] | |
| else: | |
| return TOP_K_DYNAMIC_HYPER["long"] | |
| def _pre_filter_chunks(self, question: str) -> Optional[List[int]]: | |
| """Intelligent pre-filtering - SIMPLIFIED VERSION.""" | |
| if not ENABLE_PRE_FILTER: | |
| return None | |
| question_words = set(re.findall(r'\b\w{3,}\b', question.lower())) | |
| if not question_words: | |
| return None | |
| candidate_ids = set() | |
| # Find chunks that match ANY question word | |
| for word in question_words: | |
| if word in self.keyword_index: | |
| candidate_ids.update(self.keyword_index[word]) | |
| if candidate_ids: | |
| print(f" [Filter] Matched {len(candidate_ids)} chunks") | |
| return list(candidate_ids) | |
| print(f" [Filter] No matches") | |
| return None | |
| def _search_faiss_intelligent(self, query_embedding: np.ndarray, | |
| top_k: int, | |
| filter_ids: Optional[List[int]] = None) -> List[int]: | |
| """Intelligent FAISS search - SIMPLIFIED AND CORRECT.""" | |
| if self.faiss_index is None: | |
| return [] | |
| query_embedding = query_embedding.astype(np.float32).reshape(1, -1) | |
| # Always search for at least 1 chunk | |
| min_k = max(1, top_k) | |
| # If we have filter IDs, search MORE then filter | |
| if filter_ids and len(filter_ids) > 0: | |
| # Search more broadly | |
| search_k = min(top_k * 5, self.faiss_index.ntotal) | |
| distances, indices = self.faiss_index.search(query_embedding, search_k) | |
| # Get FAISS results | |
| faiss_results = [int(idx) for idx in indices[0] if idx >= 0] | |
| # Filter to only include IDs in filter_ids | |
| filtered_results = [idx for idx in faiss_results if idx in filter_ids] | |
| if filtered_results: | |
| print(f" [Search] Filtered to {len(filtered_results)} chunks") | |
| return filtered_results[:min_k] | |
| else: | |
| # If filtering removed everything, use top unfiltered results | |
| print(f" [Search] No filtered matches, using top {min_k} results") | |
| return faiss_results[:min_k] | |
| else: | |
| # Regular search | |
| distances, indices = self.faiss_index.search(query_embedding, min_k) | |
| results = [int(idx) for idx in indices[0] if idx >= 0] | |
| return results | |
| def _retrieve_chunks_by_faiss_ids(self, faiss_ids: List[int]) -> List[str]: | |
| """Retrieve chunks by FAISS IDs.""" | |
| if not faiss_ids: | |
| return [] | |
| # Convert FAISS IDs to Database IDs | |
| db_ids = [self._faiss_id_to_db_id(faiss_id) for faiss_id in faiss_ids] | |
| cursor = self.docstore_conn.cursor() | |
| placeholders = ','.join('?' for _ in db_ids) | |
| query = f"SELECT chunk_text FROM chunks WHERE id IN ({placeholders}) ORDER BY id" | |
| cursor.execute(query, db_ids) | |
| return [r[0] for r in cursor.fetchall()] | |
| def _compress_prompt(self, chunks: List[str]) -> List[str]: | |
| """Intelligent prompt compression.""" | |
| if not ENABLE_PROMPT_COMPRESSION or not chunks: | |
| return chunks | |
| compressed = [] | |
| total_tokens = 0 | |
| for chunk in chunks: | |
| chunk_tokens = len(chunk.split()) | |
| if total_tokens + chunk_tokens <= MAX_TOKENS: | |
| compressed.append(chunk) | |
| total_tokens += chunk_tokens | |
| else: | |
| break | |
| return compressed | |
| def _generate_hyper_response(self, question: str, chunks: List[str]) -> str: | |
| """Generate response - FAST AND SIMPLE.""" | |
| if not chunks: | |
| return "I don't have enough specific information to answer that question." | |
| # Compress prompt | |
| compressed_chunks = self._compress_prompt(chunks) | |
| # Simulate faster generation | |
| time.sleep(0.08) | |
| # Simple response | |
| context = "\n\n".join(compressed_chunks[:3]) | |
| return f"Based on the information: {context[:300]}..." | |
| async def query_async(self, question: str, top_k: Optional[int] = None) -> Tuple[str, int]: | |
| """Async query processing - OPTIMIZED FOR SPEED.""" | |
| if not self._initialized: | |
| self.initialize() | |
| start_time = time.perf_counter() | |
| # Run embedding and filtering | |
| loop = asyncio.get_event_loop() | |
| embed_future = loop.run_in_executor( | |
| self.thread_pool, | |
| self._embed_and_cache_sync, | |
| question | |
| ) | |
| filter_future = loop.run_in_executor( | |
| self.thread_pool, | |
| self._pre_filter_chunks, | |
| question | |
| ) | |
| query_embedding, cache_status = await embed_future | |
| filter_ids = await filter_future | |
| # Determine top-k | |
| dynamic_k = self._get_dynamic_top_k(question) | |
| effective_k = top_k or dynamic_k | |
| # Search | |
| faiss_ids = self._search_faiss_intelligent(query_embedding, effective_k, filter_ids) | |
| # Retrieve chunks | |
| chunks = self._retrieve_chunks_by_faiss_ids(faiss_ids) | |
| # Generate response | |
| answer = self._generate_hyper_response(question, chunks) | |
| total_time = (time.perf_counter() - start_time) * 1000 | |
| # Log metrics | |
| print(f"[Hyper RAG] Query: '{question[:50]}...'") | |
| print(f" - Cache: {cache_status}") | |
| print(f" - Filtered: {'Yes' if filter_ids else 'No'}") | |
| print(f" - Top-K: {effective_k}") | |
| print(f" - Chunks used: {len(chunks)}") | |
| print(f" - Time: {total_time:.1f}ms") | |
| # Track metrics | |
| if self.metrics_tracker: | |
| self.metrics_tracker.record_query( | |
| model="hyper", | |
| latency_ms=total_time, | |
| memory_mb=0.0, # Minimal memory | |
| chunks_used=len(chunks), | |
| question_length=len(question) | |
| ) | |
| return answer, len(chunks) | |
| def _embed_and_cache_sync(self, text: str) -> Tuple[np.ndarray, str]: | |
| """Synchronous embedding with caching.""" | |
| cached = self._get_cached_embedding(text) | |
| if cached is not None: | |
| return cached, "HIT" | |
| embedding = self.embedder.encode([text])[0] | |
| self._cache_embedding(text, embedding) | |
| return embedding, "MISS" | |
| def query(self, question: str, top_k: Optional[int] = None) -> Tuple[str, int]: | |
| """Synchronous query wrapper.""" | |
| return asyncio.run(self.query_async(question, top_k)) | |
| def get_performance_stats(self) -> Dict[str, Any]: | |
| """Get performance statistics.""" | |
| return { | |
| "total_queries": self.total_queries, | |
| "avg_latency_ms": self.avg_latency, | |
| "memory_cache_size": len(self._embedding_cache), | |
| "keyword_index_size": len(self.keyword_index), | |
| "faiss_vectors": self.faiss_index.ntotal if self.faiss_index else 0 | |
| } | |
| def close(self): | |
| """Cleanup.""" | |
| if self.thread_pool: | |
| self.thread_pool.shutdown(wait=True) | |
| if self.docstore_conn: | |
| self.docstore_conn.close() | |
| # Quick test | |
| if __name__ == "__main__": | |
| print("\n🧪 Quick test of Fixed Hyper RAG...") | |
| from app.metrics import MetricsTracker | |
| metrics = MetricsTracker() | |
| rag = WorkingHyperRAG(metrics) | |
| # Test a simple query | |
| query = "What is machine learning?" | |
| print(f"\n📝 Query: {query}") | |
| answer, chunks = rag.query(query) | |
| print(f" Answer: {answer[:100]}...") | |
| print(f" Chunks used: {chunks}") | |
| rag.close() | |
| print("\n✅ Test complete!") | |