firstaid / core /models /knowledge_base.py
rivapereira123's picture
Update core/models/knowledge_base.py
d67c624 verified
import pickle
from pathlib import Path
from typing import List, Dict, Any
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from core.utils.logger import logger
class OptimizedGazaKnowledgeBase:
"""Optimized knowledge base that loads pre-made FAISS index and assets"""
def __init__(self, vector_store_dir: str = "./vector_store"):
self.vector_store_dir = Path(vector_store_dir)
self.faiss_index = None
self.embedding_model = None
self.chunks = []
self.metadata = []
self.is_initialized = False
def initialize(self):
"""Load pre-made FAISS index and associated data"""
try:
logger.info("πŸ”„ Loading pre-made FAISS index and assets...")
# 1. Load FAISS index
index_path = self.vector_store_dir / "index.faiss"
if not index_path.exists():
raise FileNotFoundError(f"FAISS index not found at {index_path}")
self.faiss_index = faiss.read_index(str(index_path))
logger.info(f"βœ… Loaded FAISS index: {self.faiss_index.ntotal} vectors, {self.faiss_index.d} dimensions")
# 2. Load chunks
chunks_path = self.vector_store_dir / "chunks.txt"
if not chunks_path.exists():
raise FileNotFoundError(f"Chunks file not found at {chunks_path}")
with open(chunks_path,'r', encoding='utf-8') as f:
lines = f.readlines()
# Parse chunks from the formatted file
current_chunk = ""
for line in lines:
line = line.strip()
if line.startswith("=== Chunk") and current_chunk:
self.chunks.append(current_chunk.strip())
current_chunk = ""
elif not line.startswith("===") and not line.startswith("Source:") and not line.startswith("Length:"):
current_chunk += line + " "
# Add the last chunk
if current_chunk:
self.chunks.append(current_chunk.strip())
logger.info(f"βœ… Loaded {len(self.chunks)} text chunks")
# 3. Load metadata
metadata_path = self.vector_store_dir / "metadata.pkl"
if metadata_path.exists():
with open(metadata_path, 'rb') as f:
metadata_dict = pickle.load(f)
if isinstance(metadata_dict, dict) and 'metadata' in metadata_dict:
self.metadata = metadata_dict['metadata']
logger.info(f"βœ… Loaded {len(self.metadata)} metadata entries")
else:
logger.warning("⚠️ Metadata format not recognized, using empty metadata")
self.metadata = [{}] * len(self.chunks)
else:
logger.warning("⚠️ No metadata file found, using empty metadata")
self.metadata = [{}] * len(self.chunks)
# 4. Initialize embedding model for query encoding
logger.info("πŸ”„ Loading embedding model for queries...")
self.embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
logger.info("βœ… Embedding model loaded")
# 5. Verify data consistency
if len(self.chunks) != self.faiss_index.ntotal:
logger.warning(f"⚠️ Mismatch: {len(self.chunks)} chunks vs {self.faiss_index.ntotal} vectors")
# Trim chunks to match index size
self.chunks = self.chunks[:self.faiss_index.ntotal]
self.metadata = self.metadata[:self.faiss_index.ntotal]
logger.info(f"βœ… Trimmed to {len(self.chunks)} chunks to match index")
self.is_initialized = True
logger.info("πŸŽ‰ Knowledge base initialization complete!")
except Exception as e:
logger.error(f"❌ Failed to initialize knowledge base: {e}")
raise
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
"""Search using pre-made FAISS index"""
if not self.is_initialized:
raise RuntimeError("Knowledge base not initialized")
try:
# 1. Encode query
query_embedding = self.embedding_model.encode([query])
query_vector = np.array(query_embedding, dtype=np.float32)
# 2. Search FAISS index
distances, indices = self.faiss_index.search(query_vector, k)
# 3. Prepare results
results = []
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
if idx >= 0 and idx < len(self.chunks): # Valid index
chunk_metadata = self.metadata[idx] if idx < len(self.metadata) else {}
result = {
"text": self.chunks[idx],
"score": float(1.0 / (1.0 + distance)), # Convert distance to similarity score
"source": chunk_metadata.get("source", "unknown"),
"chunk_index": int(idx),
"distance": float(distance),
"metadata": chunk_metadata
}
results.append(result)
logger.info(f"πŸ” Search for '{query}...' returned {len(results)} results")
return results
except Exception as e:
logger.error(f"❌ Search error: {e}")
return []
def get_stats(self) -> Dict[str, Any]:
"""Get knowledge base statistics"""
if not self.is_initialized:
return {"status": "not_initialized"}
return {
"status": "initialized",
"total_chunks": len(self.chunks),
"total_vectors": self.faiss_index.ntotal,
"embedding_dimension": self.faiss_index.d,
"index_type": type(self.faiss_index).__name__,
"sources": list(set(meta.get("source", "unknown") for meta in self.metadata))
}