Spaces:
Sleeping
Sleeping
import pickle | |
from pathlib import Path | |
from typing import List, Dict, Any | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from core.utils.logger import logger | |
class OptimizedGazaKnowledgeBase: | |
"""Optimized knowledge base that loads pre-made FAISS index and assets""" | |
def __init__(self, vector_store_dir: str = "./vector_store"): | |
self.vector_store_dir = Path(vector_store_dir) | |
self.faiss_index = None | |
self.embedding_model = None | |
self.chunks = [] | |
self.metadata = [] | |
self.is_initialized = False | |
def initialize(self): | |
"""Load pre-made FAISS index and associated data""" | |
try: | |
logger.info("π Loading pre-made FAISS index and assets...") | |
# 1. Load FAISS index | |
index_path = self.vector_store_dir / "index.faiss" | |
if not index_path.exists(): | |
raise FileNotFoundError(f"FAISS index not found at {index_path}") | |
self.faiss_index = faiss.read_index(str(index_path)) | |
logger.info(f"β Loaded FAISS index: {self.faiss_index.ntotal} vectors, {self.faiss_index.d} dimensions") | |
# 2. Load chunks | |
chunks_path = self.vector_store_dir / "chunks.txt" | |
if not chunks_path.exists(): | |
raise FileNotFoundError(f"Chunks file not found at {chunks_path}") | |
with open(chunks_path,'r', encoding='utf-8') as f: | |
lines = f.readlines() | |
# Parse chunks from the formatted file | |
current_chunk = "" | |
for line in lines: | |
line = line.strip() | |
if line.startswith("=== Chunk") and current_chunk: | |
self.chunks.append(current_chunk.strip()) | |
current_chunk = "" | |
elif not line.startswith("===") and not line.startswith("Source:") and not line.startswith("Length:"): | |
current_chunk += line + " " | |
# Add the last chunk | |
if current_chunk: | |
self.chunks.append(current_chunk.strip()) | |
logger.info(f"β Loaded {len(self.chunks)} text chunks") | |
# 3. Load metadata | |
metadata_path = self.vector_store_dir / "metadata.pkl" | |
if metadata_path.exists(): | |
with open(metadata_path, 'rb') as f: | |
metadata_dict = pickle.load(f) | |
if isinstance(metadata_dict, dict) and 'metadata' in metadata_dict: | |
self.metadata = metadata_dict['metadata'] | |
logger.info(f"β Loaded {len(self.metadata)} metadata entries") | |
else: | |
logger.warning("β οΈ Metadata format not recognized, using empty metadata") | |
self.metadata = [{}] * len(self.chunks) | |
else: | |
logger.warning("β οΈ No metadata file found, using empty metadata") | |
self.metadata = [{}] * len(self.chunks) | |
# 4. Initialize embedding model for query encoding | |
logger.info("π Loading embedding model for queries...") | |
self.embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
logger.info("β Embedding model loaded") | |
# 5. Verify data consistency | |
if len(self.chunks) != self.faiss_index.ntotal: | |
logger.warning(f"β οΈ Mismatch: {len(self.chunks)} chunks vs {self.faiss_index.ntotal} vectors") | |
# Trim chunks to match index size | |
self.chunks = self.chunks[:self.faiss_index.ntotal] | |
self.metadata = self.metadata[:self.faiss_index.ntotal] | |
logger.info(f"β Trimmed to {len(self.chunks)} chunks to match index") | |
self.is_initialized = True | |
logger.info("π Knowledge base initialization complete!") | |
except Exception as e: | |
logger.error(f"β Failed to initialize knowledge base: {e}") | |
raise | |
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]: | |
"""Search using pre-made FAISS index""" | |
if not self.is_initialized: | |
raise RuntimeError("Knowledge base not initialized") | |
try: | |
# 1. Encode query | |
query_embedding = self.embedding_model.encode([query]) | |
query_vector = np.array(query_embedding, dtype=np.float32) | |
# 2. Search FAISS index | |
distances, indices = self.faiss_index.search(query_vector, k) | |
# 3. Prepare results | |
results = [] | |
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): | |
if idx >= 0 and idx < len(self.chunks): # Valid index | |
chunk_metadata = self.metadata[idx] if idx < len(self.metadata) else {} | |
result = { | |
"text": self.chunks[idx], | |
"score": float(1.0 / (1.0 + distance)), # Convert distance to similarity score | |
"source": chunk_metadata.get("source", "unknown"), | |
"chunk_index": int(idx), | |
"distance": float(distance), | |
"metadata": chunk_metadata | |
} | |
results.append(result) | |
logger.info(f"π Search for '{query}...' returned {len(results)} results") | |
return results | |
except Exception as e: | |
logger.error(f"β Search error: {e}") | |
return [] | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get knowledge base statistics""" | |
if not self.is_initialized: | |
return {"status": "not_initialized"} | |
return { | |
"status": "initialized", | |
"total_chunks": len(self.chunks), | |
"total_vectors": self.faiss_index.ntotal, | |
"embedding_dimension": self.faiss_index.d, | |
"index_type": type(self.faiss_index).__name__, | |
"sources": list(set(meta.get("source", "unknown") for meta in self.metadata)) | |
} | |