import pickle from pathlib import Path from typing import List, Dict, Any import faiss import numpy as np from sentence_transformers import SentenceTransformer from core.utils.logger import logger class OptimizedGazaKnowledgeBase: """Optimized knowledge base that loads pre-made FAISS index and assets""" def __init__(self, vector_store_dir: str = "./vector_store"): self.vector_store_dir = Path(vector_store_dir) self.faiss_index = None self.embedding_model = None self.chunks = [] self.metadata = [] self.is_initialized = False def initialize(self): """Load pre-made FAISS index and associated data""" try: logger.info("🔄 Loading pre-made FAISS index and assets...") # 1. Load FAISS index index_path = self.vector_store_dir / "index.faiss" if not index_path.exists(): raise FileNotFoundError(f"FAISS index not found at {index_path}") self.faiss_index = faiss.read_index(str(index_path)) logger.info(f"✅ Loaded FAISS index: {self.faiss_index.ntotal} vectors, {self.faiss_index.d} dimensions") # 2. Load chunks chunks_path = self.vector_store_dir / "chunks.txt" if not chunks_path.exists(): raise FileNotFoundError(f"Chunks file not found at {chunks_path}") with open(chunks_path,'r', encoding='utf-8') as f: lines = f.readlines() # Parse chunks from the formatted file current_chunk = "" for line in lines: line = line.strip() if line.startswith("=== Chunk") and current_chunk: self.chunks.append(current_chunk.strip()) current_chunk = "" elif not line.startswith("===") and not line.startswith("Source:") and not line.startswith("Length:"): current_chunk += line + " " # Add the last chunk if current_chunk: self.chunks.append(current_chunk.strip()) logger.info(f"✅ Loaded {len(self.chunks)} text chunks") # 3. Load metadata metadata_path = self.vector_store_dir / "metadata.pkl" if metadata_path.exists(): with open(metadata_path, 'rb') as f: metadata_dict = pickle.load(f) if isinstance(metadata_dict, dict) and 'metadata' in metadata_dict: self.metadata = metadata_dict['metadata'] logger.info(f"✅ Loaded {len(self.metadata)} metadata entries") else: logger.warning("⚠️ Metadata format not recognized, using empty metadata") self.metadata = [{}] * len(self.chunks) else: logger.warning("⚠️ No metadata file found, using empty metadata") self.metadata = [{}] * len(self.chunks) # 4. Initialize embedding model for query encoding logger.info("🔄 Loading embedding model for queries...") self.embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') logger.info("✅ Embedding model loaded") # 5. Verify data consistency if len(self.chunks) != self.faiss_index.ntotal: logger.warning(f"⚠️ Mismatch: {len(self.chunks)} chunks vs {self.faiss_index.ntotal} vectors") # Trim chunks to match index size self.chunks = self.chunks[:self.faiss_index.ntotal] self.metadata = self.metadata[:self.faiss_index.ntotal] logger.info(f"✅ Trimmed to {len(self.chunks)} chunks to match index") self.is_initialized = True logger.info("🎉 Knowledge base initialization complete!") except Exception as e: logger.error(f"❌ Failed to initialize knowledge base: {e}") raise def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]: """Search using pre-made FAISS index""" if not self.is_initialized: raise RuntimeError("Knowledge base not initialized") try: # 1. Encode query query_embedding = self.embedding_model.encode([query]) query_vector = np.array(query_embedding, dtype=np.float32) # 2. Search FAISS index distances, indices = self.faiss_index.search(query_vector, k) # 3. Prepare results results = [] for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): if idx >= 0 and idx < len(self.chunks): # Valid index chunk_metadata = self.metadata[idx] if idx < len(self.metadata) else {} result = { "text": self.chunks[idx], "score": float(1.0 / (1.0 + distance)), # Convert distance to similarity score "source": chunk_metadata.get("source", "unknown"), "chunk_index": int(idx), "distance": float(distance), "metadata": chunk_metadata } results.append(result) logger.info(f"🔍 Search for '{query}...' returned {len(results)} results") return results except Exception as e: logger.error(f"❌ Search error: {e}") return [] def get_stats(self) -> Dict[str, Any]: """Get knowledge base statistics""" if not self.is_initialized: return {"status": "not_initialized"} return { "status": "initialized", "total_chunks": len(self.chunks), "total_vectors": self.faiss_index.ntotal, "embedding_dimension": self.faiss_index.d, "index_type": type(self.faiss_index).__name__, "sources": list(set(meta.get("source", "unknown") for meta in self.metadata)) }