EureCA / dsp /primitives /search.py
tonneli's picture
Delete history
f5776d3
import numpy as np
import dsp
def retrieve(query: str, k: int, **kwargs) -> list[str]:
"""Retrieves passages from the RM for the query and returns the top k passages."""
if not dsp.settings.rm:
raise AssertionError("No RM is loaded.")
passages = dsp.settings.rm(query, k=k, **kwargs)
passages = [psg.long_text for psg in passages], [psg.source_name for psg in passages]
# if dsp.settings.reranker:
# passages_cs_scores = dsp.settings.reranker(query, passages)
# passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1]
# passages = [passages[idx] for idx in passages_cs_scores_sorted]
return passages
def retrieveRerankEnsemble(queries: list[str], k: int) -> list[str]:
if not (dsp.settings.rm and dsp.settings.reranker):
raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.")
queries = [q for q in queries if q]
passages = {}
for query in queries:
retrieved_passages = dsp.settings.rm(query, k=k*3)
passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages])
for idx in np.argsort(passages_cs_scores)[::-1]:
psg = retrieved_passages[idx]
passages[psg.long_text] = passages.get(psg.long_text, []) + [
passages_cs_scores[idx]
]
passages = [(np.average(score), text) for text, score in passages.items()]
return [text for _, text in sorted(passages, reverse=True)[:k]]
def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True) -> list[str]:
"""Retrieves passages from the RM for each query in queries and returns the top k passages
based on the probability or score.
"""
if not dsp.settings.rm:
raise AssertionError("No RM is loaded.")
if dsp.settings.reranker:
return retrieveRerankEnsemble(queries, k)
queries = [q for q in queries if q]
if len(queries) == 1:
return retrieve(queries[0], k)
passages = {}
for q in queries:
for psg in dsp.settings.rm(q, k=k * 3):
if by_prob:
passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.prob
# passages["source_name"] = passages.get("source_name", "n/a")
else:
passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.score
# passages["source_name"] = passages.get("source_name", "n/a")
passages = [(score, text) for text, score in passages.items()]
passages = sorted(passages, reverse=True)[:k]
passages = [text for _, text in passages]
return passages