geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
3.41 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import cmd
import json
import random
from pyserini.search.lucene import LuceneSearcher
from pyserini.search.faiss import FaissSearcher, DprQueryEncoder
from pyserini.search.hybrid import HybridSearcher
from pyserini import search
class DPRDemo(cmd.Cmd):
nq_dev_topics = list(search.get_topics('dpr-nq-dev').values())
trivia_dev_topics = list(search.get_topics('dpr-trivia-dev').values())
ssearcher = LuceneSearcher.from_prebuilt_index('wikipedia-dpr')
searcher = ssearcher
encoder = DprQueryEncoder("facebook/dpr-question_encoder-multiset-base")
index = 'wikipedia-dpr-multi-bf'
dsearcher = FaissSearcher.from_prebuilt_index(
index,
encoder
)
hsearcher = HybridSearcher(dsearcher, ssearcher)
k = 10
prompt = '>>> '
def precmd(self, line):
if line[0] == '/':
line = line[1:]
return line
def do_help(self, arg):
print(f'/help : returns this message')
print(f'/k [NUM] : sets k (number of hits to return) to [NUM]')
print(f'/mode [MODE] : sets retriever type to [MODE] (one of sparse, dense, hybrid)')
print(f'/random [COLLECTION]: returns results for a random question from the dev subset [COLLECTION] (one of nq, trivia).')
def do_k(self, arg):
print(f'setting k = {int(arg)}')
self.k = int(arg)
def do_mode(self, arg):
if arg == "sparse":
self.searcher = self.ssearcher
elif arg == "dense":
self.searcher = self.dsearcher
elif arg == "hybrid":
self.searcher = self.hsearcher
else:
print(
f'Mode "{arg}" is invalid. Mode should be one of [sparse, dense, hybrid].')
return
print(f'setting retriver = {arg}')
def do_random(self, arg):
if arg == "nq":
topics = self.nq_dev_topics
elif arg == "trivia":
topics = self.trivia_dev_topics
else:
print(
f'Collection "{arg}" is invalid. Collection should be one of [nq, trivia].')
return
q = random.choice(topics)['title']
print(f'question: {q}')
self.default(q)
def do_EOF(self, line):
return True
def default(self, q):
hits = self.searcher.search(q, self.k)
for i in range(0, len(hits)):
raw_doc = None
if isinstance(self.searcher, LuceneSearcher):
raw_doc = hits[i].raw
else:
doc = self.searcher.doc(hits[i].docid)
if doc:
raw_doc = doc.raw()
jsondoc = json.loads(raw_doc)
print(f'{i + 1:2} {hits[i].score:.5f} {jsondoc["contents"]}')
if __name__ == '__main__':
DPRDemo().cmdloop()