NetsPresso_QA / search_online_demo_TEMPORARY.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
import argparse
from pyserini.search.faiss import AutoQueryEncoder, AnceQueryEncoder, DprQueryEncoder, TctColBertQueryEncoder
def _init_encoder_from_str(encoder, device='cpu'):
encoder_lower = encoder.lower()
if 'dpr' in encoder_lower:
return DprQueryEncoder(encoder_dir=encoder, device=device)
elif 'tct_colbert' in encoder_lower:
return TctColBertQueryEncoder(encoder_dir=encoder, device=device)
elif 'ance' in encoder_lower:
return AnceQueryEncoder(encoder_dir=encoder, device=device)
elif 'sentence' in encoder_lower:
return AutoQueryEncoder(encoder_dir=encoder, pooling='mean', l2_norm=True, device=device)
else:
return AutoQueryEncoder(encoder_dir=encoder, device=device)
def load_index(searcher_class, index_path, query_encoder=None):
if query_encoder is not None:
searcher = searcher_class(index_path, query_encoder)
else:
searcher = searcher_class(index_path)
return searcher
class OnlineSearcher(object):
def __init__(self, args):
self.args = args
# TEMPORARY: deal with missing args
if 'index_type' not in args:
args.index_type = 'hybrid'
if 'index' not in args:
args.index = '/root/indexes/mrtydi-korean/sparse,/root/indexes/mrtydi-korean/dense'
if 'lang_abbr' not in args:
args.lang_abbr = 'ko'
if 'encoder' not in args:
args.encoder = 'castorini/mdpr-question-nq'
if 'device' not in args:
args.device = 'cuda:0'
if 'alpha' not in args:
args.alpha = 0.5
if 'normalization' not in args:
args.normalization = True
if 'k' not in args:
args.k = 10
if args.index_type == 'sparse':
query_encoder = None
elif args.index_type == 'dense' or args.index_type == 'hybrid':
query_encoder = _init_encoder_from_str(encoder=args.encoder, device=args.device)
else:
raise ValueError(f"index_type {args.index_type} should be chosen among sparse, dense, or hybrid")
# load index
if args.index_type == 'hybrid':
args.index = args.index.split(',')
assert(len(args.index) == 2), "require both sparse and dense index delimited by comma"
from pyserini.search.lucene import LuceneSearcher
self.ssearcher = load_index(searcher_class=LuceneSearcher, index_path=args.index[0])
self.ssearcher.set_language(args.lang_abbr)
from pyserini.search.faiss import FaissSearcher
self.dsearcher = load_index(searcher_class=FaissSearcher, index_path=args.index[1], query_encoder=query_encoder)
from pyserini.search.hybrid import HybridSearcher
self.searcher = HybridSearcher(self.dsearcher, self.ssearcher)
print(f"load {self.ssearcher.num_docs} documents from {args.index}")
else:
if args.index_type == 'sparse':
from pyserini.search.lucene import LuceneSearcher as Searcher
elif args.index_type == 'dense':
from pyserini.search.faiss import FaissSearcher as Searcher
self.searcher = load_index(searcher_class=Searcher, index_path=args.index, query_encoder=query_encoder)
if args.index_type == 'sparse':
self.searcher.set_language(args.lang_abbr)
print(f"load {self.searcher.num_docs} documents from {args.index}")
def search(self, query, k=10):
if self.args.index_type == 'hybrid':
hits = self.searcher.search(query, alpha=self.args.alpha, normalization=self.args.normalization, k=k)
else:
hits = self.searcher.search(query)
return hits
def print_result(self, hits, k):
# Print the first 10 hits:
for i in range(0, k):
print(f'{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}')
if args.index_type == 'sparse': # faiss searcher does not store document raw text
doc = self.searcher.doc(hits[i].docid)
elif args.index_type == 'hybrid':
doc = self.searcher.sparse_searcher.doc(hits[i].docid)
else:
doc = None
if doc is not None:
print(doc.raw())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Search interactively')
parser.add_argument('--index_type', type=str, required=True,
help="choose indexing type", choices=['sparse', 'dense', 'hybrid'])
parser.add_argument('--index', type=str, required=True,
help="Path to index or name of prebuilt index.")
parser.add_argument('--query', type=str, required=True,
help="Query text")
parser.add_argument('--lang_abbr', type=str, required=False, default='ko',
help="for language specific algorithms for sparse retrieveal)")
parser.add_argument('--encoder', type=str, required=False,
help="encoder name or checkpoint path")
parser.add_argument('--device', type=str, required=False, default='cpu',
help="device to use for encoding queries (cf. pyserini does not support faiss-gpu)")
# for hybrid search
parser.add_argument('--alpha', type=float, default=0.5,
help="weight for hybrid search: alpha*score(sparse) + score(dense)")
parser.add_argument('--normalization', action='store_true',
help="normalize sparse & dens score before fusion")
# search range
parser.add_argument('--k', type=int, default=10,
help="the number of passages to return (default: 10)")
args = parser.parse_args()
# make searcher
searcher = OnlineSearcher(args)
print(f"given query: {args.query}")
# search
hits = searcher.search(args.query)
# print results
searcher.print_result(hits, args.k)