Spaces:
Sleeping
Sleeping
File size: 6,322 Bytes
e58c94d 985f660 e58c94d 985f660 c1d9fa4 985f660 e58c94d 985f660 e58c94d 985f660 e58c94d 391ff38 e58c94d 391ff38 e58c94d 391ff38 e58c94d 391ff38 e58c94d 985f660 e58c94d 985f660 e58c94d 391ff38 e58c94d 985f660 e58c94d 985f660 e58c94d 985f660 e58c94d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import pickle
from pathlib import Path
from typing import List, Dict, Any
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from core.utils.logger import logger
class OptimizedGazaKnowledgeBase:
"""Optimized knowledge base that loads pre-made FAISS index and assets"""
def __init__(self, vector_store_dir: str = "./vector_store"):
self.vector_store_dir = Path(vector_store_dir)
self.faiss_index = None
self.embedding_model = None
self.chunks = []
self.metadata = []
self.is_initialized = False
def initialize(self):
"""Load pre-made FAISS index and associated data"""
try:
logger.info("π Loading pre-made FAISS index and assets...")
# 1. Load FAISS index
index_path = self.vector_store_dir / "index.faiss"
if not index_path.exists():
raise FileNotFoundError(f"FAISS index not found at {index_path}")
self.faiss_index = faiss.read_index(str(index_path))
logger.info(f"β
Loaded FAISS index: {self.faiss_index.ntotal} vectors, {self.faiss_index.d} dimensions")
# 2. Load chunks
chunks_path = self.vector_store_dir / "chunks.txt"
if not chunks_path.exists():
raise FileNotFoundError(f"Chunks file not found at {chunks_path}")
with open(chunks_path,'r', encoding='utf-8') as f:
lines = f.readlines()
# Parse chunks from the formatted file
current_chunk = ""
for line in lines:
line = line.strip()
if line.startswith("=== Chunk") and current_chunk:
self.chunks.append(current_chunk.strip())
current_chunk = ""
elif not line.startswith("===") and not line.startswith("Source:") and not line.startswith("Length:"):
current_chunk += line + " "
# Add the last chunk
if current_chunk:
self.chunks.append(current_chunk.strip())
logger.info(f"β
Loaded {len(self.chunks)} text chunks")
# 3. Load metadata
metadata_path = self.vector_store_dir / "metadata.pkl"
if metadata_path.exists():
with open(metadata_path, 'rb') as f:
metadata_dict = pickle.load(f)
if isinstance(metadata_dict, dict) and 'metadata' in metadata_dict:
self.metadata = metadata_dict['metadata']
logger.info(f"β
Loaded {len(self.metadata)} metadata entries")
else:
logger.warning("β οΈ Metadata format not recognized, using empty metadata")
self.metadata = [{}] * len(self.chunks)
else:
logger.warning("β οΈ No metadata file found, using empty metadata")
self.metadata = [{}] * len(self.chunks)
# 4. Initialize embedding model for query encoding
logger.info("π Loading embedding model for queries...")
self.embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
logger.info("β
Embedding model loaded")
# 5. Verify data consistency
if len(self.chunks) != self.faiss_index.ntotal:
logger.warning(f"β οΈ Mismatch: {len(self.chunks)} chunks vs {self.faiss_index.ntotal} vectors")
# Trim chunks to match index size
self.chunks = self.chunks[:self.faiss_index.ntotal]
self.metadata = self.metadata[:self.faiss_index.ntotal]
logger.info(f"β
Trimmed to {len(self.chunks)} chunks to match index")
self.is_initialized = True
logger.info("π Knowledge base initialization complete!")
except Exception as e:
logger.error(f"β Failed to initialize knowledge base: {e}")
raise
def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
"""Search using pre-made FAISS index"""
if not self.is_initialized:
raise RuntimeError("Knowledge base not initialized")
try:
# 1. Encode query
query_embedding = self.embedding_model.encode([query])
query_vector = np.array(query_embedding, dtype=np.float32)
# 2. Search FAISS index
distances, indices = self.faiss_index.search(query_vector, k)
# 3. Prepare results
results = []
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
if idx >= 0 and idx < len(self.chunks): # Valid index
chunk_metadata = self.metadata[idx] if idx < len(self.metadata) else {}
result = {
"text": self.chunks[idx],
"score": float(1.0 / (1.0 + distance)), # Convert distance to similarity score
"source": chunk_metadata.get("source", "unknown"),
"chunk_index": int(idx),
"distance": float(distance),
"metadata": chunk_metadata
}
results.append(result)
logger.info(f"π Search for '{query}...' returned {len(results)} results")
return results
except Exception as e:
logger.error(f"β Search error: {e}")
return []
def get_stats(self) -> Dict[str, Any]:
"""Get knowledge base statistics"""
if not self.is_initialized:
return {"status": "not_initialized"}
return {
"status": "initialized",
"total_chunks": len(self.chunks),
"total_vectors": self.faiss_index.ntotal,
"embedding_dimension": self.faiss_index.d,
"index_type": type(self.faiss_index).__name__,
"sources": list(set(meta.get("source", "unknown") for meta in self.metadata))
}
|