NLP_Lab / src /retrieve.py
apytel
Redesigns UI for FreeCAD RAG Python script generator
11ba2bd
"""Retrieval: BM25 + Dense (FAISS) + RRF fusion + cross-encoder reranking."""
from __future__ import annotations
import os
import pickle
import re
from typing import Optional
import numpy as np
import pandas as pd
from src.citations import Citation
from src.config import (
BM25_FILE, CHUNKS_FILE, EMBED_MODEL, FAISS_FILE,
RRF_K, RERANK_MODEL, TOP_K_BM25, TOP_K_DENSE, TOP_K_FUSED, TOP_N_FINAL,
)
# ── tokeniser ────────────────────────────────────────────────────────────────
_TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_.:]*|\d+")
_CAMEL_RE = re.compile(r"(?<!^)(?=[A-Z])")
_STOP = {"the","a","an","of","to","in","is","are","and","or","this","that","it","be"}
def _tokenize(text: str) -> list[str]:
tokens = _TOKEN_RE.findall(text)
out: list[str] = []
for t in tokens:
tl = t.lower()
if tl in _STOP:
continue
out.append(tl)
parts = _CAMEL_RE.split(t)
if len(parts) > 1:
out.extend(p.lower() for p in parts if p and p.lower() not in _STOP)
for sub in re.split(r"[._:]+", t):
if sub and sub.lower() not in _STOP and sub.lower() != tl:
out.append(sub.lower())
return out
# ── lazy singletons ───────────────────────────────────────────────────────────
_chunks_df: Optional[pd.DataFrame] = None
_bm25_index = None
_faiss_index = None
_embed_model = None
_rerank_model = None
def _load_chunks() -> pd.DataFrame:
global _chunks_df
if _chunks_df is None:
if not os.path.exists(CHUNKS_FILE):
raise FileNotFoundError(
f"{CHUNKS_FILE} not found. Run `python build_index.py` first."
)
_chunks_df = pd.read_parquet(CHUNKS_FILE)
return _chunks_df
def _load_bm25():
global _bm25_index
if _bm25_index is None:
if not os.path.exists(BM25_FILE):
raise FileNotFoundError(f"{BM25_FILE} not found.")
with open(BM25_FILE, "rb") as f:
_bm25_index = pickle.load(f)
return _bm25_index
def _load_faiss():
global _faiss_index
if _faiss_index is None:
import faiss # noqa: PLC0415
if not os.path.exists(FAISS_FILE):
raise FileNotFoundError(f"{FAISS_FILE} not found.")
_faiss_index = faiss.read_index(FAISS_FILE)
return _faiss_index
def _load_embed():
global _embed_model
if _embed_model is None:
from sentence_transformers import SentenceTransformer # noqa: PLC0415
_embed_model = SentenceTransformer(EMBED_MODEL)
return _embed_model
def _load_reranker():
global _rerank_model
if _rerank_model is None:
from sentence_transformers import CrossEncoder # noqa: PLC0415
_rerank_model = CrossEncoder(RERANK_MODEL)
return _rerank_model
def indices_ready() -> bool:
return all(os.path.exists(p) for p in (CHUNKS_FILE, BM25_FILE, FAISS_FILE))
# ── retrieval methods ─────────────────────────────────────────────────────────
def _bm25_search(query: str, top_k: int) -> list[tuple[int, float]]:
"""Returns [(chunk_id, score), ...]."""
import bm25s # noqa: PLC0415
bm25 = _load_bm25()
query_tokens_arr = bm25s.tokenize([" ".join(_tokenize(query))])
results, scores = bm25.retrieve(query_tokens_arr, k=top_k)
return list(zip(results[0].tolist(), scores[0].tolist()))
def _dense_search(query: str, top_k: int) -> list[tuple[int, float]]:
"""Returns [(chunk_id, score), ...]."""
model = _load_embed()
index = _load_faiss()
# BGE models expect a query prefix
vec = model.encode(f"Represent this sentence for searching relevant passages: {query}",
normalize_embeddings=True).reshape(1, -1).astype("float32")
scores, ids = index.search(vec, top_k)
return [(int(i), float(s)) for i, s in zip(ids[0], scores[0]) if i >= 0]
def _rrf_fuse(
bm25_hits: list[tuple[int, float]],
dense_hits: list[tuple[int, float]],
k: int = RRF_K,
top_n: int = TOP_K_FUSED,
) -> list[tuple[int, float]]:
scores: dict[int, float] = {}
for rank, (cid, _) in enumerate(bm25_hits):
scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1)
for rank, (cid, _) in enumerate(dense_hits):
scores[cid] = scores.get(cid, 0.0) + 1.0 / (k + rank + 1)
ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return ranked[:top_n]
def _rerank(query: str, hits: list[tuple[int, float]], top_n: int, df: pd.DataFrame) -> list[tuple[int, float]]:
reranker = _load_reranker()
pairs = [(query, df.loc[cid, "text"]) for cid, _ in hits]
scores = reranker.predict(pairs)
ranked = sorted(zip([cid for cid, _ in hits], scores), key=lambda x: x[1], reverse=True)
return [(int(cid), float(s)) for cid, s in ranked[:top_n]]
# ── public API ────────────────────────────────────────────────────────────────
class HybridRetriever:
def __init__(
self,
use_bm25: bool = True,
use_dense: bool = True,
use_rerank: bool = True,
top_n: int = TOP_N_FINAL,
):
self.use_bm25 = use_bm25
self.use_dense = use_dense
self.use_rerank = use_rerank
self.top_n = top_n
def retrieve(self, query: str) -> list[Citation]:
df = _load_chunks()
bm25_hits: list[tuple[int, float]] = []
dense_hits: list[tuple[int, float]] = []
if self.use_bm25:
bm25_hits = _bm25_search(query, TOP_K_BM25)
if self.use_dense:
dense_hits = _dense_search(query, TOP_K_DENSE)
if self.use_bm25 and self.use_dense:
fused = _rrf_fuse(bm25_hits, dense_hits)
elif self.use_bm25:
fused = bm25_hits[:TOP_K_FUSED]
elif self.use_dense:
fused = dense_hits[:TOP_K_FUSED]
else:
return []
if self.use_rerank and len(fused) > 0:
final = _rerank(query, fused, self.top_n, df)
else:
final = fused[:self.top_n]
citations: list[Citation] = []
for rank, (cid, score) in enumerate(final, start=1):
row = df.loc[cid]
citations.append(Citation(
id=rank,
chunk_id=int(cid),
source_url=str(row["source_url"]),
page_title=str(row["page_title"]),
section=str(row.get("section", "")),
snippet=str(row["text"])[:600],
score=float(score),
))
return citations