| """Retrieval: BM25 + Dense (FAISS) + RRF fusion + cross-encoder reranking.""" |
| from __future__ import annotations |
|
|
| import os |
| import pickle |
| import re |
| from typing import Optional |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from src.citations import Citation |
| from src.config import ( |
| BM25_FILE, CHUNKS_FILE, EMBED_MODEL, FAISS_FILE, |
| RRF_K, RERANK_MODEL, TOP_K_BM25, TOP_K_DENSE, TOP_K_FUSED, TOP_N_FINAL, |
| ) |
|
|
| |
|
|
| _TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_.:]*|\d+") |
| _CAMEL_RE = re.compile(r"(?<!^)(?=[A-Z])") |
| _STOP = {"the","a","an","of","to","in","is","are","and","or","this","that","it","be"} |
|
|
|
|
| def _tokenize(text: str) -> list[str]: |
| tokens = _TOKEN_RE.findall(text) |
| out: list[str] = [] |
| for t in tokens: |
| tl = t.lower() |
| if tl in _STOP: |
| continue |
| out.append(tl) |
| parts = _CAMEL_RE.split(t) |
| if len(parts) > 1: |
| out.extend(p.lower() for p in parts if p and p.lower() not in _STOP) |
| for sub in re.split(r"[._:]+", t): |
| if sub and sub.lower() not in _STOP and sub.lower() != tl: |
| out.append(sub.lower()) |
| return out |
|
|
|
|
| |
|
|
| _chunks_df: Optional[pd.DataFrame] = None |
| _bm25_index = None |
| _faiss_index = None |
| _embed_model = None |
| _rerank_model = None |
|
|
|
|
| def _load_chunks() -> pd.DataFrame: |
| global _chunks_df |
| if _chunks_df is None: |
| if not os.path.exists(CHUNKS_FILE): |
| raise FileNotFoundError( |
| f"{CHUNKS_FILE} not found. Run `python build_index.py` first." |
| ) |
| _chunks_df = pd.read_parquet(CHUNKS_FILE) |
| return _chunks_df |
|
|
|
|
| def _load_bm25(): |
| global _bm25_index |
| if _bm25_index is None: |
| if not os.path.exists(BM25_FILE): |
| raise FileNotFoundError(f"{BM25_FILE} not found.") |
| with open(BM25_FILE, "rb") as f: |
| _bm25_index = pickle.load(f) |
| return _bm25_index |
|
|
|
|
| def _load_faiss(): |
| global _faiss_index |
| if _faiss_index is None: |
| import faiss |
| if not os.path.exists(FAISS_FILE): |
| raise FileNotFoundError(f"{FAISS_FILE} not found.") |
| _faiss_index = faiss.read_index(FAISS_FILE) |
| return _faiss_index |
|
|
|
|
| def _load_embed(): |
| global _embed_model |
| if _embed_model is None: |
| from sentence_transformers import SentenceTransformer |
| _embed_model = SentenceTransformer(EMBED_MODEL) |
| return _embed_model |
|
|
|
|
| def _load_reranker(): |
| global _rerank_model |
| if _rerank_model is None: |
| from sentence_transformers import CrossEncoder |
| _rerank_model = CrossEncoder(RERANK_MODEL) |
| return _rerank_model |
|
|
|
|
| def indices_ready() -> bool: |
| return all(os.path.exists(p) for p in (CHUNKS_FILE, BM25_FILE, FAISS_FILE)) |
|
|
|
|
| |
|
|
| def _bm25_search(query: str, top_k: int) -> list[tuple[int, float]]: |
| """Returns [(chunk_id, score), ...].""" |
| import bm25s |
| bm25 = _load_bm25() |
| query_tokens_arr = bm25s.tokenize([" ".join(_tokenize(query))]) |
| results, scores = bm25.retrieve(query_tokens_arr, k=top_k) |
| return list(zip(results[0].tolist(), scores[0].tolist())) |
|
|
|
|
| def _dense_search(query: str, top_k: int) -> list[tuple[int, float]]: |
| """Returns [(chunk_id, score), ...].""" |
| model = _load_embed() |
| index = _load_faiss() |
| |
| vec = model.encode(f"Represent this sentence for searching relevant passages: {query}", |
| normalize_embeddings=True).reshape(1, -1).astype("float32") |
| scores, ids = index.search(vec, top_k) |
| return [(int(i), float(s)) for i, s in zip(ids[0], scores[0]) if i >= 0] |
|
|
|
|
| def _rrf_fuse( |
| bm25_hits: list[tuple[int, float]], |
| dense_hits: list[tuple[int, float]], |
| k: int = RRF_K, |
| top_n: int = TOP_K_FUSED, |
| ) -> list[tuple[int, float]]: |
| scores: dict[int, float] = {} |
| for rank, (cid, _) in enumerate(bm25_hits): |
| scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1) |
| for rank, (cid, _) in enumerate(dense_hits): |
| scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1) |
| ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True) |
| return ranked[:top_n] |
|
|
|
|
| def _rerank(query: str, hits: list[tuple[int, float]], top_n: int, df: pd.DataFrame) -> list[tuple[int, float]]: |
| reranker = _load_reranker() |
| pairs = [(query, df.loc[cid, "text"]) for cid, _ in hits] |
| scores = reranker.predict(pairs) |
| ranked = sorted(zip([cid for cid, _ in hits], scores), key=lambda x: x[1], reverse=True) |
| return [(int(cid), float(s)) for cid, s in ranked[:top_n]] |
|
|
|
|
| |
|
|
| class HybridRetriever: |
| def __init__( |
| self, |
| use_bm25: bool = True, |
| use_dense: bool = True, |
| use_rerank: bool = True, |
| top_n: int = TOP_N_FINAL, |
| ): |
| self.use_bm25 = use_bm25 |
| self.use_dense = use_dense |
| self.use_rerank = use_rerank |
| self.top_n = top_n |
|
|
| def retrieve(self, query: str) -> list[Citation]: |
| df = _load_chunks() |
|
|
| bm25_hits: list[tuple[int, float]] = [] |
| dense_hits: list[tuple[int, float]] = [] |
|
|
| if self.use_bm25: |
| bm25_hits = _bm25_search(query, TOP_K_BM25) |
| if self.use_dense: |
| dense_hits = _dense_search(query, TOP_K_DENSE) |
|
|
| if self.use_bm25 and self.use_dense: |
| fused = _rrf_fuse(bm25_hits, dense_hits) |
| elif self.use_bm25: |
| fused = bm25_hits[:TOP_K_FUSED] |
| elif self.use_dense: |
| fused = dense_hits[:TOP_K_FUSED] |
| else: |
| return [] |
|
|
| if self.use_rerank and len(fused) > 0: |
| final = _rerank(query, fused, self.top_n, df) |
| else: |
| final = fused[:self.top_n] |
|
|
| citations: list[Citation] = [] |
| for rank, (cid, score) in enumerate(final, start=1): |
| row = df.loc[cid] |
| citations.append(Citation( |
| id=rank, |
| chunk_id=int(cid), |
| source_url=str(row["source_url"]), |
| page_title=str(row["page_title"]), |
| section=str(row.get("section", "")), |
| snippet=str(row["text"])[:600], |
| score=float(score), |
| )) |
| return citations |
|
|