doc-ai-api / Libraries /Faiss_Searching.py
LongK171's picture
Add all
dbe2c62
import faiss
import numpy as np
from typing import Dict, List, Any, Optional
from sentence_transformers import SentenceTransformer, CrossEncoder
class SemanticSearchEngine:
def __init__(
self,
indexer: SentenceTransformer,
reranker: Optional[CrossEncoder] = None,
device: str = "cuda",
normalize: bool = True,
top_k: int = 20,
rerank_k: int = 10,
rerank_batch_size: int = 16,
):
self.device = device
self.normalize = normalize
self.top_k = int(top_k)
self.rerank_k = int(rerank_k)
self.rerank_batch_size = int(rerank_batch_size)
# ✅ Nhận trực tiếp model đã load
if not isinstance(indexer, SentenceTransformer):
raise TypeError("indexer phải là SentenceTransformer đã load sẵn.")
self._indexer = indexer
# Reranker là tùy chọn
if reranker and not isinstance(reranker, CrossEncoder):
raise TypeError("reranker phải là CrossEncoder hoặc None.")
self.reranker = reranker
# ---------------------------
# Tiện ích nội bộ
# ---------------------------
@staticmethod
def _l2_normalize(x: np.ndarray, axis: int = 1, eps: float = 1e-12) -> np.ndarray:
denom = np.linalg.norm(x, axis=axis, keepdims=True)
denom = np.maximum(denom, eps)
return x / denom
@staticmethod
def _build_idx_maps(Mapping: Dict[str, Any], MapData: Dict[str, Any]):
"""Tạo ánh xạ index→text và index→key"""
items = MapData.get("items", [])
idx2text = {int(item["index"]): item.get("text", None) for item in items}
raw_i2k = Mapping.get("index_to_key", {})
idx2key = {int(i): k for i, k in raw_i2k.items()}
return idx2text, idx2key
# ---------------------------
# 1️⃣ SEARCH: FAISS vector search
# ---------------------------
def search(
self,
query: str,
faissIndex: "faiss.Index", # type: ignore
Mapping: Dict[str, Any],
MapData: Dict[str, Any],
MapChunk: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
query_embedding: Optional[np.ndarray] = None,
) -> List[Dict[str, Any]]:
"""
Trả về:
[{"index":..., "key":..., "text":..., "faiss_score":...}, ...]
"""
k = int(top_k or self.top_k)
# 1. Encode truy vấn (hoặc dùng sẵn embedding)
if query_embedding is None:
q = self._indexer.encode(
[query], convert_to_tensor=True, device=str(self.device)
)
q = q.detach().cpu().numpy().astype("float32")
else:
q = np.asarray(query_embedding, dtype="float32")
if q.ndim == 1:
q = q[None, :]
# 2. Normalize nếu dùng cosine
if self.normalize:
q = self._l2_normalize(q)
# 3. Search FAISS
scores, ids = faissIndex.search(q, k)
idx2text, idx2key = self._build_idx_maps(Mapping, MapData)
# 4. Mapping kết quả
chunk_map = MapChunk.get("index_to_chunk", {}) if MapChunk else {}
results = []
for score, idx in zip(scores[0].tolist(), ids[0].tolist()):
chunk_ids = chunk_map.get(str(idx), [])
results.append({
"index": int(idx),
"key": idx2key.get(int(idx)),
"text": idx2text.get(int(idx)),
"faiss_score": float(score),
"chunk_ids": chunk_ids,
})
return results
# ---------------------------
# 2️⃣ RERANK: CrossEncoder rerank
# ---------------------------
def rerank(
self,
query: str,
results: List[Dict[str, Any]],
top_k: Optional[int] = None,
show_progress: bool = False,
) -> List[Dict[str, Any]]:
"""
Xếp hạng lại kết quả bằng CrossEncoder (nếu có).
Trả về danh sách top_k kết quả đã rerank.
"""
if not results:
return []
if self.reranker is None:
raise ValueError("⚠️ Không có reranker được cung cấp khi khởi tạo.")
k = int(top_k or self.rerank_k)
pairs = []
valid_indices = []
for i, r in enumerate(results):
text = r.get("text")
if isinstance(text, str) and text.strip():
pairs.append([query, text])
valid_indices.append(i)
if not pairs:
return []
scores = self.reranker.predict(
pairs, batch_size=self.rerank_batch_size, show_progress_bar=show_progress
)
for i, s in zip(valid_indices, scores):
results[i]["rerank_score"] = float(s)
reranked = [r for r in results if "rerank_score" in r]
reranked.sort(key=lambda x: x["rerank_score"], reverse=True)
return reranked[:k]