import argparse import json from pyserini.search.faiss import ( AutoQueryEncoder, AnceQueryEncoder, DprQueryEncoder, TctColBertQueryEncoder, ) def _init_encoder_from_str(encoder, device="cpu"): encoder_lower = encoder.lower() if "dpr" in encoder_lower: return DprQueryEncoder(encoder_dir=encoder, device=device) elif "tct_colbert" in encoder_lower: return TctColBertQueryEncoder(encoder_dir=encoder, device=device) elif "ance" in encoder_lower: return AnceQueryEncoder(encoder_dir=encoder, device=device) elif "sentence" in encoder_lower: return AutoQueryEncoder( encoder_dir=encoder, pooling="mean", l2_norm=True, device=device ) else: return AutoQueryEncoder(encoder_dir=encoder, device=device) def load_index(searcher_class, index_dir, query_encoder=None): if query_encoder is not None: searcher = searcher_class(index_dir=index_dir, query_encoder=query_encoder) else: searcher = searcher_class(index_dir=index_dir) return searcher class OnlineSearcher(object): def __init__(self, args): self.args = args if args.index_type == "sparse": query_encoder = None elif args.index_type == "dense" or args.index_type == "hybrid": query_encoder = _init_encoder_from_str( encoder=args.encoder, device=args.device ) else: raise ValueError( f"index_type {args.index_type} should be chosen among sparse, dense, or hybrid" ) # load index if args.index_type == "hybrid": args.index = args.index.split(",") assert ( len(args.index) == 2 ), "require both sparse and dense index delimited by comma" from pyserini.search.lucene import LuceneSearcher self.ssearcher = load_index( searcher_class=LuceneSearcher, index_dir=args.index[0] ) self.ssearcher.set_language(args.lang_abbr) from pyserini.search.faiss import FaissSearcher self.dsearcher = load_index( searcher_class=FaissSearcher, index_dir=args.index[1], query_encoder=query_encoder, ) from pyserini.search.hybrid import HybridSearcher self.searcher = HybridSearcher(self.dsearcher, self.ssearcher) print(f"load {self.ssearcher.num_docs} documents from {args.index}") else: if args.index_type == "sparse": from pyserini.search.lucene import LuceneSearcher as Searcher elif args.index_type == "dense": from pyserini.search.faiss import FaissSearcher as Searcher self.searcher = load_index( searcher_class=Searcher, index_dir=args.index, query_encoder=query_encoder, ) if args.index_type == "sparse": self.searcher.set_language(args.lang_abbr) print(f"load {self.searcher.num_docs} documents from {args.index}") def search(self, query, k=10): if self.args.index_type == "hybrid": hits = self.searcher.search( query, alpha=self.args.alpha, normalization=self.args.normalization, k=k ) else: hits = self.searcher.search(query) return hits def print_result(self, hits, k): # Print the first k hits: docs = [] for i in range(0, min(k, len(hits))): print(f"{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}") if ( self.args.index_type == "sparse" ): # faiss searcher does not store document raw text doc = self.searcher.doc(hits[i].docid) elif self.args.index_type == "hybrid": doc = self.searcher.sparse_searcher.doc(hits[i].docid) else: doc = None if doc is not None and not self.args.hide_text: doc_raw = doc.raw() docs.append(json.loads(doc_raw)) print(doc_raw) docs = "\n\n".join( [f'문서 {idx+1}\n{doc["contents"]}' for idx, doc in enumerate(docs)] ) return docs if __name__ == "__main__": parser = argparse.ArgumentParser(description="Search interactively") parser.add_argument( "--index_type", type=str, required=True, help="choose indexing type", choices=["sparse", "dense", "hybrid"], ) parser.add_argument( "--index", type=str, required=True, help="Path to index or name of prebuilt index.", ) parser.add_argument("--query", type=str, required=True, help="Query text") parser.add_argument( "--lang_abbr", type=str, required=False, default="ko", help="for language specific algorithms for sparse retrieveal)", ) parser.add_argument( "--encoder", type=str, required=False, help="encoder name or checkpoint path" ) parser.add_argument( "--device", type=str, required=False, default="cpu", help="device to use for encoding queries (cf. pyserini does not support faiss-gpu)", ) # for hybrid search parser.add_argument( "--alpha", type=float, default=0.5, help="weight for hybrid search: alpha*score(sparse) + score(dense)", ) parser.add_argument( "--normalization", action="store_true", help="normalize sparse & dens score before fusion", ) # search range parser.add_argument( "--k", type=int, default=10, help="the number of passages to return (default: 10)", ) # print option parser.add_argument( "--hide_text", action="store_true", help="do not print if this is true" ) args = parser.parse_args() # make searcher searcher = OnlineSearcher(args) print(f"given query: {args.query}") # search hits = searcher.search(args.query) # print results searcher.print_result(hits, args.k)