# # Pyserini: Reproducible IR research with sparse and dense representations # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argparse import json import os import sys from tqdm import tqdm from pyserini.search.faiss import FaissSearcher from pyserini.query_iterator import get_query_iterator, TopicsFormat from pyserini.output_writer import get_output_writer, OutputFormat from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher from pyserini.search.hybrid import HybridSearcher from pyserini.search.faiss.__main__ import define_dsearch_args, init_query_encoder from pyserini.search.lucene.__main__ import define_search_args, set_bm25_parameters # Fixes this error: "OMP: Error #15: Initializing libomp.a, but found libomp.dylib already initialized." # https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' def define_fusion_args(parser): parser.add_argument('--alpha', type=float, metavar='num', required=False, default=0.1, help="alpha for hybrid search") parser.add_argument('--hits', type=int, required=False, default=1000, help='number of hits from dense and sparse') parser.add_argument('--normalization', action='store_true', required=False, help='hybrid score with normalization') parser.add_argument('--weight-on-dense', action='store_true', required=False, help='weight on dense part') def parse_args(parser, commands): # Divide argv by commands split_argv = [[]] for c in sys.argv[1:]: if c in commands.choices: split_argv.append([c]) else: split_argv[-1].append(c) # Initialize namespace args = argparse.Namespace() for c in commands.choices: setattr(args, c, None) # Parse each command parser.parse_args(split_argv[0], namespace=args) # Without command for argv in split_argv[1:]: # Commands n = argparse.Namespace() setattr(args, argv[0], n) parser.parse_args(argv, namespace=n) return args if __name__ == '__main__': parser = argparse.ArgumentParser(description='Conduct a hybrid search on dense+sparse indexes.') commands = parser.add_subparsers(title='sub-commands') dense_parser = commands.add_parser('dense') define_dsearch_args(dense_parser) sparse_parser = commands.add_parser('sparse') define_search_args(sparse_parser) fusion_parser = commands.add_parser('fusion') define_fusion_args(fusion_parser) run_parser = commands.add_parser('run') run_parser.add_argument('--topics', type=str, metavar='topic_name', required=False, help="Name of topics. Available: msmarco-passage-dev-subset.") run_parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.") run_parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value, help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") run_parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") run_parser.add_argument('--output', type=str, metavar='path', required=False, help="Path to output file.") run_parser.add_argument('--max-passage', action='store_true', default=False, help="Select only max passage from document.") run_parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, help="Final number of hits when selecting only max passage.") run_parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#', help="Delimiter between docid and passage id.") run_parser.add_argument('--batch-size', type=int, metavar='num', required=False, default=1, help="Specify batch size to search the collection concurrently.") run_parser.add_argument('--threads', type=int, metavar='num', required=False, default=1, help="Maximum number of threads to use.") args = parse_args(parser, commands) query_iterator = get_query_iterator(args.run.topics, TopicsFormat(args.run.topics_format)) topics = query_iterator.topics query_encoder = init_query_encoder(args.dense.encoder, args.dense.encoder_class, args.dense.tokenizer, args.run.topics, args.dense.encoded_queries, args.dense.device, args.dense.query_prefix) if os.path.exists(args.dense.index): # create searcher from index directory dsearcher = FaissSearcher(args.dense.index, query_encoder) else: # create searcher from prebuilt index name dsearcher = FaissSearcher.from_prebuilt_index(args.dense.index, query_encoder) if not dsearcher: exit() if os.path.exists(args.sparse.index): # create searcher from index directory if args.sparse.impact: ssearcher = LuceneImpactSearcher(args.sparse.index, args.sparse.encoder, args.sparse.min_idf) else: ssearcher = LuceneSearcher(args.sparse.index) else: # create searcher from prebuilt index name if args.sparse.impact: ssearcher = LuceneImpactSearcher.from_prebuilt_index(args.sparse.index, args.sparse.encoder, args.sparse.min_idf) else: ssearcher = LuceneSearcher.from_prebuilt_index(args.sparse.index) if not ssearcher: exit() set_bm25_parameters(ssearcher, args.sparse.index, args.sparse.k1, args.sparse.b) if args.sparse.language != 'en': ssearcher.set_language(args.sparse.language) hsearcher = HybridSearcher(dsearcher, ssearcher) if not hsearcher: exit() # build output path output_path = args.run.output print(f'Running {args.run.topics} topics, saving to {output_path}...') tag = 'hybrid' output_writer = get_output_writer(output_path, OutputFormat(args.run.output_format), 'w', max_hits=args.run.hits, tag=tag, topics=topics, use_max_passage=args.run.max_passage, max_passage_delimiter=args.run.max_passage_delimiter, max_passage_hits=args.run.max_passage_hits) with output_writer: batch_topics = list() batch_topic_ids = list() for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): if args.run.batch_size <= 1 and args.run.threads <= 1: hits = hsearcher.search(text, args.fusion.hits, args.run.hits, args.fusion.alpha, args.fusion.normalization, args.fusion.weight_on_dense) results = [(topic_id, hits)] else: batch_topic_ids.append(str(topic_id)) batch_topics.append(text) if (index + 1) % args.run.batch_size == 0 or \ index == len(topics.keys()) - 1: results = hsearcher.batch_search( batch_topics, batch_topic_ids, args.fusion.hits, args.run.hits, args.run.threads, args.fusion.alpha, args.fusion.normalization, args.fusion.weight_on_dense) results = [(id_, results[id_]) for id_ in batch_topic_ids] batch_topic_ids.clear() batch_topics.clear() else: continue for topic, hits in results: output_writer.write(topic, hits) results.clear()