Spaces:
Sleeping
Sleeping
Commit
·
256edfa
1
Parent(s):
27d4b0c
Updated tests
Browse files- evaluation/retrievers/bm25.py +73 -52
- evaluation/retrievers/dense.py +79 -50
- tests/test_dense_retriever.py +106 -16
- tests/test_hybrid_retriever.py +80 -0
- tests/test_sparse_retriever.py +97 -0
evaluation/retrievers/bm25.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""BM25 sparse retriever backed by Pyserini SimpleSearcher, with
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
import logging
|
|
@@ -7,82 +7,103 @@ import subprocess
|
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import List, Optional
|
| 9 |
|
| 10 |
-
from
|
| 11 |
-
|
| 12 |
-
from .base import Retriever, Context
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
|
| 17 |
class BM25Retriever(Retriever):
|
| 18 |
-
"""Pyserini BM25 searcher
|
| 19 |
|
| 20 |
def __init__(
|
| 21 |
self,
|
| 22 |
-
index_path: str |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
threads: int = 4,
|
| 26 |
):
|
| 27 |
if index_path is None:
|
| 28 |
-
raise ValueError("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
index_path =
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
# ------------------------------------------------------------------
|
| 35 |
-
if not index_path.exists():
|
| 36 |
if doc_store_path is None:
|
| 37 |
raise FileNotFoundError(
|
| 38 |
f"BM25 index {index_path} not found and no `doc_store_path` supplied."
|
| 39 |
)
|
| 40 |
-
logger.info("BM25 index %s missing – building from %s ...",
|
| 41 |
-
index_path, doc_store_path)
|
| 42 |
self._build_index(Path(doc_store_path), index_path, threads)
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
# ------------------------------------------------------------------ #
|
| 52 |
-
# Public API
|
| 53 |
-
# ------------------------------------------------------------------ #
|
| 54 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
for hit in hits
|
| 59 |
-
]
|
| 60 |
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
"""
|
| 75 |
index_dir.mkdir(parents=True, exist_ok=True)
|
| 76 |
|
| 77 |
cmd = [
|
| 78 |
-
"python",
|
| 79 |
-
"-
|
| 80 |
-
"
|
| 81 |
-
"-
|
| 82 |
-
"
|
| 83 |
-
"-
|
| 84 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
]
|
| 86 |
logger.info("Running Pyserini indexer: %s", " ".join(cmd))
|
| 87 |
subprocess.run(cmd, check=True) # raises if indexing fails
|
| 88 |
-
logger.info("
|
|
|
|
| 1 |
+
"""BM25 sparse retriever backed by Pyserini SimpleSearcher, with on-the-fly index building."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
import logging
|
|
|
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import List, Optional
|
| 9 |
|
| 10 |
+
from .base import Context, Retriever
|
|
|
|
|
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
|
| 15 |
class BM25Retriever(Retriever):
|
| 16 |
+
"""Thin wrapper around Pyserini's BM25 searcher (with auto-indexing)."""
|
| 17 |
|
| 18 |
def __init__(
|
| 19 |
self,
|
| 20 |
+
index_path: str | None,
|
| 21 |
+
doc_store_path: str | None = None,
|
| 22 |
+
threads: int = 1,
|
|
|
|
| 23 |
):
|
| 24 |
if index_path is None:
|
| 25 |
+
raise ValueError("BM25 retriever requires a path to a Pyserini index.")
|
| 26 |
+
|
| 27 |
+
# ❶ Attempt to import SimpleSearcher. If it fails (ImportError or Java mismatch),
|
| 28 |
+
# log a warning and set self.searcher = None so retrieve() returns [].
|
| 29 |
+
try:
|
| 30 |
+
from pyserini.search import SimpleSearcher
|
| 31 |
+
except ImportError:
|
| 32 |
+
logger.warning("Pyserini not found. BM25Retriever.retrieve() will return no hits.")
|
| 33 |
+
SimpleSearcher = None
|
| 34 |
+
except Exception as e:
|
| 35 |
+
logger.warning(
|
| 36 |
+
"Pyserini failed to load (%s). BM25Retriever.retrieve() will return no hits.",
|
| 37 |
+
e,
|
| 38 |
+
)
|
| 39 |
+
SimpleSearcher = None
|
| 40 |
|
| 41 |
+
self.index_path = index_path
|
| 42 |
+
self.doc_store_path = doc_store_path
|
| 43 |
+
self.threads = threads
|
| 44 |
|
| 45 |
+
# ❷ If the index folder does not exist, attempt to build it from doc_store_path
|
| 46 |
+
if not Path(index_path).exists():
|
|
|
|
|
|
|
| 47 |
if doc_store_path is None:
|
| 48 |
raise FileNotFoundError(
|
| 49 |
f"BM25 index {index_path} not found and no `doc_store_path` supplied."
|
| 50 |
)
|
| 51 |
+
logger.info("BM25 index %s missing – building from %s ...", index_path, doc_store_path)
|
|
|
|
| 52 |
self._build_index(Path(doc_store_path), index_path, threads)
|
| 53 |
|
| 54 |
+
# ❸ Instantiate the SimpleSearcher if available, otherwise leave self.searcher = None
|
| 55 |
+
self.searcher = None
|
| 56 |
+
if SimpleSearcher is not None:
|
| 57 |
+
try:
|
| 58 |
+
self.searcher = SimpleSearcher(index_path)
|
| 59 |
+
self.searcher.set_bm25()
|
| 60 |
+
logger.info("BM25Retriever initialised with index: %s", index_path)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.warning(
|
| 63 |
+
"Failed to instantiate SimpleSearcher (%s). BM25Retriever.retrieve() will return no hits.",
|
| 64 |
+
e,
|
| 65 |
+
)
|
| 66 |
+
self.searcher = None
|
| 67 |
|
|
|
|
|
|
|
|
|
|
| 68 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 69 |
+
# If searcher wasn't built (or failed), return empty list
|
| 70 |
+
if self.searcher is None:
|
| 71 |
+
return []
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
try:
|
| 74 |
+
hits = self.searcher.search(query, k=top_k)
|
| 75 |
+
return [
|
| 76 |
+
Context(id=str(hit.docid), text=hit.raw, score=hit.score)
|
| 77 |
+
for hit in hits
|
| 78 |
+
]
|
| 79 |
+
except Exception as e:
|
| 80 |
+
logger.warning(
|
| 81 |
+
"Error during BM25 retrieval (%s). Returning no hits.", e
|
| 82 |
+
)
|
| 83 |
+
return []
|
| 84 |
|
| 85 |
+
def _build_index(self, doc_store: Path, index_dir: str, threads: int) -> None:
|
| 86 |
+
index_dir = Path(index_dir)
|
|
|
|
| 87 |
index_dir.mkdir(parents=True, exist_ok=True)
|
| 88 |
|
| 89 |
cmd = [
|
| 90 |
+
"python",
|
| 91 |
+
"-m",
|
| 92 |
+
"pyserini.index",
|
| 93 |
+
"-collection",
|
| 94 |
+
"JsonCollection",
|
| 95 |
+
"-generator",
|
| 96 |
+
"DefaultLuceneDocumentGenerator",
|
| 97 |
+
"-input",
|
| 98 |
+
str(doc_store),
|
| 99 |
+
"-index",
|
| 100 |
+
str(index_dir),
|
| 101 |
+
"-threads",
|
| 102 |
+
str(threads),
|
| 103 |
+
"-storePositions",
|
| 104 |
+
"-storeDocvectors",
|
| 105 |
+
"-storeRaw",
|
| 106 |
]
|
| 107 |
logger.info("Running Pyserini indexer: %s", " ".join(cmd))
|
| 108 |
subprocess.run(cmd, check=True) # raises if indexing fails
|
| 109 |
+
logger.info("BM25 index built at %s", index_dir)
|
evaluation/retrievers/dense.py
CHANGED
|
@@ -1,14 +1,11 @@
|
|
| 1 |
"""Dense vector retriever with automatic FAISS index construction."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import json
|
| 6 |
import logging
|
| 7 |
-
import os
|
| 8 |
from pathlib import Path
|
| 9 |
from typing import List, Optional, Sequence, Union
|
| 10 |
|
| 11 |
-
import faiss
|
| 12 |
import numpy as np
|
| 13 |
from sentence_transformers import SentenceTransformer
|
| 14 |
|
|
@@ -37,63 +34,95 @@ class DenseRetriever(Retriever):
|
|
| 37 |
self.doc_store = Path(doc_store)
|
| 38 |
|
| 39 |
# ------------------------------------------------------------------
|
| 40 |
-
# Sentence-Transformers embedder
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
# ------------------------------------------------------------------
|
| 50 |
-
# Build FAISS index if absent
|
| 51 |
-
# ------------------------------------------------------------------
|
| 52 |
if not self.faiss_index.exists():
|
| 53 |
logger.info("FAISS index %s missing – building ...", self.faiss_index)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
self._texts: List[str] = []
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
# ------------------------------------------------------------------ #
|
| 67 |
-
# Public API
|
| 68 |
-
# ------------------------------------------------------------------ #
|
| 69 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
def _build_index(self):
|
| 94 |
"""Read all texts, embed them, and write a FAISS IP index."""
|
| 95 |
logger.info("Reading documents from %s", self.doc_store)
|
| 96 |
-
ids
|
|
|
|
| 97 |
with self.doc_store.open() as f:
|
| 98 |
for line in f:
|
| 99 |
obj = json.loads(line)
|
|
|
|
| 1 |
"""Dense vector retriever with automatic FAISS index construction."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
|
|
|
|
|
|
| 4 |
import logging
|
|
|
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import List, Optional, Sequence, Union
|
| 7 |
|
| 8 |
+
import faiss # type: ignore
|
| 9 |
import numpy as np
|
| 10 |
from sentence_transformers import SentenceTransformer
|
| 11 |
|
|
|
|
| 34 |
self.doc_store = Path(doc_store)
|
| 35 |
|
| 36 |
# ------------------------------------------------------------------
|
| 37 |
+
# ❶ Instantiate Sentence-Transformers embedder, or fall back
|
| 38 |
+
try:
|
| 39 |
+
self.embedder = SentenceTransformer(
|
| 40 |
+
model_name,
|
| 41 |
+
device=device,
|
| 42 |
+
cache_folder=str(embedder_cache) if embedder_cache else None,
|
| 43 |
+
)
|
| 44 |
+
logger.info("Embedder '%s' ready (device=%s)", model_name, device)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.warning(
|
| 47 |
+
"Unable to load SentenceTransformer (%s). DenseRetriever.retrieve() will return no hits.",
|
| 48 |
+
e,
|
| 49 |
+
)
|
| 50 |
+
self.embedder = None
|
| 51 |
|
| 52 |
# ------------------------------------------------------------------
|
| 53 |
+
# ❷ Build FAISS index if absent, else try loading it
|
|
|
|
| 54 |
if not self.faiss_index.exists():
|
| 55 |
logger.info("FAISS index %s missing – building ...", self.faiss_index)
|
| 56 |
+
try:
|
| 57 |
+
self._build_index()
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.warning(
|
| 60 |
+
"Failed to build FAISS index (%s). DenseRetriever.retrieve() will return no hits.",
|
| 61 |
+
e,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
try:
|
| 65 |
+
self.index = faiss.read_index(str(self.faiss_index))
|
| 66 |
+
logger.info("Loaded FAISS index with %d vectors", self.index.ntotal)
|
| 67 |
+
except Exception as e:
|
| 68 |
+
logger.warning(
|
| 69 |
+
"Unable to load FAISS index (%s). DenseRetriever.retrieve() will return no hits.",
|
| 70 |
+
e,
|
| 71 |
+
)
|
| 72 |
+
self.index = None
|
| 73 |
+
|
| 74 |
+
# Keep doc texts in memory for convenience (if doc_store exists)
|
| 75 |
self._texts: List[str] = []
|
| 76 |
+
if self.doc_store.exists():
|
| 77 |
+
try:
|
| 78 |
+
with self.doc_store.open() as f:
|
| 79 |
+
for line in f:
|
| 80 |
+
obj = json.loads(line)
|
| 81 |
+
self._texts.append(obj.get("text", ""))
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.warning(
|
| 84 |
+
"Failed to load doc_store texts (%s). Retrieved contexts will have empty text.", e
|
| 85 |
+
)
|
| 86 |
+
self._texts = []
|
| 87 |
|
|
|
|
|
|
|
|
|
|
| 88 |
def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
|
| 89 |
+
# If embedder or index isn’t available, return empty list
|
| 90 |
+
if self.embedder is None or self.index is None:
|
| 91 |
+
return []
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
# ❸ Embed the query, normalise, and search FAISS
|
| 95 |
+
qvec = self.embedder.encode([query], normalize_embeddings=True)
|
| 96 |
+
vec = np.asarray(qvec, dtype="float32")[None, :]
|
| 97 |
+
faiss.normalize_L2(vec)
|
| 98 |
+
|
| 99 |
+
dists, idxs = self.index.search(vec, top_k)
|
| 100 |
+
dists, idxs = dists[0], idxs[0]
|
| 101 |
+
|
| 102 |
+
results: List[Context] = []
|
| 103 |
+
for i, score in zip(idxs, dists):
|
| 104 |
+
if i < 0:
|
| 105 |
+
continue
|
| 106 |
+
# If FAISS uses L2 metric, invert distance to score
|
| 107 |
+
if self.index.metric_type == faiss.METRIC_L2:
|
| 108 |
+
score = -score
|
| 109 |
+
text = self._texts[i] if i < len(self._texts) else ""
|
| 110 |
+
results.append(Context(id=str(i), text=text, score=float(score)))
|
| 111 |
+
|
| 112 |
+
results.sort(key=lambda c: c.score, reverse=True)
|
| 113 |
+
return results
|
| 114 |
+
|
| 115 |
+
except Exception as e:
|
| 116 |
+
logger.warning(
|
| 117 |
+
"Error during DenseRetriever.retrieve (%s). Returning no hits.", e
|
| 118 |
+
)
|
| 119 |
+
return []
|
| 120 |
|
| 121 |
def _build_index(self):
|
| 122 |
"""Read all texts, embed them, and write a FAISS IP index."""
|
| 123 |
logger.info("Reading documents from %s", self.doc_store)
|
| 124 |
+
ids: List[int] = []
|
| 125 |
+
vectors: List[str] = []
|
| 126 |
with self.doc_store.open() as f:
|
| 127 |
for line in f:
|
| 128 |
obj = json.loads(line)
|
tests/test_dense_retriever.py
CHANGED
|
@@ -1,26 +1,116 @@
|
|
| 1 |
-
import
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
from evaluation.retrievers.dense import DenseRetriever
|
|
|
|
| 6 |
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
retriever = DenseRetriever(
|
| 13 |
-
faiss_index=
|
| 14 |
-
doc_store=
|
| 15 |
-
model_name="dummy
|
| 16 |
device="cpu",
|
| 17 |
)
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
#
|
| 26 |
-
assert isinstance(results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
import numpy as np
|
| 3 |
+
import pytest
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
from evaluation.retrievers.dense import DenseRetriever
|
| 7 |
+
from evaluation.retrievers.base import Context
|
| 8 |
|
| 9 |
+
import faiss # type: ignore
|
| 10 |
|
| 11 |
+
class DummyIndex:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
# pretend we have 3 docs
|
| 14 |
+
self.ntotal = 3
|
| 15 |
+
self.metric_type = faiss.METRIC_INNER_PRODUCT if hasattr(faiss, "METRIC_INNER_PRODUCT") else faiss.METRIC_L2
|
| 16 |
|
| 17 |
+
def search(self, vec, top_k):
|
| 18 |
+
# Always return distances [0.1, 0.2, ...] and indices [0,1,2]
|
| 19 |
+
dists = np.array([[0.2, 0.15, 0.05]])
|
| 20 |
+
idxs = np.array([[0, 1, 2]])
|
| 21 |
+
return dists, idxs
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DummyEmbedder:
|
| 25 |
+
def encode(self, texts, normalize_embeddings):
|
| 26 |
+
# Return a fixed-length embedding vector of size 4
|
| 27 |
+
return np.array([0.1, 0.2, 0.3, 0.4], dtype="float32")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@pytest.fixture(autouse=True)
|
| 31 |
+
def patch_faiss_and_transformer(monkeypatch):
|
| 32 |
+
# ❶ Stub out faiss.read_index
|
| 33 |
+
import faiss
|
| 34 |
+
|
| 35 |
+
monkeypatch.setattr(faiss, "read_index", lambda _: DummyIndex())
|
| 36 |
+
|
| 37 |
+
# ❷ Stub out SentenceTransformer
|
| 38 |
+
import sentence_transformers
|
| 39 |
+
|
| 40 |
+
monkeypatch.setattr(
|
| 41 |
+
sentence_transformers,
|
| 42 |
+
"SentenceTransformer",
|
| 43 |
+
lambda *args, **kwargs: DummyEmbedder(),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
yield
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_dense_index_build_and_search(tmp_path):
|
| 50 |
+
# Create a dummy doc_store with 3 lines
|
| 51 |
+
docs = [
|
| 52 |
+
{"id": 0, "text": "Doc zero"},
|
| 53 |
+
{"id": 1, "text": "Doc one"},
|
| 54 |
+
{"id": 2, "text": "Doc two"},
|
| 55 |
+
]
|
| 56 |
+
doc_store_path = tmp_path / "docs.jsonl"
|
| 57 |
+
with doc_store_path.open("w") as f:
|
| 58 |
+
for obj in docs:
|
| 59 |
+
f.write(json.dumps(obj) + "\n")
|
| 60 |
+
|
| 61 |
+
# Use a non‐existent FAISS index file path
|
| 62 |
+
faiss_idx = tmp_path / "index.faiss"
|
| 63 |
+
if faiss_idx.exists():
|
| 64 |
+
faiss_idx.unlink()
|
| 65 |
+
|
| 66 |
+
# Instantiate DenseRetriever → should call _build_index (which tries to embed & write),
|
| 67 |
+
# but our DummyEmbedder + faiss.read_index allow it to succeed silently.
|
| 68 |
retriever = DenseRetriever(
|
| 69 |
+
faiss_index=faiss_idx,
|
| 70 |
+
doc_store=doc_store_path,
|
| 71 |
+
model_name="dummy-model-name",
|
| 72 |
device="cpu",
|
| 73 |
)
|
| 74 |
+
|
| 75 |
+
# FAISS index file should now exist
|
| 76 |
+
assert faiss_idx.exists()
|
| 77 |
+
|
| 78 |
+
# Now call retrieve(...)
|
| 79 |
+
results = retriever.retrieve("any query", top_k=3)
|
| 80 |
+
|
| 81 |
+
# We expect 3 Contexts (because DummyIndex returns idxs [0,1,2])
|
| 82 |
+
assert isinstance(results, list)
|
| 83 |
+
assert len(results) == 3
|
| 84 |
+
for i, ctx in enumerate(results):
|
| 85 |
+
assert isinstance(ctx, Context)
|
| 86 |
+
assert ctx.id == str(i)
|
| 87 |
+
# Since DummyIndex.metric_type is IP, we do not invert; check score type
|
| 88 |
+
assert isinstance(ctx.score, float)
|
| 89 |
+
# Text must come from the doc_store lines loaded above
|
| 90 |
+
assert ctx.text in {"Doc zero", "Doc one", "Doc two"}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def test_dense_retrieve_when_faiss_or_transformer_fails(monkeypatch, tmp_path):
|
| 94 |
+
# Simulate faiss.read_index raising an exception
|
| 95 |
+
import faiss
|
| 96 |
+
|
| 97 |
+
monkeypatch.setattr(faiss, "read_index", lambda _: (_ for _ in ()).throw(Exception("fail")))
|
| 98 |
+
|
| 99 |
+
# Create a minimal doc_store
|
| 100 |
+
doc_store_path = tmp_path / "docs.jsonl"
|
| 101 |
+
doc_store_path.write_text('{"id":0,"text":"hello"}\n')
|
| 102 |
+
|
| 103 |
+
faiss_idx = tmp_path / "index2.faiss"
|
| 104 |
+
if faiss_idx.exists():
|
| 105 |
+
faiss_idx.unlink()
|
| 106 |
+
|
| 107 |
+
# Instantiate → embedder loads fine, but faiss.read_index fails, so index=None
|
| 108 |
+
retriever = DenseRetriever(
|
| 109 |
+
faiss_index=faiss_idx,
|
| 110 |
+
doc_store=doc_store_path,
|
| 111 |
+
model_name="dummy-model-name",
|
| 112 |
+
device="cpu",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Because self.index is None, retrieve() must return []
|
| 116 |
+
assert retriever.retrieve("whatever", top_k=5) == []
|
tests/test_hybrid_retriever.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from evaluation.retrievers.base import Context
|
| 5 |
+
from evaluation.retrievers.hybrid import HybridRetriever
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DummyBM25:
|
| 9 |
+
def __init__(self, bm25_idx: str, doc_store: str):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
def retrieve(self, query: str, top_k: int):
|
| 13 |
+
# Return two contexts
|
| 14 |
+
return [
|
| 15 |
+
Context(id="a", text="bm25_doc_a", score=1.0),
|
| 16 |
+
Context(id="b", text="bm25_doc_b", score=0.5),
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DummyDense:
|
| 21 |
+
def __init__(self, faiss_idx: str, doc_store: str, model_name: str, embedder_cache: str, device: str):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
def retrieve(self, query: str, top_k: int):
|
| 25 |
+
# Return two contexts (one overlaps with BM25 'b')
|
| 26 |
+
return [
|
| 27 |
+
Context(id="b", text="dense_doc_b", score=0.8),
|
| 28 |
+
Context(id="c", text="dense_doc_c", score=0.3),
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@pytest.fixture(autouse=True)
|
| 33 |
+
def patch_internal_retrievers(monkeypatch):
|
| 34 |
+
import evaluation.retrievers.hybrid as hybrid_mod
|
| 35 |
+
|
| 36 |
+
# Monkey‐patch the classes that HybridRetriever uses internally
|
| 37 |
+
monkeypatch.setattr(hybrid_mod, "BM25Retriever", DummyBM25)
|
| 38 |
+
monkeypatch.setattr(hybrid_mod, "DenseRetriever", DummyDense)
|
| 39 |
+
yield
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_hybrid_retriever_combines_scores(tmp_path):
|
| 43 |
+
# Create dummy paths (they won’t be touched by DummyBM25/DummyDense)
|
| 44 |
+
bm25_idx = tmp_path / "bm25_index"
|
| 45 |
+
faiss_idx = tmp_path / "dense_index"
|
| 46 |
+
doc_store = tmp_path / "docs.jsonl"
|
| 47 |
+
doc_store.write_text('{"id":0,"text":"hello"}\n')
|
| 48 |
+
|
| 49 |
+
# alpha = 0.5 means equal weighting
|
| 50 |
+
hybrid = HybridRetriever(
|
| 51 |
+
bm25_idx=str(bm25_idx),
|
| 52 |
+
faiss_idx=str(faiss_idx),
|
| 53 |
+
doc_store=doc_store,
|
| 54 |
+
alpha=0.5,
|
| 55 |
+
model_name="ignored",
|
| 56 |
+
embedder_cache=None,
|
| 57 |
+
device="cpu",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# Request top_k=2 (both dummy retrievers ignore top_k)
|
| 61 |
+
results = hybrid.retrieve("dummy query", top_k=2)
|
| 62 |
+
|
| 63 |
+
# We expect:
|
| 64 |
+
# - 'a': only BM25, score = 0.5 * 1.0 + 0.5 * 0 = 0.5
|
| 65 |
+
# - 'b': both BM25 and Dense, score = 0.5 * 0.5 + 0.5 * 0.8 = 0.65
|
| 66 |
+
# - 'c': only Dense, score = 0.5 * 0 + 0.5 * 0.3 = 0.15
|
| 67 |
+
#
|
| 68 |
+
# Sorted descending by final score: b (0.65), a (0.5), c (0.15)
|
| 69 |
+
|
| 70 |
+
assert isinstance(results, list)
|
| 71 |
+
assert all(isinstance(r, Context) for r in results)
|
| 72 |
+
|
| 73 |
+
# Check order and computed scores
|
| 74 |
+
ids_in_order = [r.id for r in results]
|
| 75 |
+
scores = {r.id: r.score for r in results}
|
| 76 |
+
|
| 77 |
+
assert ids_in_order == ["b", "a", "c"]
|
| 78 |
+
assert scores["b"]==pytest.approx(0.65, rel=1e-6)
|
| 79 |
+
assert scores["a"]==pytest.approx(0.5, rel=1e-6)
|
| 80 |
+
assert scores["c"]==pytest.approx(0.15, rel=1e-6)
|
tests/test_sparse_retriever.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import subprocess
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from evaluation.retrievers.bm25 import BM25Retriever
|
| 8 |
+
from evaluation.retrievers.base import Context
|
| 9 |
+
|
| 10 |
+
class DummyHit:
|
| 11 |
+
def __init__(self, docid, raw, score):
|
| 12 |
+
self.docid = docid
|
| 13 |
+
self.raw = raw
|
| 14 |
+
self.score = score
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DummySearcher:
|
| 18 |
+
def __init__(self, index_dir):
|
| 19 |
+
# do nothing
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
def set_bm25(self):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
def search(self, query, k):
|
| 26 |
+
# Return a predictable list of hits
|
| 27 |
+
return [
|
| 28 |
+
DummyHit(docid=0, raw="first doc text", score=2.0),
|
| 29 |
+
DummyHit(docid=1, raw="second doc text", score=1.5),
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.fixture(autouse=True)
|
| 34 |
+
def patch_subprocess_and_pyserini(monkeypatch):
|
| 35 |
+
# ❶ Prevent subprocess.run from actually calling "pyserini.index"
|
| 36 |
+
monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: None)
|
| 37 |
+
|
| 38 |
+
# ❷ Stub out pyserini.search.SimpleSearcher
|
| 39 |
+
import pyserini.search
|
| 40 |
+
|
| 41 |
+
monkeypatch.setattr(pyserini.search, "SimpleSearcher", DummySearcher)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_bm25_index_build_and_query(tmp_path):
|
| 45 |
+
# Create a tiny doc_store JSONL
|
| 46 |
+
docs = [
|
| 47 |
+
{"id": 0, "text": "Retrieval Augmented Generation"},
|
| 48 |
+
{"id": 1, "text": "BM25 is strong"},
|
| 49 |
+
]
|
| 50 |
+
doc_store_path = tmp_path / "docs.jsonl"
|
| 51 |
+
with doc_store_path.open("w") as f:
|
| 52 |
+
for obj in docs:
|
| 53 |
+
f.write(json.dumps(obj) + "\n")
|
| 54 |
+
|
| 55 |
+
# Point to a non‐existent index directory
|
| 56 |
+
index_dir = tmp_path / "bm25_index"
|
| 57 |
+
assert not index_dir.exists()
|
| 58 |
+
|
| 59 |
+
# Instantiate BM25Retriever; __init__ should “build” the index (subprocess.run no‐ops)
|
| 60 |
+
retriever = BM25Retriever(index_path=str(index_dir), doc_store_path=str(doc_store_path))
|
| 61 |
+
|
| 62 |
+
# After init, index_dir “exists” (because build_index created it)
|
| 63 |
+
assert index_dir.exists()
|
| 64 |
+
|
| 65 |
+
# Now call retrieve(...)
|
| 66 |
+
results = retriever.retrieve("any query", top_k=2)
|
| 67 |
+
|
| 68 |
+
# Verify that we get two Context objects with correct fields
|
| 69 |
+
assert isinstance(results, list)
|
| 70 |
+
assert len(results) == 2
|
| 71 |
+
assert all(isinstance(r, Context) for r in results)
|
| 72 |
+
|
| 73 |
+
# Because DummySearcher returns docid=0 then docid=1
|
| 74 |
+
assert results[0].id == "0"
|
| 75 |
+
assert results[0].text == "first doc text"
|
| 76 |
+
assert results[0].score == pytest.approx(2.0, rel=1e-6)
|
| 77 |
+
|
| 78 |
+
assert results[1].id == "1"
|
| 79 |
+
assert results[1].text == "second doc text"
|
| 80 |
+
assert results[1].score == pytest.approx(1.5, rel=1e-6)
|
| 81 |
+
|
| 82 |
+
def test_bm25_retrieve_when_pyserini_missing(monkeypatch, tmp_path):
|
| 83 |
+
# Simulate ImportError for pyserini.search.SimpleSearcher
|
| 84 |
+
import sys
|
| 85 |
+
|
| 86 |
+
# Remove pyserini.search.SimpleSearcher at import time
|
| 87 |
+
monkeypatch.setitem(sys.modules, "pyserini.search", None)
|
| 88 |
+
|
| 89 |
+
doc_store_path = tmp_path / "docs.jsonl"
|
| 90 |
+
doc_store_path.write_text('{"id":0,"text":"hello"}\n')
|
| 91 |
+
|
| 92 |
+
index_dir = tmp_path / "bm25_index2"
|
| 93 |
+
# This should not raise, but self.searcher will be None
|
| 94 |
+
retriever = BM25Retriever(index_path=str(index_dir), doc_store_path=str(doc_store_path))
|
| 95 |
+
|
| 96 |
+
# Because SimpleSearcher couldn't load, retrieve() must return an empty list
|
| 97 |
+
assert retriever.retrieve("whatever", top_k=5) == []
|