# # Pyserini: Reproducible IR research with sparse and dense representations # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argparse import os from tqdm import tqdm from transformers import AutoTokenizer from pyserini.analysis import JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer from pyserini.output_writer import OutputFormat, get_output_writer from pyserini.pyclass import autoclass from pyserini.query_iterator import get_query_iterator, TopicsFormat from pyserini.search import JDisjunctionMaxQueryGenerator from . import LuceneImpactSearcher, LuceneSearcher, SlimSearcher from .reranker import ClassifierType, PseudoRelevanceClassifierReranker def set_bm25_parameters(searcher, index, k1=None, b=None): if k1 is not None or b is not None: if k1 is None or b is None: print('Must set *both* k1 and b for BM25!') exit() print(f'Setting BM25 parameters: k1={k1}, b={b}') searcher.set_bm25(k1, b) else: # Automatically set bm25 parameters based on known index... if index == 'msmarco-passage' or index == 'msmarco-passage-slim' or index == 'msmarco-v1-passage' or \ index == 'msmarco-v1-passage-slim' or index == 'msmarco-v1-passage-full': # See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-passage.md print('MS MARCO passage: setting k1=0.82, b=0.68') searcher.set_bm25(0.82, 0.68) elif index == 'msmarco-passage-expanded' or \ index == 'msmarco-v1-passage-d2q-t5' or \ index == 'msmarco-v1-passage-d2q-t5-docvectors': # See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-passage-docTTTTTquery.md print('MS MARCO passage w/ doc2query-T5 expansion: setting k1=2.18, b=0.86') searcher.set_bm25(2.18, 0.86) elif index == 'msmarco-doc' or index == 'msmarco-doc-slim' or index == 'msmarco-v1-doc' or \ index == 'msmarco-v1-doc-slim' or index == 'msmarco-v1-doc-full': # See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-doc.md print('MS MARCO doc: setting k1=4.46, b=0.82') searcher.set_bm25(4.46, 0.82) elif index == 'msmarco-doc-per-passage' or index == 'msmarco-doc-per-passage-slim' or \ index == 'msmarco-v1-doc-segmented' or index == 'msmarco-v1-doc-segmented-slim' or \ index == 'msmarco-v1-doc-segmented-full': # See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-doc-segmented.md print('MS MARCO doc, per passage: setting k1=2.16, b=0.61') searcher.set_bm25(2.16, 0.61) elif index == 'msmarco-doc-expanded-per-doc' or \ index == 'msmarco-v1-doc-d2q-t5' or \ index == 'msmarco-v1-doc-d2q-t5-docvectors': # See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-doc-docTTTTTquery.md print('MS MARCO doc w/ doc2query-T5 (per doc) expansion: setting k1=4.68, b=0.87') searcher.set_bm25(4.68, 0.87) elif index == 'msmarco-doc-expanded-per-passage' or \ index == 'msmarco-v1-doc-segmented-d2q-t5' or \ index == 'msmarco-v1-doc-segmented-d2q-t5-docvectors': # See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-doc-segmented-docTTTTTquery.md print('MS MARCO doc w/ doc2query-T5 (per passage) expansion: setting k1=2.56, b=0.59') searcher.set_bm25(2.56, 0.59) def define_search_args(parser): parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, help="Path to Lucene index or name of prebuilt index.") parser.add_argument('--encoded-corpus', type=str, default=None, help="path to stored sparse vectors") parser.add_argument('--impact', action='store_true', help="Use Impact.") parser.add_argument('--encoder', type=str, default=None, help="encoder name") parser.add_argument('--min-idf', type=int, default=0, help="minimum idf") parser.add_argument('--bm25', action='store_true', default=True, help="Use BM25 (default).") parser.add_argument('--k1', type=float, help='BM25 k1 parameter.') parser.add_argument('--b', type=float, help='BM25 b parameter.') parser.add_argument('--rm3', action='store_true', help="Use RM3") parser.add_argument('--rocchio', action='store_true', help="Use Rocchio") parser.add_argument('--rocchio-use-negative', action='store_true', help="Use nonrelevant labels in Rocchio") parser.add_argument('--qld', action='store_true', help="Use QLD") parser.add_argument('--language', type=str, help='language code for BM25, e.g. zh for Chinese', default='en') parser.add_argument('--pretokenized', action='store_true', help="Boolean switch to accept pre-tokenized topics") parser.add_argument('--prcl', type=ClassifierType, nargs='+', default=[], help='Specify the classifier PseudoRelevanceClassifierReranker uses.') parser.add_argument('--prcl.vectorizer', dest='vectorizer', type=str, help='Type of vectorizer. Available: TfidfVectorizer, BM25Vectorizer.') parser.add_argument('--prcl.r', dest='r', type=int, default=10, help='Number of positive labels in pseudo relevance feedback.') parser.add_argument('--prcl.n', dest='n', type=int, default=100, help='Number of negative labels in pseudo relevance feedback.') parser.add_argument('--prcl.alpha', dest='alpha', type=float, default=0.5, help='Alpha value for interpolation in pseudo relevance feedback.') parser.add_argument('--fields', metavar="key=value", nargs='+', help='Fields to search with assigned float weights.') parser.add_argument('--dismax', action='store_true', default=False, help='Use disjunction max queries when searching multiple fields.') parser.add_argument('--dismax.tiebreaker', dest='tiebreaker', type=float, default=0.0, help='The tiebreaker weight to use in disjunction max queries.') parser.add_argument('--stopwords', type=str, help='Path to file with customstopwords.') if __name__ == "__main__": JLuceneSearcher = autoclass('io.anserini.search.SimpleSearcher') parser = argparse.ArgumentParser(description='Search a Lucene index.') define_search_args(parser) parser.add_argument('--topics', type=str, metavar='topic_name', required=True, help="Name of topics. Available: robust04, robust05, core17, core18.") parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.") parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value, help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") parser.add_argument('--output', type=str, metavar='path', help="Path to output file.") parser.add_argument('--max-passage', action='store_true', default=False, help="Select only max passage from document.") parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, help="Final number of hits when selecting only max passage.") parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#', help="Delimiter between docid and passage id.") parser.add_argument('--batch-size', type=int, metavar='num', required=False, default=1, help="Specify batch size to search the collection concurrently.") parser.add_argument('--threads', type=int, metavar='num', required=False, default=1, help="Maximum number of threads to use.") parser.add_argument('--tokenizer', type=str, help='tokenizer used to preprocess topics') parser.add_argument('--remove-duplicates', action='store_true', default=False, help="Remove duplicate docs.") # For some test collections, a query is doc from the corpus (e.g., arguana in BEIR). # We want to remove the query from the results. This is equivalent to -removeQuery in Java. parser.add_argument('--remove-query', action='store_true', default=False, help="Remove query from results list.") args = parser.parse_args() query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format)) topics = query_iterator.topics if not args.impact: if os.path.exists(args.index): # create searcher from index directory searcher = LuceneSearcher(args.index) else: # create searcher from prebuilt index name searcher = LuceneSearcher.from_prebuilt_index(args.index) elif args.impact: if os.path.exists(args.index): if args.encoded_corpus is not None: searcher = SlimSearcher(args.encoded_corpus, args.index, args.encoder, args.min_idf) else: searcher = LuceneImpactSearcher(args.index, args.encoder, args.min_idf) else: if args.encoded_corpus is not None: searcher = SlimSearcher.from_prebuilt_index(args.encoded_corpus, args.index, args.encoder, args.min_idf) else: searcher = LuceneImpactSearcher.from_prebuilt_index(args.index, args.encoder, args.min_idf) if args.language != 'en': searcher.set_language(args.language) if not searcher: exit() search_rankers = [] if args.qld: search_rankers.append('qld') searcher.set_qld() elif args.bm25: search_rankers.append('bm25') set_bm25_parameters(searcher, args.index, args.k1, args.b) if args.rm3: search_rankers.append('rm3') searcher.set_rm3() if args.rocchio: search_rankers.append('rocchio') if args.rocchio_use_negative: searcher.set_rocchio(gamma=0.15, use_negative=True) else: searcher.set_rocchio() fields = dict() if args.fields: fields = dict([pair.split('=') for pair in args.fields]) print(f'Searching over fields: {fields}') query_generator = None if args.dismax: query_generator = JDisjunctionMaxQueryGenerator(args.tiebreaker) print(f'Using dismax query generator with tiebreaker={args.tiebreaker}') if args.pretokenized: analyzer = JWhiteSpaceAnalyzer() searcher.set_analyzer(analyzer) if args.tokenizer is not None: raise ValueError(f"--tokenizer is not supported with when setting --pretokenized.") if args.tokenizer != None: analyzer = JWhiteSpaceAnalyzer() searcher.set_analyzer(analyzer) print(f'Using whitespace analyzer because of pretokenized topics') tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) print(f'Using {args.tokenizer} to preprocess topics') if args.stopwords: analyzer = JDefaultEnglishAnalyzer.fromArguments('porter', False, args.stopwords) searcher.set_analyzer(analyzer) print(f'Using custom stopwords={args.stopwords}') # get re-ranker use_prcl = args.prcl and len(args.prcl) > 0 and args.alpha > 0 if use_prcl is True: ranker = PseudoRelevanceClassifierReranker( searcher.index_dir, args.vectorizer, args.prcl, r=args.r, n=args.n, alpha=args.alpha) # build output path output_path = args.output if output_path is None: if use_prcl is True: clf_rankers = [] for t in args.prcl: if t == ClassifierType.LR: clf_rankers.append('lr') elif t == ClassifierType.SVM: clf_rankers.append('svm') r_str = f'prcl.r_{args.r}' n_str = f'prcl.n_{args.n}' a_str = f'prcl.alpha_{args.alpha}' clf_str = 'prcl_' + '+'.join(clf_rankers) tokens1 = ['run', args.topics, '+'.join(search_rankers)] tokens2 = [args.vectorizer, clf_str, r_str, n_str, a_str] output_path = '.'.join(tokens1) + '-' + '-'.join(tokens2) + ".txt" else: tokens = ['run', args.topics, '+'.join(search_rankers), 'txt'] output_path = '.'.join(tokens) print(f'Running {args.topics} topics, saving to {output_path}...') tag = output_path[:-4] if args.output is None else 'Anserini' output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w', max_hits=args.hits, tag=tag, topics=topics, use_max_passage=args.max_passage, max_passage_delimiter=args.max_passage_delimiter, max_passage_hits=args.max_passage_hits) with output_writer: batch_topics = list() batch_topic_ids = list() for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): if (args.tokenizer != None): toks = tokenizer.tokenize(text) text = ' ' text = text.join(toks) if args.batch_size <= 1 and args.threads <= 1: if args.impact: hits = searcher.search(text, args.hits, fields=fields) else: hits = searcher.search(text, args.hits, query_generator=query_generator, fields=fields) results = [(topic_id, hits)] else: batch_topic_ids.append(str(topic_id)) batch_topics.append(text) if (index + 1) % args.batch_size == 0 or \ index == len(topics.keys()) - 1: if args.impact: results = searcher.batch_search( batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields ) else: results = searcher.batch_search( batch_topics, batch_topic_ids, args.hits, args.threads, query_generator=query_generator, fields=fields ) results = [(id_, results[id_]) for id_ in batch_topic_ids] batch_topic_ids.clear() batch_topics.clear() else: continue for topic, hits in results: # do rerank if use_prcl and len(hits) > (args.r + args.n): docids = [hit.docid.strip() for hit in hits] scores = [hit.score for hit in hits] scores, docids = ranker.rerank(docids, scores) docid_score_map = dict(zip(docids, scores)) for hit in hits: hit.score = docid_score_map[hit.docid.strip()] if args.remove_duplicates: seen_docids = set() dedup_hits = [] for hit in hits: if hit.docid.strip() in seen_docids: continue seen_docids.add(hit.docid.strip()) dedup_hits.append(hit) hits = dedup_hits # For some test collections, a query is doc from the corpus (e.g., arguana in BEIR). # We want to remove the query from the results. if args.remove_query: hits = [hit for hit in hits if hit.docid != topic] # write results output_writer.write(topic, hits) results.clear()