import argparse from pyserini.search.faiss import AutoQueryEncoder, AnceQueryEncoder, DprQueryEncoder, TctColBertQueryEncoder def _init_encoder_from_str(encoder, device='cpu'): encoder_lower = encoder.lower() if 'dpr' in encoder_lower: return DprQueryEncoder(encoder_dir=encoder, device=device) elif 'tct_colbert' in encoder_lower: return TctColBertQueryEncoder(encoder_dir=encoder, device=device) elif 'ance' in encoder_lower: return AnceQueryEncoder(encoder_dir=encoder, device=device) elif 'sentence' in encoder_lower: return AutoQueryEncoder(encoder_dir=encoder, pooling='mean', l2_norm=True, device=device) else: return AutoQueryEncoder(encoder_dir=encoder, device=device) def load_index(searcher_class, index_path, query_encoder=None): if query_encoder is not None: searcher = searcher_class(index_path, query_encoder) else: searcher = searcher_class(index_path) return searcher class OnlineSearcher(object): def __init__(self, args): self.args = args # TEMPORARY: deal with missing args if 'index_type' not in args: args.index_type = 'hybrid' if 'index' not in args: args.index = '/root/indexes/mrtydi-korean/sparse,/root/indexes/mrtydi-korean/dense' if 'lang_abbr' not in args: args.lang_abbr = 'ko' if 'encoder' not in args: args.encoder = 'castorini/mdpr-question-nq' if 'device' not in args: args.device = 'cuda:0' if 'alpha' not in args: args.alpha = 0.5 if 'normalization' not in args: args.normalization = True if 'k' not in args: args.k = 10 if args.index_type == 'sparse': query_encoder = None elif args.index_type == 'dense' or args.index_type == 'hybrid': query_encoder = _init_encoder_from_str(encoder=args.encoder, device=args.device) else: raise ValueError(f"index_type {args.index_type} should be chosen among sparse, dense, or hybrid") # load index if args.index_type == 'hybrid': args.index = args.index.split(',') assert(len(args.index) == 2), "require both sparse and dense index delimited by comma" from pyserini.search.lucene import LuceneSearcher self.ssearcher = load_index(searcher_class=LuceneSearcher, index_path=args.index[0]) self.ssearcher.set_language(args.lang_abbr) from pyserini.search.faiss import FaissSearcher self.dsearcher = load_index(searcher_class=FaissSearcher, index_path=args.index[1], query_encoder=query_encoder) from pyserini.search.hybrid import HybridSearcher self.searcher = HybridSearcher(self.dsearcher, self.ssearcher) print(f"load {self.ssearcher.num_docs} documents from {args.index}") else: if args.index_type == 'sparse': from pyserini.search.lucene import LuceneSearcher as Searcher elif args.index_type == 'dense': from pyserini.search.faiss import FaissSearcher as Searcher self.searcher = load_index(searcher_class=Searcher, index_path=args.index, query_encoder=query_encoder) if args.index_type == 'sparse': self.searcher.set_language(args.lang_abbr) print(f"load {self.searcher.num_docs} documents from {args.index}") def search(self, query, k=10): if self.args.index_type == 'hybrid': hits = self.searcher.search(query, alpha=self.args.alpha, normalization=self.args.normalization, k=k) else: hits = self.searcher.search(query) return hits def print_result(self, hits, k): # Print the first 10 hits: for i in range(0, k): print(f'{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}') if args.index_type == 'sparse': # faiss searcher does not store document raw text doc = self.searcher.doc(hits[i].docid) elif args.index_type == 'hybrid': doc = self.searcher.sparse_searcher.doc(hits[i].docid) else: doc = None if doc is not None: print(doc.raw()) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Search interactively') parser.add_argument('--index_type', type=str, required=True, help="choose indexing type", choices=['sparse', 'dense', 'hybrid']) parser.add_argument('--index', type=str, required=True, help="Path to index or name of prebuilt index.") parser.add_argument('--query', type=str, required=True, help="Query text") parser.add_argument('--lang_abbr', type=str, required=False, default='ko', help="for language specific algorithms for sparse retrieveal)") parser.add_argument('--encoder', type=str, required=False, help="encoder name or checkpoint path") parser.add_argument('--device', type=str, required=False, default='cpu', help="device to use for encoding queries (cf. pyserini does not support faiss-gpu)") # for hybrid search parser.add_argument('--alpha', type=float, default=0.5, help="weight for hybrid search: alpha*score(sparse) + score(dense)") parser.add_argument('--normalization', action='store_true', help="normalize sparse & dens score before fusion") # search range parser.add_argument('--k', type=int, default=10, help="the number of passages to return (default: 10)") args = parser.parse_args() # make searcher searcher = OnlineSearcher(args) print(f"given query: {args.query}") # search hits = searcher.search(args.query) # print results searcher.print_result(hits, args.k)