Spaces:
Running
Running
| # text_engine.py | |
| import os | |
| import pickle | |
| import logging | |
| from typing import List, Optional | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| from rank_bm25 import BM25Okapi | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Text_Search_Engine: | |
| def __init__( | |
| self, | |
| base_folder: str = "vector_store", | |
| model_name: str = "sentence-transformers/LaBSE", | |
| index_type: str = "flat", | |
| ): | |
| self.base_folder = base_folder | |
| self.embeddings_folder = os.path.join(base_folder, "embeddings") | |
| self.docs_folder = os.path.join(base_folder, "documents") | |
| os.makedirs(self.embeddings_folder, exist_ok=True) | |
| os.makedirs(self.docs_folder, exist_ok=True) | |
| self.model = SentenceTransformer(model_name) | |
| self.index: Optional[faiss.Index] = None | |
| self.rows: List[dict] = [] | |
| self.texts: List[str] = [] | |
| self.bm25: Optional[BM25Okapi] = None | |
| self.index_type = index_type | |
| # ------------------------- | |
| # Index creation utilities | |
| # ------------------------- | |
| def _create_index(self, dimension: int, embeddings: np.ndarray): | |
| if self.index_type == "flat": | |
| self.index = faiss.IndexFlatL2(dimension) | |
| elif self.index_type == "ivf": | |
| nlist = max(1, min(256, len(embeddings) // 10)) | |
| quantizer = faiss.IndexFlatL2(dimension) | |
| self.index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_L2) | |
| self.index.train(np.array(embeddings).astype("float32")) | |
| elif self.index_type == "hnsw": | |
| self.index = faiss.IndexHNSWFlat(dimension, 32) | |
| else: | |
| raise ValueError(f"Unsupported index type: {self.index_type}") | |
| def _persist(self): | |
| try: | |
| if self.index is not None: | |
| faiss.write_index(self.index, os.path.join(self.embeddings_folder, "multilingual.index")) | |
| with open(os.path.join(self.docs_folder, "rows.pkl"), "wb") as f: | |
| pickle.dump(self.rows, f) | |
| logger.info("Persisted index and rows to disk.") | |
| except Exception as e: | |
| logger.exception("Failed to persist index/rows: %s", e) | |
| # ------------------------- | |
| # Core operations | |
| # ------------------------- | |
| def encode_store(self, rows: List[dict], texts: List[str]): | |
| try: | |
| embeddings = self.model.encode(texts, convert_to_numpy=True) | |
| dimension = embeddings.shape[1] | |
| self._create_index(dimension, embeddings) | |
| self.index.add(np.array(embeddings).astype("float32")) | |
| self.rows = rows | |
| self.texts = texts | |
| tokenized_corpus = [t.lower().split() for t in texts] | |
| self.bm25 = BM25Okapi(tokenized_corpus) | |
| self._persist() | |
| logger.info("Index built with %d rows (index_type=%s).", len(rows), self.index_type) | |
| except Exception as e: | |
| logger.exception("Error in encode_store: %s", e) | |
| raise | |
| def load(self): | |
| try: | |
| index_path = os.path.join(self.embeddings_folder, "multilingual.index") | |
| rows_path = os.path.join(self.docs_folder, "rows.pkl") | |
| if os.path.exists(index_path) and os.path.exists(rows_path): | |
| self.index = faiss.read_index(index_path) | |
| with open(rows_path, "rb") as f: | |
| self.rows = pickle.load(f) | |
| self.texts = [r["_search_text"] for r in self.rows] | |
| tokenized_corpus = [t.lower().split() for t in self.texts] | |
| self.bm25 = BM25Okapi(tokenized_corpus) | |
| logger.info("Loaded index and %d rows from disk.", len(self.rows)) | |
| else: | |
| logger.info("No persisted index/rows found.") | |
| except Exception as e: | |
| logger.exception("Error in load: %s", e) | |
| raise | |
| def add_rows(self, new_rows: List[dict], new_texts: List[str]): | |
| try: | |
| if not new_rows: | |
| return | |
| new_embeddings = self.model.encode(new_texts, convert_to_numpy=True).astype("float32") | |
| if self.index is None: | |
| self._create_index(new_embeddings.shape[1], new_embeddings) | |
| self.index.add(new_embeddings) | |
| else: | |
| if isinstance(self.index, faiss.IndexIVFFlat) and not self.index.is_trained: | |
| combined = np.vstack([self.model.encode(self.texts, convert_to_numpy=True).astype("float32"), new_embeddings]) if self.texts else new_embeddings | |
| self.index.train(combined) | |
| self.index.add(new_embeddings) | |
| self.rows.extend(new_rows) | |
| self.texts.extend(new_texts) | |
| tokenized_corpus = [t.lower().split() for t in self.texts] | |
| self.bm25 = BM25Okapi(tokenized_corpus) | |
| self._persist() | |
| logger.info("Added %d new rows. Total rows: %d", len(new_rows), len(self.rows)) | |
| except Exception as e: | |
| logger.exception("Error in add_rows: %s", e) | |
| raise | |
| # ------------------------- | |
| # Search methods | |
| # ------------------------- | |
| def search(self, query: str, top_k: int = 3): | |
| try: | |
| if self.index is None: | |
| return [] | |
| query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32") | |
| k = min(top_k, len(self.rows)) | |
| distances, indices = self.index.search(query_emb, k=k) | |
| results = [ | |
| {**self.rows[i], "distance": float(distances[0][j])} | |
| for j, i in enumerate(indices[0]) | |
| ] | |
| return sorted(results, key=lambda x: x["distance"]) | |
| except Exception as e: | |
| logger.exception("Error in search: %s", e) | |
| return [] | |
| def hybrid_search(self, query: str, top_k: int = 3, alpha: float = 0.5): | |
| try: | |
| if self.index is None or self.bm25 is None: | |
| return [] | |
| # ๐น Step 1: Encode query | |
| query_emb = self.model.encode([query], convert_to_numpy=True).astype("float32") | |
| # ๐น Step 2: Retrieve top candidates (IMPORTANT) | |
| retrieve_k = min(20, len(self.texts)) # candidate pool | |
| distances, indices = self.index.search(query_emb, k=retrieve_k) | |
| candidate_ids = indices[0] | |
| # ๐น Step 3: Semantic scores (convert distance โ similarity) | |
| sem_scores = {} | |
| for j, i in enumerate(candidate_ids): | |
| sim = 1 / (1 + distances[0][j]) | |
| sem_scores[i] = sim | |
| # ๐น Step 4: BM25 scores (only for candidates) | |
| tokenized_query = query.lower().split() | |
| bm25_scores = self.bm25.get_scores(tokenized_query) | |
| lex_scores = {i: bm25_scores[i] for i in candidate_ids} | |
| # ๐น Step 5: NORMALIZATION (CRITICAL) | |
| def normalize(scores_dict): | |
| vals = list(scores_dict.values()) | |
| if not vals: | |
| return scores_dict | |
| min_v, max_v = min(vals), max(vals) | |
| if max_v - min_v == 0: | |
| return {k: 0.0 for k in scores_dict} | |
| return {k: (v - min_v) / (max_v - min_v) for k, v in scores_dict.items()} | |
| sem_scores = normalize(sem_scores) | |
| lex_scores = normalize(lex_scores) | |
| # ๐น Step 6: Combine scores | |
| combined = [] | |
| for i in candidate_ids: | |
| sem = sem_scores.get(i, 0.0) | |
| lex = lex_scores.get(i, 0.0) | |
| score = alpha * sem + (1 - alpha) * lex | |
| combined.append({**self.rows[i], "score": float(score)}) | |
| # ๐น Step 7: Sort and return | |
| combined = sorted(combined, key=lambda x: x["score"], reverse=True) | |
| return combined[:top_k] | |
| except Exception as e: | |
| logger.exception("Error in hybrid_search: %s", e) | |
| return [] | |
| # ------------------------- | |
| # Utilities | |
| # ------------------------- | |
| def clear_vdb(self): | |
| try: | |
| if self.index is not None: | |
| try: | |
| self.index.reset() | |
| except Exception: | |
| self.index = None | |
| self.rows = [] | |
| self.texts = [] | |
| self.bm25 = None | |
| index_path = os.path.join(self.embeddings_folder, "multilingual.index") | |
| docs_path = os.path.join(self.docs_folder, "rows.pkl") | |
| if os.path.exists(index_path): | |
| os.remove(index_path) | |
| if os.path.exists(docs_path): | |
| os.remove(docs_path) | |
| logger.info("Cleared vector DB and persisted files.") | |
| except Exception as e: | |
| logger.exception("Error in clear_vdb: %s", e) | |
| raise |