# # Pyserini: Reproducible IR research with sparse and dense representations # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import cmd import json import os import random from pyserini.search.lucene import LuceneSearcher from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder, AnceQueryEncoder from pyserini.search.hybrid import HybridSearcher from pyserini import search class MsMarcoDemo(cmd.Cmd): dev_topics = list(search.get_topics('msmarco-passage-dev-subset').values()) ssearcher = LuceneSearcher.from_prebuilt_index('msmarco-passage') dsearcher = None hsearcher = None searcher = ssearcher k = 10 prompt = '>>> ' # https://stackoverflow.com/questions/35213134/command-prefixes-in-python-cli-using-cmd-in-pythons-standard-library def precmd(self, line): if line[0] == '/': line = line[1:] return line def do_help(self, arg): print(f'/help : returns this message') print(f'/k [NUM] : sets k (number of hits to return) to [NUM]') print(f'/model [MODEL] : sets encoder to use the model [MODEL] (one of tct, ance)') print(f'/mode [MODE] : sets retriever type to [MODE] (one of sparse, dense, hybrid)') print(f'/random : returns results for a random question from dev subset') def do_k(self, arg): print(f'setting k = {int(arg)}') self.k = int(arg) def do_mode(self, arg): if arg == "sparse": self.searcher = self.ssearcher elif arg == "dense": if self.dsearcher is None: print(f'Specify model through /model before using dense retrieval.') return self.searcher = self.dsearcher elif arg == "hybrid": if self.hsearcher is None: print(f'Specify model through /model before using hybrid retrieval.') return self.searcher = self.hsearcher else: print( f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].') return print(f'setting retriver = {arg}') def do_model(self, arg): if arg == "tct": encoder = TctColBertQueryEncoder("castorini/tct_colbert-msmarco") index = "msmarco-passage-tct_colbert-hnsw" elif arg == "ance": encoder = AnceQueryEncoder("castorini/ance-msmarco-passage") index = "msmarco-passage-ance-bf" else: print( f'Model "{arg}" is invalid. Model should be one of [tct, ance].') return self.dsearcher = FaissSearcher.from_prebuilt_index( index, encoder ) self.hsearcher = HybridSearcher(self.dsearcher, self.ssearcher) print(f'setting model = {arg}') def do_random(self, arg): q = random.choice(self.dev_topics)['title'] print(f'question: {q}') self.default(q) def do_EOF(self, line): return True def default(self, q): hits = self.searcher.search(q, self.k) for i in range(0, len(hits)): raw_doc = None if isinstance(self.searcher, LuceneSearcher): raw_doc = hits[i].raw else: doc = self.searcher.doc(hits[i].docid) if doc: raw_doc = doc.raw() jsondoc = json.loads(raw_doc) print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}') if __name__ == '__main__': MsMarcoDemo().cmdloop()