CortexSearch / text_engine.py
junaid17's picture
Update text_engine.py
717901c verified
# text_engine.py
import os
import pickle
import logging
from typing import List, Optional
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
from rank_bm25 import BM25Okapi
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Text_Search_Engine:
def __init__(
self,
base_folder: str = "vector_store",
model_name: str = "sentence-transformers/LaBSE",
index_type: str = "flat",
):
self.base_folder = base_folder
self.embeddings_folder = os.path.join(base_folder, "embeddings")
self.docs_folder = os.path.join(base_folder, "documents")
os.makedirs(self.embeddings_folder, exist_ok=True)
os.makedirs(self.docs_folder, exist_ok=True)
self.model = SentenceTransformer(model_name)
self.index: Optional[faiss.Index] = None
self.rows: List[dict] = []
self.texts: List[str] = []
self.bm25: Optional[BM25Okapi] = None
self.index_type = index_type
# -------------------------
# Index creation utilities
# -------------------------
def _create_index(self, dimension: int, embeddings: np.ndarray):
if self.index_type == "flat":
self.index = faiss.IndexFlatL2(dimension)
elif self.index_type == "ivf":
nlist = max(1, min(256, len(embeddings) // 10))
quantizer = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2)
self.index.train(np.array(embeddings).astype("float32"))
elif self.index_type == "hnsw":
self.index = faiss.IndexHNSWFlat(dimension, 32)
else:
raise ValueError(f"Unsupported index type: {self.index_type}")
def _persist(self):
try:
if self.index is not None:
faiss.write_index(self.index, os.path.join(self.embeddings_folder, "multilingual.index"))
with open(os.path.join(self.docs_folder, "rows.pkl"), "wb") as f:
pickle.dump(self.rows, f)
logger.info("Persisted index and rows to disk.")
except Exception as e:
logger.exception("Failed to persist index/rows: %s", e)
# -------------------------
# Core operations
# -------------------------
def encode_store(self, rows: List[dict], texts: List[str]):
try:
embeddings = self.model.encode(texts, convert_to_numpy=True)
dimension = embeddings.shape[1]
self._create_index(dimension, embeddings)
self.index.add(np.array(embeddings).astype("float32"))
self.rows = rows
self.texts = texts
tokenized_corpus = [t.lower().split() for t in texts]
self.bm25 = BM25Okapi(tokenized_corpus)
self._persist()
logger.info("Index built with %d rows (index_type=%s).", len(rows), self.index_type)
except Exception as e:
logger.exception("Error in encode_store: %s", e)
raise
def load(self):
try:
index_path = os.path.join(self.embeddings_folder, "multilingual.index")
rows_path = os.path.join(self.docs_folder, "rows.pkl")
if os.path.exists(index_path) and os.path.exists(rows_path):
self.index = faiss.read_index(index_path)
with open(rows_path, "rb") as f:
self.rows = pickle.load(f)
self.texts = [r["_search_text"] for r in self.rows]
tokenized_corpus = [t.lower().split() for t in self.texts]
self.bm25 = BM25Okapi(tokenized_corpus)
logger.info("Loaded index and %d rows from disk.", len(self.rows))
else:
logger.info("No persisted index/rows found.")
except Exception as e:
logger.exception("Error in load: %s", e)
raise
def add_rows(self, new_rows: List[dict], new_texts: List[str]):
try:
if not new_rows:
return
new_embeddings = self.model.encode(new_texts, convert_to_numpy=True).astype("float32")
if self.index is None:
self._create_index(new_embeddings.shape[1], new_embeddings)
self.index.add(new_embeddings)
else:
if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained:
combined = np.vstack([self.model.encode(self.texts, convert_to_numpy=True).astype("float32"), new_embeddings]) if self.texts else new_embeddings
self.index.train(combined)
self.index.add(new_embeddings)
self.rows.extend(new_rows)
self.texts.extend(new_texts)
tokenized_corpus = [t.lower().split() for t in self.texts]
self.bm25 = BM25Okapi(tokenized_corpus)
self._persist()
logger.info("Added %d new rows. Total rows: %d", len(new_rows), len(self.rows))
except Exception as e:
logger.exception("Error in add_rows: %s", e)
raise
# -------------------------
# Search methods
# -------------------------
def search(self, query: str, top_k: int = 3):
try:
if self.index is None:
return []
query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32")
k = min(top_k, len(self.rows))
distances, indices = self.index.search(query_emb, k=k)
results = [
{**self.rows[i], "distance": float(distances[0][j])}
for j, i in enumerate(indices[0])
]
return sorted(results, key=lambda x: x["distance"])
except Exception as e:
logger.exception("Error in search: %s", e)
return []
def hybrid_search(self, query: str, top_k: int = 3, alpha: float = 0.5):
try:
if self.index is None or self.bm25 is None:
return []
# ๐Ÿ”น Step 1: Encode query
query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32")
# ๐Ÿ”น Step 2: Retrieve top candidates (IMPORTANT)
retrieve_k = min(20, len(self.texts)) # candidate pool
distances, indices = self.index.search(query_emb, k=retrieve_k)
candidate_ids = indices[0]
# ๐Ÿ”น Step 3: Semantic scores (convert distance โ†’ similarity)
sem_scores = {}
for j, i in enumerate(candidate_ids):
sim = 1 / (1 + distances[0][j])
sem_scores[i] = sim
# ๐Ÿ”น Step 4: BM25 scores (only for candidates)
tokenized_query = query.lower().split()
bm25_scores = self.bm25.get_scores(tokenized_query)
lex_scores = {i: bm25_scores[i] for i in candidate_ids}
# ๐Ÿ”น Step 5: NORMALIZATION (CRITICAL)
def normalize(scores_dict):
vals = list(scores_dict.values())
if not vals:
return scores_dict
min_v, max_v = min(vals), max(vals)
if max_v - min_v == 0:
return {k: 0.0 for k in scores_dict}
return {k: (v - min_v) / (max_v - min_v) for k, v in scores_dict.items()}
sem_scores = normalize(sem_scores)
lex_scores = normalize(lex_scores)
# ๐Ÿ”น Step 6: Combine scores
combined = []
for i in candidate_ids:
sem = sem_scores.get(i, 0.0)
lex = lex_scores.get(i, 0.0)
score = alpha * sem + (1 - alpha) * lex
combined.append({**self.rows[i], "score": float(score)})
# ๐Ÿ”น Step 7: Sort and return
combined = sorted(combined, key=lambda x: x["score"], reverse=True)
return combined[:top_k]
except Exception as e:
logger.exception("Error in hybrid_search: %s", e)
return []
# -------------------------
# Utilities
# -------------------------
def clear_vdb(self):
try:
if self.index is not None:
try:
self.index.reset()
except Exception:
self.index = None
self.rows = []
self.texts = []
self.bm25 = None
index_path = os.path.join(self.embeddings_folder, "multilingual.index")
docs_path = os.path.join(self.docs_folder, "rows.pkl")
if os.path.exists(index_path):
os.remove(index_path)
if os.path.exists(docs_path):
os.remove(docs_path)
logger.info("Cleared vector DB and persisted files.")
except Exception as e:
logger.exception("Error in clear_vdb: %s", e)
raise