Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import argparse | |
import os | |
from typing import OrderedDict | |
from tqdm import tqdm | |
from pyserini.search import FaissSearcher, BinaryDenseSearcher, TctColBertQueryEncoder, QueryEncoder, \ | |
DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, AnceQueryEncoder, AutoQueryEncoder, DenseVectorAveragePrf, \ | |
DenseVectorRocchioPrf, DenseVectorAncePrf | |
from pyserini.encode import PcaEncoder | |
from pyserini.query_iterator import get_query_iterator, TopicsFormat | |
from pyserini.output_writer import get_output_writer, OutputFormat | |
from pyserini.search.lucene import LuceneSearcher | |
# from ._prf import DenseVectorAveragePrf, DenseVectorRocchioPrf | |
# Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized." | |
# https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
def define_dsearch_args(parser): | |
parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, | |
help="Path to Faiss index or name of prebuilt index.") | |
parser.add_argument('--encoder-class', type=str, metavar='which query encoder class to use. `default` would infer from the args.encoder', | |
required=False, | |
choices=["dkrr", "dpr", "bpr", "tct_colbert", "ance", "sentence", "contriever", "auto"], | |
default=None, | |
help='which query encoder class to use. `default` would infer from the args.encoder') | |
parser.add_argument('--encoder', type=str, metavar='path to query encoder checkpoint or encoder name', | |
required=False, | |
help="Path to query encoder pytorch checkpoint or hgf encoder model name") | |
parser.add_argument('--tokenizer', type=str, metavar='name or path', | |
required=False, | |
help="Path to a hgf tokenizer name or path") | |
parser.add_argument('--encoded-queries', type=str, metavar='path to query encoded queries dir or queries name', | |
required=False, | |
help="Path to query encoder pytorch checkpoint or hgf encoder model name") | |
parser.add_argument('--pca-model', type=str, metavar='path', required=False, | |
default=None, help="Path to a faiss pca model") | |
parser.add_argument('--device', type=str, metavar='device to run query encoder', required=False, default='cpu', | |
help="Device to run query encoder, cpu or [cuda:0, cuda:1, ...]") | |
parser.add_argument('--query-prefix', type=str, metavar='str', required=False, default=None, | |
help="Query prefix if exists.") | |
parser.add_argument('--searcher', type=str, metavar='str', required=False, default='simple', | |
help="dense searcher type") | |
parser.add_argument('--prf-depth', type=int, metavar='num of passages used for PRF', required=False, default=0, | |
help="Specify how many passages are used for PRF, 0: Simple retrieval with no PRF, > 0: perform PRF") | |
parser.add_argument('--prf-method', type=str, metavar='avg or rocchio', required=False, default='avg', | |
help="Choose PRF methods, avg or rocchio") | |
parser.add_argument('--rocchio-alpha', type=float, metavar='alpha parameter for rocchio', required=False, | |
default=0.9, | |
help="The alpha parameter to control the contribution from the query vector") | |
parser.add_argument('--rocchio-beta', type=float, metavar='beta parameter for rocchio', required=False, default=0.1, | |
help="The beta parameter to control the contribution from the average vector of the positive PRF passages") | |
parser.add_argument('--rocchio-gamma', type=float, metavar='gamma parameter for rocchio', required=False, default=0.1, | |
help="The gamma parameter to control the contribution from the average vector of the negative PRF passages") | |
parser.add_argument('--rocchio-topk', type=int, metavar='topk passages as positive for rocchio', required=False, default=3, | |
help="Set topk passages as positive PRF passages for rocchio") | |
parser.add_argument('--rocchio-bottomk', type=int, metavar='bottomk passages as negative for rocchio', required=False, default=0, | |
help="Set bottomk passages as negative PRF passages for rocchio, 0: do not use negatives prf passages.") | |
parser.add_argument('--sparse-index', type=str, metavar='sparse lucene index containing contents', required=False, | |
help='The path to sparse index containing the passage contents') | |
parser.add_argument('--ance-prf-encoder', type=str, metavar='query encoder path for ANCE-PRF', required=False, | |
help='The path or name to ANCE-PRF model checkpoint') | |
parser.add_argument('--ef-search', type=int, metavar='efSearch for HNSW index', required=False, default=None, | |
help="Set efSearch for HNSW index") | |
def init_query_encoder(encoder, encoder_class, tokenizer_name, topics_name, encoded_queries, device, prefix): | |
encoded_queries_map = { | |
'msmarco-passage-dev-subset': 'tct_colbert-msmarco-passage-dev-subset', | |
'dpr-nq-dev': 'dpr_multi-nq-dev', | |
'dpr-nq-test': 'dpr_multi-nq-test', | |
'dpr-trivia-dev': 'dpr_multi-trivia-dev', | |
'dpr-trivia-test': 'dpr_multi-trivia-test', | |
'dpr-wq-test': 'dpr_multi-wq-test', | |
'dpr-squad-test': 'dpr_multi-squad-test', | |
'dpr-curated-test': 'dpr_multi-curated-test' | |
} | |
encoder_class_map = { | |
"dkrr": DkrrDprQueryEncoder, | |
"dpr": DprQueryEncoder, | |
"bpr": BprQueryEncoder, | |
"tct_colbert": TctColBertQueryEncoder, | |
"ance": AnceQueryEncoder, | |
"sentence": AutoQueryEncoder, | |
"contriever": AutoQueryEncoder, | |
"auto": AutoQueryEncoder, | |
} | |
if encoder: | |
_encoder_class = encoder_class | |
# determine encoder_class | |
if encoder_class is not None: | |
encoder_class = encoder_class_map[encoder_class] | |
else: | |
# if any class keyword was matched in the given encoder name, | |
# use that encoder class | |
for class_keyword in encoder_class_map: | |
if class_keyword in encoder.lower(): | |
encoder_class = encoder_class_map[class_keyword] | |
break | |
# if none of the class keyword was matched, | |
# use the AutoQueryEncoder | |
if encoder_class is None: | |
encoder_class = AutoQueryEncoder | |
# prepare arguments to encoder class | |
kwargs = dict(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device, prefix=prefix) | |
if (_encoder_class == "sentence") or ("sentence" in encoder): | |
kwargs.update(dict(pooling='mean', l2_norm=True)) | |
if (_encoder_class == "contriever") or ("contriever" in encoder): | |
kwargs.update(dict(pooling='mean', l2_norm=False)) | |
return encoder_class(**kwargs) | |
if encoded_queries: | |
if os.path.exists(encoded_queries): | |
if 'bpr' in encoded_queries: | |
return BprQueryEncoder(encoded_query_dir=encoded_queries) | |
else: | |
return QueryEncoder(encoded_queries) | |
else: | |
if 'bpr' in encoded_queries: | |
return BprQueryEncoder.load_encoded_queries(encoded_queries) | |
else: | |
return QueryEncoder.load_encoded_queries(encoded_queries) | |
if topics_name in encoded_queries_map: | |
return QueryEncoder.load_encoded_queries(encoded_queries_map[topics_name]) | |
raise ValueError(f'No encoded queries for topic {topics_name}') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Search a Faiss index.') | |
parser.add_argument('--topics', type=str, metavar='topic_name', required=True, | |
help="Name of topics. Available: msmarco-passage-dev-subset.") | |
parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.") | |
parser.add_argument('--binary-hits', type=int, metavar='num', required=False, default=1000, | |
help="Number of binary hits.") | |
parser.add_argument("--rerank", action="store_true", help='whethere rerank bpr sparse results.') | |
parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value, | |
help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") | |
parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, | |
help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") | |
parser.add_argument('--output', type=str, metavar='path', required=True, help="Path to output file.") | |
parser.add_argument('--max-passage', action='store_true', | |
default=False, help="Select only max passage from document.") | |
parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, | |
help="Final number of hits when selecting only max passage.") | |
parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#', | |
help="Delimiter between docid and passage id.") | |
parser.add_argument('--batch-size', type=int, metavar='num', required=False, default=1, | |
help="search batch of queries in parallel") | |
parser.add_argument('--threads', type=int, metavar='num', required=False, default=1, | |
help="maximum threads to use during search") | |
# For some test collections, a query is doc from the corpus (e.g., arguana in BEIR). | |
# We want to remove the query from the results. This is equivalent to -removeQuery in Java. | |
parser.add_argument('--remove-query', action='store_true', default=False, help="Remove query from results list.") | |
define_dsearch_args(parser) | |
args = parser.parse_args() | |
query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format)) | |
topics = query_iterator.topics | |
query_encoder = init_query_encoder( | |
args.encoder, args.encoder_class, args.tokenizer, args.topics, args.encoded_queries, args.device, args.query_prefix) | |
if args.pca_model: | |
query_encoder = PcaEncoder(query_encoder, args.pca_model) | |
kwargs = {} | |
if os.path.exists(args.index): | |
# create searcher from index directory | |
if args.searcher.lower() == 'bpr': | |
kwargs = dict(binary_k=args.binary_hits, rerank=args.rerank) | |
searcher = BinaryDenseSearcher(args.index, query_encoder) | |
else: | |
searcher = FaissSearcher(args.index, query_encoder) | |
else: | |
# create searcher from prebuilt index name | |
if args.searcher.lower() == 'bpr': | |
kwargs = dict(binary_k=args.binary_hits, rerank=args.rerank) | |
searcher = BinaryDenseSearcher.from_prebuilt_index(args.index, query_encoder) | |
else: | |
searcher = FaissSearcher.from_prebuilt_index(args.index, query_encoder) | |
if args.ef_search: | |
searcher.set_hnsw_ef_search(args.ef_search) | |
if not searcher: | |
exit() | |
# Check PRF Flag | |
if args.prf_depth > 0 and type(searcher) == FaissSearcher: | |
PRF_FLAG = True | |
if args.prf_method.lower() == 'avg': | |
prfRule = DenseVectorAveragePrf() | |
elif args.prf_method.lower() == 'rocchio': | |
prfRule = DenseVectorRocchioPrf(args.rocchio_alpha, args.rocchio_beta, args.rocchio_gamma, | |
args.rocchio_topk, args.rocchio_bottomk) | |
# ANCE-PRF is using a new query encoder, so the input to DenseVectorAncePrf is different | |
elif args.prf_method.lower() == 'ance-prf' and type(query_encoder) == AnceQueryEncoder: | |
if os.path.exists(args.sparse_index): | |
sparse_searcher = LuceneSearcher(args.sparse_index) | |
else: | |
sparse_searcher = LuceneSearcher.from_prebuilt_index(args.sparse_index) | |
prf_query_encoder = AnceQueryEncoder(encoder_dir=args.ance_prf_encoder, tokenizer_name=args.tokenizer, | |
device=args.device) | |
prfRule = DenseVectorAncePrf(prf_query_encoder, sparse_searcher) | |
print(f'Running FaissSearcher with {args.prf_method.upper()} PRF...') | |
else: | |
PRF_FLAG = False | |
# build output path | |
output_path = args.output | |
print(f'Running {args.topics} topics, saving to {output_path}...') | |
tag = 'Faiss' | |
output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w', | |
max_hits=args.hits, tag=tag, topics=topics, | |
use_max_passage=args.max_passage, | |
max_passage_delimiter=args.max_passage_delimiter, | |
max_passage_hits=args.max_passage_hits) | |
with output_writer: | |
batch_topics = list() | |
batch_topic_ids = list() | |
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): | |
if args.batch_size <= 1 and args.threads <= 1: | |
if PRF_FLAG: | |
emb_q, prf_candidates = searcher.search(text, k=args.prf_depth, return_vector=True, **kwargs) | |
# ANCE-PRF input is different, do not need query embeddings | |
if args.prf_method.lower() == 'ance-prf': | |
prf_emb_q = prfRule.get_prf_q_emb(text, prf_candidates) | |
else: | |
prf_emb_q = prfRule.get_prf_q_emb(emb_q[0], prf_candidates) | |
prf_emb_q = np.expand_dims(prf_emb_q, axis=0).astype('float32') | |
hits = searcher.search(prf_emb_q, k=args.hits, **kwargs) | |
else: | |
hits = searcher.search(text, args.hits, **kwargs) | |
results = [(topic_id, hits)] | |
else: | |
batch_topic_ids.append(str(topic_id)) | |
batch_topics.append(text) | |
if (index + 1) % args.batch_size == 0 or \ | |
index == len(topics.keys()) - 1: | |
if PRF_FLAG: | |
q_embs, prf_candidates = searcher.batch_search(batch_topics, batch_topic_ids, | |
k=args.prf_depth, return_vector=True, **kwargs) | |
# ANCE-PRF input is different, do not need query embeddings | |
if args.prf_method.lower() == 'ance-prf': | |
prf_embs_q = prfRule.get_batch_prf_q_emb(batch_topics, batch_topic_ids, prf_candidates) | |
else: | |
prf_embs_q = prfRule.get_batch_prf_q_emb(batch_topic_ids, q_embs, prf_candidates) | |
results = searcher.batch_search(prf_embs_q, batch_topic_ids, k=args.hits, threads=args.threads, | |
**kwargs) | |
results = [(id_, results[id_]) for id_ in batch_topic_ids] | |
else: | |
results = searcher.batch_search(batch_topics, batch_topic_ids, args.hits, threads=args.threads, | |
**kwargs) | |
results = [(id_, results[id_]) for id_ in batch_topic_ids] | |
batch_topic_ids.clear() | |
batch_topics.clear() | |
else: | |
continue | |
for topic, hits in results: | |
# For some test collections, a query is doc from the corpus (e.g., arguana in BEIR). | |
# We want to remove the query from the results. | |
if args.remove_query: | |
hits = [hit for hit in hits if hit.docid != topic] | |
output_writer.write(topic, hits) | |
results.clear() | |