"""Retrieval: BM25 + Dense (FAISS) + RRF fusion + cross-encoder reranking.""" from __future__ import annotations import os import pickle import re from typing import Optional import numpy as np import pandas as pd from src.citations import Citation from src.config import ( BM25_FILE, CHUNKS_FILE, EMBED_MODEL, FAISS_FILE, RRF_K, RERANK_MODEL, TOP_K_BM25, TOP_K_DENSE, TOP_K_FUSED, TOP_N_FINAL, ) # ── tokeniser ──────────────────────────────────────────────────────────────── _TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_.:]*|\d+") _CAMEL_RE = re.compile(r"(? list[str]: tokens = _TOKEN_RE.findall(text) out: list[str] = [] for t in tokens: tl = t.lower() if tl in _STOP: continue out.append(tl) parts = _CAMEL_RE.split(t) if len(parts) > 1: out.extend(p.lower() for p in parts if p and p.lower() not in _STOP) for sub in re.split(r"[._:]+", t): if sub and sub.lower() not in _STOP and sub.lower() != tl: out.append(sub.lower()) return out # ── lazy singletons ─────────────────────────────────────────────────────────── _chunks_df: Optional[pd.DataFrame] = None _bm25_index = None _faiss_index = None _embed_model = None _rerank_model = None def _load_chunks() -> pd.DataFrame: global _chunks_df if _chunks_df is None: if not os.path.exists(CHUNKS_FILE): raise FileNotFoundError( f"{CHUNKS_FILE} not found. Run `python build_index.py` first." ) _chunks_df = pd.read_parquet(CHUNKS_FILE) return _chunks_df def _load_bm25(): global _bm25_index if _bm25_index is None: if not os.path.exists(BM25_FILE): raise FileNotFoundError(f"{BM25_FILE} not found.") with open(BM25_FILE, "rb") as f: _bm25_index = pickle.load(f) return _bm25_index def _load_faiss(): global _faiss_index if _faiss_index is None: import faiss # noqa: PLC0415 if not os.path.exists(FAISS_FILE): raise FileNotFoundError(f"{FAISS_FILE} not found.") _faiss_index = faiss.read_index(FAISS_FILE) return _faiss_index def _load_embed(): global _embed_model if _embed_model is None: from sentence_transformers import SentenceTransformer # noqa: PLC0415 _embed_model = SentenceTransformer(EMBED_MODEL) return _embed_model def _load_reranker(): global _rerank_model if _rerank_model is None: from sentence_transformers import CrossEncoder # noqa: PLC0415 _rerank_model = CrossEncoder(RERANK_MODEL) return _rerank_model def indices_ready() -> bool: return all(os.path.exists(p) for p in (CHUNKS_FILE, BM25_FILE, FAISS_FILE)) # ── retrieval methods ───────────────────────────────────────────────────────── def _bm25_search(query: str, top_k: int) -> list[tuple[int, float]]: """Returns [(chunk_id, score), ...].""" import bm25s # noqa: PLC0415 bm25 = _load_bm25() query_tokens_arr = bm25s.tokenize([" ".join(_tokenize(query))]) results, scores = bm25.retrieve(query_tokens_arr, k=top_k) return list(zip(results[0].tolist(), scores[0].tolist())) def _dense_search(query: str, top_k: int) -> list[tuple[int, float]]: """Returns [(chunk_id, score), ...].""" model = _load_embed() index = _load_faiss() # BGE models expect a query prefix vec = model.encode(f"Represent this sentence for searching relevant passages: {query}", normalize_embeddings=True).reshape(1, -1).astype("float32") scores, ids = index.search(vec, top_k) return [(int(i), float(s)) for i, s in zip(ids[0], scores[0]) if i >= 0] def _rrf_fuse( bm25_hits: list[tuple[int, float]], dense_hits: list[tuple[int, float]], k: int = RRF_K, top_n: int = TOP_K_FUSED, ) -> list[tuple[int, float]]: scores: dict[int, float] = {} for rank, (cid, _) in enumerate(bm25_hits): scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1) for rank, (cid, _) in enumerate(dense_hits): scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1) ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True) return ranked[:top_n] def _rerank(query: str, hits: list[tuple[int, float]], top_n: int, df: pd.DataFrame) -> list[tuple[int, float]]: reranker = _load_reranker() pairs = [(query, df.loc[cid, "text"]) for cid, _ in hits] scores = reranker.predict(pairs) ranked = sorted(zip([cid for cid, _ in hits], scores), key=lambda x: x[1], reverse=True) return [(int(cid), float(s)) for cid, s in ranked[:top_n]] # ── public API ──────────────────────────────────────────────────────────────── class HybridRetriever: def __init__( self, use_bm25: bool = True, use_dense: bool = True, use_rerank: bool = True, top_n: int = TOP_N_FINAL, ): self.use_bm25 = use_bm25 self.use_dense = use_dense self.use_rerank = use_rerank self.top_n = top_n def retrieve(self, query: str) -> list[Citation]: df = _load_chunks() bm25_hits: list[tuple[int, float]] = [] dense_hits: list[tuple[int, float]] = [] if self.use_bm25: bm25_hits = _bm25_search(query, TOP_K_BM25) if self.use_dense: dense_hits = _dense_search(query, TOP_K_DENSE) if self.use_bm25 and self.use_dense: fused = _rrf_fuse(bm25_hits, dense_hits) elif self.use_bm25: fused = bm25_hits[:TOP_K_FUSED] elif self.use_dense: fused = dense_hits[:TOP_K_FUSED] else: return [] if self.use_rerank and len(fused) > 0: final = _rerank(query, fused, self.top_n, df) else: final = fused[:self.top_n] citations: list[Citation] = [] for rank, (cid, score) in enumerate(final, start=1): row = df.loc[cid] citations.append(Citation( id=rank, chunk_id=int(cid), source_url=str(row["source_url"]), page_title=str(row["page_title"]), section=str(row.get("section", "")), snippet=str(row["text"])[:600], score=float(score), )) return citations