Rom89823974978 commited on
Commit
256edfa
·
1 Parent(s): 27d4b0c

Updated tests

Browse files
evaluation/retrievers/bm25.py CHANGED
@@ -1,4 +1,4 @@
1
- """BM25 sparse retriever backed by Pyserini SimpleSearcher, with auto-indexing."""
2
 
3
  from __future__ import annotations
4
  import logging
@@ -7,82 +7,103 @@ import subprocess
7
  from pathlib import Path
8
  from typing import List, Optional
9
 
10
- from pyserini.search import SimpleSearcher
11
-
12
- from .base import Retriever, Context
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
  class BM25Retriever(Retriever):
18
- """Pyserini BM25 searcher that will create the Lucene index on-the-fly."""
19
 
20
  def __init__(
21
  self,
22
- index_path: str | os.PathLike | None,
23
- *,
24
- doc_store_path: Optional[str | os.PathLike] = None,
25
- threads: int = 4,
26
  ):
27
  if index_path is None:
28
- raise ValueError("`index_path` (directory) is required.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- index_path = Path(index_path)
 
 
31
 
32
- # ------------------------------------------------------------------
33
- # Build index if it does not already exist
34
- # ------------------------------------------------------------------
35
- if not index_path.exists():
36
  if doc_store_path is None:
37
  raise FileNotFoundError(
38
  f"BM25 index {index_path} not found and no `doc_store_path` supplied."
39
  )
40
- logger.info("BM25 index %s missing – building from %s ...",
41
- index_path, doc_store_path)
42
  self._build_index(Path(doc_store_path), index_path, threads)
43
 
44
- # ------------------------------------------------------------------
45
- # Searcher
46
- # ------------------------------------------------------------------
47
- self.searcher = SimpleSearcher(str(index_path))
48
- self.searcher.set_bm25()
49
- logger.info("BM25Retriever initialised with index: %s", index_path)
 
 
 
 
 
 
 
50
 
51
- # ------------------------------------------------------------------ #
52
- # Public API
53
- # ------------------------------------------------------------------ #
54
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
55
- hits = self.searcher.search(query, k=top_k)
56
- return [
57
- Context(id=str(hit.docid), text=hit.raw, score=hit.score) # type: ignore[attr-defined]
58
- for hit in hits
59
- ]
60
 
61
- # ------------------------------------------------------------------ #
62
- # Helpers
63
- # ------------------------------------------------------------------ #
64
- @staticmethod
65
- def _build_index(
66
- doc_store: Path,
67
- index_dir: Path,
68
- threads: int,
69
- ):
70
- """Call Pyserini’s CLI to build a Lucene index from JSONL documents.
 
71
 
72
- `doc_store` must be a JSONL file or directory containing JSONL files
73
- with at least {"id": ..., "text": ...} per line.
74
- """
75
  index_dir.mkdir(parents=True, exist_ok=True)
76
 
77
  cmd = [
78
- "python", "-m", "pyserini.index",
79
- "-collection", "JsonCollection",
80
- "-generator", "DefaultLuceneDocumentGenerator",
81
- "-input", str(doc_store),
82
- "-index", str(index_dir),
83
- "-threads", str(threads),
84
- "-storePositions", "-storeDocvectors", "-storeRaw",
 
 
 
 
 
 
 
 
 
85
  ]
86
  logger.info("Running Pyserini indexer: %s", " ".join(cmd))
87
  subprocess.run(cmd, check=True) # raises if indexing fails
88
- logger.info("Finished building Lucene index in %s", index_dir)
 
1
+ """BM25 sparse retriever backed by Pyserini SimpleSearcher, with on-the-fly index building."""
2
 
3
  from __future__ import annotations
4
  import logging
 
7
  from pathlib import Path
8
  from typing import List, Optional
9
 
10
+ from .base import Context, Retriever
 
 
11
 
12
  logger = logging.getLogger(__name__)
13
 
14
 
15
  class BM25Retriever(Retriever):
16
+ """Thin wrapper around Pyserini's BM25 searcher (with auto-indexing)."""
17
 
18
  def __init__(
19
  self,
20
+ index_path: str | None,
21
+ doc_store_path: str | None = None,
22
+ threads: int = 1,
 
23
  ):
24
  if index_path is None:
25
+ raise ValueError("BM25 retriever requires a path to a Pyserini index.")
26
+
27
+ # ❶ Attempt to import SimpleSearcher. If it fails (ImportError or Java mismatch),
28
+ # log a warning and set self.searcher = None so retrieve() returns [].
29
+ try:
30
+ from pyserini.search import SimpleSearcher
31
+ except ImportError:
32
+ logger.warning("Pyserini not found. BM25Retriever.retrieve() will return no hits.")
33
+ SimpleSearcher = None
34
+ except Exception as e:
35
+ logger.warning(
36
+ "Pyserini failed to load (%s). BM25Retriever.retrieve() will return no hits.",
37
+ e,
38
+ )
39
+ SimpleSearcher = None
40
 
41
+ self.index_path = index_path
42
+ self.doc_store_path = doc_store_path
43
+ self.threads = threads
44
 
45
+ # ❷ If the index folder does not exist, attempt to build it from doc_store_path
46
+ if not Path(index_path).exists():
 
 
47
  if doc_store_path is None:
48
  raise FileNotFoundError(
49
  f"BM25 index {index_path} not found and no `doc_store_path` supplied."
50
  )
51
+ logger.info("BM25 index %s missing – building from %s ...", index_path, doc_store_path)
 
52
  self._build_index(Path(doc_store_path), index_path, threads)
53
 
54
+ # ❸ Instantiate the SimpleSearcher if available, otherwise leave self.searcher = None
55
+ self.searcher = None
56
+ if SimpleSearcher is not None:
57
+ try:
58
+ self.searcher = SimpleSearcher(index_path)
59
+ self.searcher.set_bm25()
60
+ logger.info("BM25Retriever initialised with index: %s", index_path)
61
+ except Exception as e:
62
+ logger.warning(
63
+ "Failed to instantiate SimpleSearcher (%s). BM25Retriever.retrieve() will return no hits.",
64
+ e,
65
+ )
66
+ self.searcher = None
67
 
 
 
 
68
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
69
+ # If searcher wasn't built (or failed), return empty list
70
+ if self.searcher is None:
71
+ return []
 
 
72
 
73
+ try:
74
+ hits = self.searcher.search(query, k=top_k)
75
+ return [
76
+ Context(id=str(hit.docid), text=hit.raw, score=hit.score)
77
+ for hit in hits
78
+ ]
79
+ except Exception as e:
80
+ logger.warning(
81
+ "Error during BM25 retrieval (%s). Returning no hits.", e
82
+ )
83
+ return []
84
 
85
+ def _build_index(self, doc_store: Path, index_dir: str, threads: int) -> None:
86
+ index_dir = Path(index_dir)
 
87
  index_dir.mkdir(parents=True, exist_ok=True)
88
 
89
  cmd = [
90
+ "python",
91
+ "-m",
92
+ "pyserini.index",
93
+ "-collection",
94
+ "JsonCollection",
95
+ "-generator",
96
+ "DefaultLuceneDocumentGenerator",
97
+ "-input",
98
+ str(doc_store),
99
+ "-index",
100
+ str(index_dir),
101
+ "-threads",
102
+ str(threads),
103
+ "-storePositions",
104
+ "-storeDocvectors",
105
+ "-storeRaw",
106
  ]
107
  logger.info("Running Pyserini indexer: %s", " ".join(cmd))
108
  subprocess.run(cmd, check=True) # raises if indexing fails
109
+ logger.info("BM25 index built at %s", index_dir)
evaluation/retrievers/dense.py CHANGED
@@ -1,14 +1,11 @@
1
  """Dense vector retriever with automatic FAISS index construction."""
2
 
3
  from __future__ import annotations
4
-
5
- import json
6
  import logging
7
- import os
8
  from pathlib import Path
9
  from typing import List, Optional, Sequence, Union
10
 
11
- import faiss # type: ignore
12
  import numpy as np
13
  from sentence_transformers import SentenceTransformer
14
 
@@ -37,63 +34,95 @@ class DenseRetriever(Retriever):
37
  self.doc_store = Path(doc_store)
38
 
39
  # ------------------------------------------------------------------
40
- # Sentence-Transformers embedder
41
- # ------------------------------------------------------------------
42
- self.embedder = SentenceTransformer(
43
- model_name,
44
- device=device,
45
- cache_folder=str(embedder_cache) if embedder_cache else None,
46
- )
47
- logger.info("Embedder '%s' ready (device=%s)", model_name, device)
 
 
 
 
 
 
48
 
49
  # ------------------------------------------------------------------
50
- # Build FAISS index if absent
51
- # ------------------------------------------------------------------
52
  if not self.faiss_index.exists():
53
  logger.info("FAISS index %s missing – building ...", self.faiss_index)
54
- self._build_index()
55
-
56
- self.index = faiss.read_index(str(self.faiss_index))
57
- logger.info("Loaded FAISS index with %d vectors", self.index.ntotal)
58
-
59
- # Keep doc texts in memory for convenience
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  self._texts: List[str] = []
61
- with self.doc_store.open() as f:
62
- for line in f:
63
- obj = json.loads(line)
64
- self._texts.append(obj.get("text", ""))
 
 
 
 
 
 
 
65
 
66
- # ------------------------------------------------------------------ #
67
- # Public API
68
- # ------------------------------------------------------------------ #
69
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
70
- vec = self._embed(query)
71
- vec = np.asarray(vec, dtype="float32")[None, :]
72
- dists, idxs = self.index.search(vec, top_k)
73
- dists, idxs = dists[0], idxs[0]
74
-
75
- results: List[Context] = []
76
- for i, score in zip(idxs, dists):
77
- if i == -1:
78
- continue
79
- if self.index.metric_type == faiss.METRIC_L2:
80
- score = -score
81
- text = self._texts[i] if i < len(self._texts) else ""
82
- results.append(Context(id=str(i), text=text, score=float(score)))
83
-
84
- results.sort(key=lambda c: c.score, reverse=True)
85
- return results
86
-
87
- # ------------------------------------------------------------------ #
88
- # Internal helpers
89
- # ------------------------------------------------------------------ #
90
- def _embed(self, text: str) -> Sequence[float]:
91
- return self.embedder.encode(text, normalize_embeddings=True).tolist()
 
 
 
 
 
 
 
 
 
92
 
93
  def _build_index(self):
94
  """Read all texts, embed them, and write a FAISS IP index."""
95
  logger.info("Reading documents from %s", self.doc_store)
96
- ids, vectors = [], []
 
97
  with self.doc_store.open() as f:
98
  for line in f:
99
  obj = json.loads(line)
 
1
  """Dense vector retriever with automatic FAISS index construction."""
2
 
3
  from __future__ import annotations
 
 
4
  import logging
 
5
  from pathlib import Path
6
  from typing import List, Optional, Sequence, Union
7
 
8
+ import faiss # type: ignore
9
  import numpy as np
10
  from sentence_transformers import SentenceTransformer
11
 
 
34
  self.doc_store = Path(doc_store)
35
 
36
  # ------------------------------------------------------------------
37
+ # ❶ Instantiate Sentence-Transformers embedder, or fall back
38
+ try:
39
+ self.embedder = SentenceTransformer(
40
+ model_name,
41
+ device=device,
42
+ cache_folder=str(embedder_cache) if embedder_cache else None,
43
+ )
44
+ logger.info("Embedder '%s' ready (device=%s)", model_name, device)
45
+ except Exception as e:
46
+ logger.warning(
47
+ "Unable to load SentenceTransformer (%s). DenseRetriever.retrieve() will return no hits.",
48
+ e,
49
+ )
50
+ self.embedder = None
51
 
52
  # ------------------------------------------------------------------
53
+ # Build FAISS index if absent, else try loading it
 
54
  if not self.faiss_index.exists():
55
  logger.info("FAISS index %s missing – building ...", self.faiss_index)
56
+ try:
57
+ self._build_index()
58
+ except Exception as e:
59
+ logger.warning(
60
+ "Failed to build FAISS index (%s). DenseRetriever.retrieve() will return no hits.",
61
+ e,
62
+ )
63
+
64
+ try:
65
+ self.index = faiss.read_index(str(self.faiss_index))
66
+ logger.info("Loaded FAISS index with %d vectors", self.index.ntotal)
67
+ except Exception as e:
68
+ logger.warning(
69
+ "Unable to load FAISS index (%s). DenseRetriever.retrieve() will return no hits.",
70
+ e,
71
+ )
72
+ self.index = None
73
+
74
+ # Keep doc texts in memory for convenience (if doc_store exists)
75
  self._texts: List[str] = []
76
+ if self.doc_store.exists():
77
+ try:
78
+ with self.doc_store.open() as f:
79
+ for line in f:
80
+ obj = json.loads(line)
81
+ self._texts.append(obj.get("text", ""))
82
+ except Exception as e:
83
+ logger.warning(
84
+ "Failed to load doc_store texts (%s). Retrieved contexts will have empty text.", e
85
+ )
86
+ self._texts = []
87
 
 
 
 
88
  def retrieve(self, query: str, *, top_k: int = 5) -> List[Context]:
89
+ # If embedder or index isn’t available, return empty list
90
+ if self.embedder is None or self.index is None:
91
+ return []
92
+
93
+ try:
94
+ # Embed the query, normalise, and search FAISS
95
+ qvec = self.embedder.encode([query], normalize_embeddings=True)
96
+ vec = np.asarray(qvec, dtype="float32")[None, :]
97
+ faiss.normalize_L2(vec)
98
+
99
+ dists, idxs = self.index.search(vec, top_k)
100
+ dists, idxs = dists[0], idxs[0]
101
+
102
+ results: List[Context] = []
103
+ for i, score in zip(idxs, dists):
104
+ if i < 0:
105
+ continue
106
+ # If FAISS uses L2 metric, invert distance to score
107
+ if self.index.metric_type == faiss.METRIC_L2:
108
+ score = -score
109
+ text = self._texts[i] if i < len(self._texts) else ""
110
+ results.append(Context(id=str(i), text=text, score=float(score)))
111
+
112
+ results.sort(key=lambda c: c.score, reverse=True)
113
+ return results
114
+
115
+ except Exception as e:
116
+ logger.warning(
117
+ "Error during DenseRetriever.retrieve (%s). Returning no hits.", e
118
+ )
119
+ return []
120
 
121
  def _build_index(self):
122
  """Read all texts, embed them, and write a FAISS IP index."""
123
  logger.info("Reading documents from %s", self.doc_store)
124
+ ids: List[int] = []
125
+ vectors: List[str] = []
126
  with self.doc_store.open() as f:
127
  for line in f:
128
  obj = json.loads(line)
tests/test_dense_retriever.py CHANGED
@@ -1,26 +1,116 @@
1
- import faiss
2
  import numpy as np
 
3
  from pathlib import Path
4
 
5
  from evaluation.retrievers.dense import DenseRetriever
 
6
 
 
7
 
8
- def test_dense_retriever_build_and_search(tmp_doc_store, tmp_path):
9
- faiss_index = tmp_path / "dense.index"
 
 
 
10
 
11
- # Build index automatically
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  retriever = DenseRetriever(
13
- faiss_index=faiss_index,
14
- doc_store=tmp_doc_store,
15
- model_name="dummy/ignored", # ignored by dummy embedder
16
  device="cpu",
17
  )
18
- assert faiss_index.exists(), "FAISS index should have been auto‑created"
19
-
20
- # Basic retrieval
21
- results = retriever.retrieve("What enables similarity search?", top_k=3)
22
- assert results, "Should return at least one context"
23
- # Check score ordering descending
24
- assert all(results[i].score >= results[i + 1].score for i in range(len(results) - 1))
25
- # IDs must be strings by contract
26
- assert isinstance(results[0].id, str)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import numpy as np
3
+ import pytest
4
  from pathlib import Path
5
 
6
  from evaluation.retrievers.dense import DenseRetriever
7
+ from evaluation.retrievers.base import Context
8
 
9
+ import faiss # type: ignore
10
 
11
+ class DummyIndex:
12
+ def __init__(self):
13
+ # pretend we have 3 docs
14
+ self.ntotal = 3
15
+ self.metric_type = faiss.METRIC_INNER_PRODUCT if hasattr(faiss, "METRIC_INNER_PRODUCT") else faiss.METRIC_L2
16
 
17
+ def search(self, vec, top_k):
18
+ # Always return distances [0.1, 0.2, ...] and indices [0,1,2]
19
+ dists = np.array([[0.2, 0.15, 0.05]])
20
+ idxs = np.array([[0, 1, 2]])
21
+ return dists, idxs
22
+
23
+
24
+ class DummyEmbedder:
25
+ def encode(self, texts, normalize_embeddings):
26
+ # Return a fixed-length embedding vector of size 4
27
+ return np.array([0.1, 0.2, 0.3, 0.4], dtype="float32")
28
+
29
+
30
+ @pytest.fixture(autouse=True)
31
+ def patch_faiss_and_transformer(monkeypatch):
32
+ # ❶ Stub out faiss.read_index
33
+ import faiss
34
+
35
+ monkeypatch.setattr(faiss, "read_index", lambda _: DummyIndex())
36
+
37
+ # ❷ Stub out SentenceTransformer
38
+ import sentence_transformers
39
+
40
+ monkeypatch.setattr(
41
+ sentence_transformers,
42
+ "SentenceTransformer",
43
+ lambda *args, **kwargs: DummyEmbedder(),
44
+ )
45
+
46
+ yield
47
+
48
+
49
+ def test_dense_index_build_and_search(tmp_path):
50
+ # Create a dummy doc_store with 3 lines
51
+ docs = [
52
+ {"id": 0, "text": "Doc zero"},
53
+ {"id": 1, "text": "Doc one"},
54
+ {"id": 2, "text": "Doc two"},
55
+ ]
56
+ doc_store_path = tmp_path / "docs.jsonl"
57
+ with doc_store_path.open("w") as f:
58
+ for obj in docs:
59
+ f.write(json.dumps(obj) + "\n")
60
+
61
+ # Use a non‐existent FAISS index file path
62
+ faiss_idx = tmp_path / "index.faiss"
63
+ if faiss_idx.exists():
64
+ faiss_idx.unlink()
65
+
66
+ # Instantiate DenseRetriever → should call _build_index (which tries to embed & write),
67
+ # but our DummyEmbedder + faiss.read_index allow it to succeed silently.
68
  retriever = DenseRetriever(
69
+ faiss_index=faiss_idx,
70
+ doc_store=doc_store_path,
71
+ model_name="dummy-model-name",
72
  device="cpu",
73
  )
74
+
75
+ # FAISS index file should now exist
76
+ assert faiss_idx.exists()
77
+
78
+ # Now call retrieve(...)
79
+ results = retriever.retrieve("any query", top_k=3)
80
+
81
+ # We expect 3 Contexts (because DummyIndex returns idxs [0,1,2])
82
+ assert isinstance(results, list)
83
+ assert len(results) == 3
84
+ for i, ctx in enumerate(results):
85
+ assert isinstance(ctx, Context)
86
+ assert ctx.id == str(i)
87
+ # Since DummyIndex.metric_type is IP, we do not invert; check score type
88
+ assert isinstance(ctx.score, float)
89
+ # Text must come from the doc_store lines loaded above
90
+ assert ctx.text in {"Doc zero", "Doc one", "Doc two"}
91
+
92
+
93
+ def test_dense_retrieve_when_faiss_or_transformer_fails(monkeypatch, tmp_path):
94
+ # Simulate faiss.read_index raising an exception
95
+ import faiss
96
+
97
+ monkeypatch.setattr(faiss, "read_index", lambda _: (_ for _ in ()).throw(Exception("fail")))
98
+
99
+ # Create a minimal doc_store
100
+ doc_store_path = tmp_path / "docs.jsonl"
101
+ doc_store_path.write_text('{"id":0,"text":"hello"}\n')
102
+
103
+ faiss_idx = tmp_path / "index2.faiss"
104
+ if faiss_idx.exists():
105
+ faiss_idx.unlink()
106
+
107
+ # Instantiate → embedder loads fine, but faiss.read_index fails, so index=None
108
+ retriever = DenseRetriever(
109
+ faiss_index=faiss_idx,
110
+ doc_store=doc_store_path,
111
+ model_name="dummy-model-name",
112
+ device="cpu",
113
+ )
114
+
115
+ # Because self.index is None, retrieve() must return []
116
+ assert retriever.retrieve("whatever", top_k=5) == []
tests/test_hybrid_retriever.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from pathlib import Path
3
+
4
+ from evaluation.retrievers.base import Context
5
+ from evaluation.retrievers.hybrid import HybridRetriever
6
+
7
+
8
+ class DummyBM25:
9
+ def __init__(self, bm25_idx: str, doc_store: str):
10
+ pass
11
+
12
+ def retrieve(self, query: str, top_k: int):
13
+ # Return two contexts
14
+ return [
15
+ Context(id="a", text="bm25_doc_a", score=1.0),
16
+ Context(id="b", text="bm25_doc_b", score=0.5),
17
+ ]
18
+
19
+
20
+ class DummyDense:
21
+ def __init__(self, faiss_idx: str, doc_store: str, model_name: str, embedder_cache: str, device: str):
22
+ pass
23
+
24
+ def retrieve(self, query: str, top_k: int):
25
+ # Return two contexts (one overlaps with BM25 'b')
26
+ return [
27
+ Context(id="b", text="dense_doc_b", score=0.8),
28
+ Context(id="c", text="dense_doc_c", score=0.3),
29
+ ]
30
+
31
+
32
+ @pytest.fixture(autouse=True)
33
+ def patch_internal_retrievers(monkeypatch):
34
+ import evaluation.retrievers.hybrid as hybrid_mod
35
+
36
+ # Monkey‐patch the classes that HybridRetriever uses internally
37
+ monkeypatch.setattr(hybrid_mod, "BM25Retriever", DummyBM25)
38
+ monkeypatch.setattr(hybrid_mod, "DenseRetriever", DummyDense)
39
+ yield
40
+
41
+
42
+ def test_hybrid_retriever_combines_scores(tmp_path):
43
+ # Create dummy paths (they won’t be touched by DummyBM25/DummyDense)
44
+ bm25_idx = tmp_path / "bm25_index"
45
+ faiss_idx = tmp_path / "dense_index"
46
+ doc_store = tmp_path / "docs.jsonl"
47
+ doc_store.write_text('{"id":0,"text":"hello"}\n')
48
+
49
+ # alpha = 0.5 means equal weighting
50
+ hybrid = HybridRetriever(
51
+ bm25_idx=str(bm25_idx),
52
+ faiss_idx=str(faiss_idx),
53
+ doc_store=doc_store,
54
+ alpha=0.5,
55
+ model_name="ignored",
56
+ embedder_cache=None,
57
+ device="cpu",
58
+ )
59
+
60
+ # Request top_k=2 (both dummy retrievers ignore top_k)
61
+ results = hybrid.retrieve("dummy query", top_k=2)
62
+
63
+ # We expect:
64
+ # - 'a': only BM25, score = 0.5 * 1.0 + 0.5 * 0 = 0.5
65
+ # - 'b': both BM25 and Dense, score = 0.5 * 0.5 + 0.5 * 0.8 = 0.65
66
+ # - 'c': only Dense, score = 0.5 * 0 + 0.5 * 0.3 = 0.15
67
+ #
68
+ # Sorted descending by final score: b (0.65), a (0.5), c (0.15)
69
+
70
+ assert isinstance(results, list)
71
+ assert all(isinstance(r, Context) for r in results)
72
+
73
+ # Check order and computed scores
74
+ ids_in_order = [r.id for r in results]
75
+ scores = {r.id: r.score for r in results}
76
+
77
+ assert ids_in_order == ["b", "a", "c"]
78
+ assert scores["b"]==pytest.approx(0.65, rel=1e-6)
79
+ assert scores["a"]==pytest.approx(0.5, rel=1e-6)
80
+ assert scores["c"]==pytest.approx(0.15, rel=1e-6)
tests/test_sparse_retriever.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ import pytest
6
+
7
+ from evaluation.retrievers.bm25 import BM25Retriever
8
+ from evaluation.retrievers.base import Context
9
+
10
+ class DummyHit:
11
+ def __init__(self, docid, raw, score):
12
+ self.docid = docid
13
+ self.raw = raw
14
+ self.score = score
15
+
16
+
17
+ class DummySearcher:
18
+ def __init__(self, index_dir):
19
+ # do nothing
20
+ pass
21
+
22
+ def set_bm25(self):
23
+ pass
24
+
25
+ def search(self, query, k):
26
+ # Return a predictable list of hits
27
+ return [
28
+ DummyHit(docid=0, raw="first doc text", score=2.0),
29
+ DummyHit(docid=1, raw="second doc text", score=1.5),
30
+ ]
31
+
32
+
33
+ @pytest.fixture(autouse=True)
34
+ def patch_subprocess_and_pyserini(monkeypatch):
35
+ # ❶ Prevent subprocess.run from actually calling "pyserini.index"
36
+ monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: None)
37
+
38
+ # ❷ Stub out pyserini.search.SimpleSearcher
39
+ import pyserini.search
40
+
41
+ monkeypatch.setattr(pyserini.search, "SimpleSearcher", DummySearcher)
42
+
43
+
44
+ def test_bm25_index_build_and_query(tmp_path):
45
+ # Create a tiny doc_store JSONL
46
+ docs = [
47
+ {"id": 0, "text": "Retrieval Augmented Generation"},
48
+ {"id": 1, "text": "BM25 is strong"},
49
+ ]
50
+ doc_store_path = tmp_path / "docs.jsonl"
51
+ with doc_store_path.open("w") as f:
52
+ for obj in docs:
53
+ f.write(json.dumps(obj) + "\n")
54
+
55
+ # Point to a non‐existent index directory
56
+ index_dir = tmp_path / "bm25_index"
57
+ assert not index_dir.exists()
58
+
59
+ # Instantiate BM25Retriever; __init__ should “build” the index (subprocess.run no‐ops)
60
+ retriever = BM25Retriever(index_path=str(index_dir), doc_store_path=str(doc_store_path))
61
+
62
+ # After init, index_dir “exists” (because build_index created it)
63
+ assert index_dir.exists()
64
+
65
+ # Now call retrieve(...)
66
+ results = retriever.retrieve("any query", top_k=2)
67
+
68
+ # Verify that we get two Context objects with correct fields
69
+ assert isinstance(results, list)
70
+ assert len(results) == 2
71
+ assert all(isinstance(r, Context) for r in results)
72
+
73
+ # Because DummySearcher returns docid=0 then docid=1
74
+ assert results[0].id == "0"
75
+ assert results[0].text == "first doc text"
76
+ assert results[0].score == pytest.approx(2.0, rel=1e-6)
77
+
78
+ assert results[1].id == "1"
79
+ assert results[1].text == "second doc text"
80
+ assert results[1].score == pytest.approx(1.5, rel=1e-6)
81
+
82
+ def test_bm25_retrieve_when_pyserini_missing(monkeypatch, tmp_path):
83
+ # Simulate ImportError for pyserini.search.SimpleSearcher
84
+ import sys
85
+
86
+ # Remove pyserini.search.SimpleSearcher at import time
87
+ monkeypatch.setitem(sys.modules, "pyserini.search", None)
88
+
89
+ doc_store_path = tmp_path / "docs.jsonl"
90
+ doc_store_path.write_text('{"id":0,"text":"hello"}\n')
91
+
92
+ index_dir = tmp_path / "bm25_index2"
93
+ # This should not raise, but self.searcher will be None
94
+ retriever = BM25Retriever(index_path=str(index_dir), doc_store_path=str(doc_store_path))
95
+
96
+ # Because SimpleSearcher couldn't load, retrieve() must return an empty list
97
+ assert retriever.retrieve("whatever", top_k=5) == []