geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
No virus
16.8 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import os
from tqdm import tqdm
from transformers import AutoTokenizer
from pyserini.analysis import JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer
from pyserini.output_writer import OutputFormat, get_output_writer
from pyserini.pyclass import autoclass
from pyserini.query_iterator import get_query_iterator, TopicsFormat
from pyserini.search import JDisjunctionMaxQueryGenerator
from . import LuceneImpactSearcher, LuceneSearcher, SlimSearcher
from .reranker import ClassifierType, PseudoRelevanceClassifierReranker
def set_bm25_parameters(searcher, index, k1=None, b=None):
if k1 is not None or b is not None:
if k1 is None or b is None:
print('Must set *both* k1 and b for BM25!')
exit()
print(f'Setting BM25 parameters: k1={k1}, b={b}')
searcher.set_bm25(k1, b)
else:
# Automatically set bm25 parameters based on known index...
if index == 'msmarco-passage' or index == 'msmarco-passage-slim' or index == 'msmarco-v1-passage' or \
index == 'msmarco-v1-passage-slim' or index == 'msmarco-v1-passage-full':
# See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-passage.md
print('MS MARCO passage: setting k1=0.82, b=0.68')
searcher.set_bm25(0.82, 0.68)
elif index == 'msmarco-passage-expanded' or \
index == 'msmarco-v1-passage-d2q-t5' or \
index == 'msmarco-v1-passage-d2q-t5-docvectors':
# See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-passage-docTTTTTquery.md
print('MS MARCO passage w/ doc2query-T5 expansion: setting k1=2.18, b=0.86')
searcher.set_bm25(2.18, 0.86)
elif index == 'msmarco-doc' or index == 'msmarco-doc-slim' or index == 'msmarco-v1-doc' or \
index == 'msmarco-v1-doc-slim' or index == 'msmarco-v1-doc-full':
# See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-doc.md
print('MS MARCO doc: setting k1=4.46, b=0.82')
searcher.set_bm25(4.46, 0.82)
elif index == 'msmarco-doc-per-passage' or index == 'msmarco-doc-per-passage-slim' or \
index == 'msmarco-v1-doc-segmented' or index == 'msmarco-v1-doc-segmented-slim' or \
index == 'msmarco-v1-doc-segmented-full':
# See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-doc-segmented.md
print('MS MARCO doc, per passage: setting k1=2.16, b=0.61')
searcher.set_bm25(2.16, 0.61)
elif index == 'msmarco-doc-expanded-per-doc' or \
index == 'msmarco-v1-doc-d2q-t5' or \
index == 'msmarco-v1-doc-d2q-t5-docvectors':
# See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-doc-docTTTTTquery.md
print('MS MARCO doc w/ doc2query-T5 (per doc) expansion: setting k1=4.68, b=0.87')
searcher.set_bm25(4.68, 0.87)
elif index == 'msmarco-doc-expanded-per-passage' or \
index == 'msmarco-v1-doc-segmented-d2q-t5' or \
index == 'msmarco-v1-doc-segmented-d2q-t5-docvectors':
# See https://github.com/castorini/anserini/blob/master/docs/regressions-msmarco-doc-segmented-docTTTTTquery.md
print('MS MARCO doc w/ doc2query-T5 (per passage) expansion: setting k1=2.56, b=0.59')
searcher.set_bm25(2.56, 0.59)
def define_search_args(parser):
parser.add_argument('--index', type=str, metavar='path to index or index name', required=True,
help="Path to Lucene index or name of prebuilt index.")
parser.add_argument('--encoded-corpus', type=str, default=None, help="path to stored sparse vectors")
parser.add_argument('--impact', action='store_true', help="Use Impact.")
parser.add_argument('--encoder', type=str, default=None, help="encoder name")
parser.add_argument('--min-idf', type=int, default=0, help="minimum idf")
parser.add_argument('--bm25', action='store_true', default=True, help="Use BM25 (default).")
parser.add_argument('--k1', type=float, help='BM25 k1 parameter.')
parser.add_argument('--b', type=float, help='BM25 b parameter.')
parser.add_argument('--rm3', action='store_true', help="Use RM3")
parser.add_argument('--rocchio', action='store_true', help="Use Rocchio")
parser.add_argument('--rocchio-use-negative', action='store_true', help="Use nonrelevant labels in Rocchio")
parser.add_argument('--qld', action='store_true', help="Use QLD")
parser.add_argument('--language', type=str, help='language code for BM25, e.g. zh for Chinese', default='en')
parser.add_argument('--pretokenized', action='store_true', help="Boolean switch to accept pre-tokenized topics")
parser.add_argument('--prcl', type=ClassifierType, nargs='+', default=[],
help='Specify the classifier PseudoRelevanceClassifierReranker uses.')
parser.add_argument('--prcl.vectorizer', dest='vectorizer', type=str,
help='Type of vectorizer. Available: TfidfVectorizer, BM25Vectorizer.')
parser.add_argument('--prcl.r', dest='r', type=int, default=10,
help='Number of positive labels in pseudo relevance feedback.')
parser.add_argument('--prcl.n', dest='n', type=int, default=100,
help='Number of negative labels in pseudo relevance feedback.')
parser.add_argument('--prcl.alpha', dest='alpha', type=float, default=0.5,
help='Alpha value for interpolation in pseudo relevance feedback.')
parser.add_argument('--fields', metavar="key=value", nargs='+',
help='Fields to search with assigned float weights.')
parser.add_argument('--dismax', action='store_true', default=False,
help='Use disjunction max queries when searching multiple fields.')
parser.add_argument('--dismax.tiebreaker', dest='tiebreaker', type=float, default=0.0,
help='The tiebreaker weight to use in disjunction max queries.')
parser.add_argument('--stopwords', type=str, help='Path to file with customstopwords.')
if __name__ == "__main__":
JLuceneSearcher = autoclass('io.anserini.search.SimpleSearcher')
parser = argparse.ArgumentParser(description='Search a Lucene index.')
define_search_args(parser)
parser.add_argument('--topics', type=str, metavar='topic_name', required=True,
help="Name of topics. Available: robust04, robust05, core17, core18.")
parser.add_argument('--hits', type=int, metavar='num',
required=False, default=1000, help="Number of hits.")
parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value,
help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}")
parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value,
help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}")
parser.add_argument('--output', type=str, metavar='path',
help="Path to output file.")
parser.add_argument('--max-passage', action='store_true',
default=False, help="Select only max passage from document.")
parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100,
help="Final number of hits when selecting only max passage.")
parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#',
help="Delimiter between docid and passage id.")
parser.add_argument('--batch-size', type=int, metavar='num', required=False,
default=1, help="Specify batch size to search the collection concurrently.")
parser.add_argument('--threads', type=int, metavar='num', required=False,
default=1, help="Maximum number of threads to use.")
parser.add_argument('--tokenizer', type=str, help='tokenizer used to preprocess topics')
parser.add_argument('--remove-duplicates', action='store_true', default=False, help="Remove duplicate docs.")
# For some test collections, a query is doc from the corpus (e.g., arguana in BEIR).
# We want to remove the query from the results. This is equivalent to -removeQuery in Java.
parser.add_argument('--remove-query', action='store_true', default=False, help="Remove query from results list.")
args = parser.parse_args()
query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format))
topics = query_iterator.topics
if not args.impact:
if os.path.exists(args.index):
# create searcher from index directory
searcher = LuceneSearcher(args.index)
else:
# create searcher from prebuilt index name
searcher = LuceneSearcher.from_prebuilt_index(args.index)
elif args.impact:
if os.path.exists(args.index):
if args.encoded_corpus is not None:
searcher = SlimSearcher(args.encoded_corpus, args.index, args.encoder, args.min_idf)
else:
searcher = LuceneImpactSearcher(args.index, args.encoder, args.min_idf)
else:
if args.encoded_corpus is not None:
searcher = SlimSearcher.from_prebuilt_index(args.encoded_corpus, args.index, args.encoder, args.min_idf)
else:
searcher = LuceneImpactSearcher.from_prebuilt_index(args.index, args.encoder, args.min_idf)
if args.language != 'en':
searcher.set_language(args.language)
if not searcher:
exit()
search_rankers = []
if args.qld:
search_rankers.append('qld')
searcher.set_qld()
elif args.bm25:
search_rankers.append('bm25')
set_bm25_parameters(searcher, args.index, args.k1, args.b)
if args.rm3:
search_rankers.append('rm3')
searcher.set_rm3()
if args.rocchio:
search_rankers.append('rocchio')
if args.rocchio_use_negative:
searcher.set_rocchio(gamma=0.15, use_negative=True)
else:
searcher.set_rocchio()
fields = dict()
if args.fields:
fields = dict([pair.split('=') for pair in args.fields])
print(f'Searching over fields: {fields}')
query_generator = None
if args.dismax:
query_generator = JDisjunctionMaxQueryGenerator(args.tiebreaker)
print(f'Using dismax query generator with tiebreaker={args.tiebreaker}')
if args.pretokenized:
analyzer = JWhiteSpaceAnalyzer()
searcher.set_analyzer(analyzer)
if args.tokenizer is not None:
raise ValueError(f"--tokenizer is not supported with when setting --pretokenized.")
if args.tokenizer != None:
analyzer = JWhiteSpaceAnalyzer()
searcher.set_analyzer(analyzer)
print(f'Using whitespace analyzer because of pretokenized topics')
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
print(f'Using {args.tokenizer} to preprocess topics')
if args.stopwords:
analyzer = JDefaultEnglishAnalyzer.fromArguments('porter', False, args.stopwords)
searcher.set_analyzer(analyzer)
print(f'Using custom stopwords={args.stopwords}')
# get re-ranker
use_prcl = args.prcl and len(args.prcl) > 0 and args.alpha > 0
if use_prcl is True:
ranker = PseudoRelevanceClassifierReranker(
searcher.index_dir, args.vectorizer, args.prcl, r=args.r, n=args.n, alpha=args.alpha)
# build output path
output_path = args.output
if output_path is None:
if use_prcl is True:
clf_rankers = []
for t in args.prcl:
if t == ClassifierType.LR:
clf_rankers.append('lr')
elif t == ClassifierType.SVM:
clf_rankers.append('svm')
r_str = f'prcl.r_{args.r}'
n_str = f'prcl.n_{args.n}'
a_str = f'prcl.alpha_{args.alpha}'
clf_str = 'prcl_' + '+'.join(clf_rankers)
tokens1 = ['run', args.topics, '+'.join(search_rankers)]
tokens2 = [args.vectorizer, clf_str, r_str, n_str, a_str]
output_path = '.'.join(tokens1) + '-' + '-'.join(tokens2) + ".txt"
else:
tokens = ['run', args.topics, '+'.join(search_rankers), 'txt']
output_path = '.'.join(tokens)
print(f'Running {args.topics} topics, saving to {output_path}...')
tag = output_path[:-4] if args.output is None else 'Anserini'
output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w',
max_hits=args.hits, tag=tag, topics=topics,
use_max_passage=args.max_passage,
max_passage_delimiter=args.max_passage_delimiter,
max_passage_hits=args.max_passage_hits)
with output_writer:
batch_topics = list()
batch_topic_ids = list()
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))):
if (args.tokenizer != None):
toks = tokenizer.tokenize(text)
text = ' '
text = text.join(toks)
if args.batch_size <= 1 and args.threads <= 1:
if args.impact:
hits = searcher.search(text, args.hits, fields=fields)
else:
hits = searcher.search(text, args.hits, query_generator=query_generator, fields=fields)
results = [(topic_id, hits)]
else:
batch_topic_ids.append(str(topic_id))
batch_topics.append(text)
if (index + 1) % args.batch_size == 0 or \
index == len(topics.keys()) - 1:
if args.impact:
results = searcher.batch_search(
batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields
)
else:
results = searcher.batch_search(
batch_topics, batch_topic_ids, args.hits, args.threads,
query_generator=query_generator, fields=fields
)
results = [(id_, results[id_]) for id_ in batch_topic_ids]
batch_topic_ids.clear()
batch_topics.clear()
else:
continue
for topic, hits in results:
# do rerank
if use_prcl and len(hits) > (args.r + args.n):
docids = [hit.docid.strip() for hit in hits]
scores = [hit.score for hit in hits]
scores, docids = ranker.rerank(docids, scores)
docid_score_map = dict(zip(docids, scores))
for hit in hits:
hit.score = docid_score_map[hit.docid.strip()]
if args.remove_duplicates:
seen_docids = set()
dedup_hits = []
for hit in hits:
if hit.docid.strip() in seen_docids:
continue
seen_docids.add(hit.docid.strip())
dedup_hits.append(hit)
hits = dedup_hits
# For some test collections, a query is doc from the corpus (e.g., arguana in BEIR).
# We want to remove the query from the results.
if args.remove_query:
hits = [hit for hit in hits if hit.docid != topic]
# write results
output_writer.write(topic, hits)
results.clear()