import time import json from pyserini.search.lucene import LuceneImpactSearcher, LuceneSearcher import streamlit as st from pathlib import Path import sys path_root = Path("./") sys.path.append(str(path_root)) encoder_index_map = { 'uniCOIL': ('UniCoil', 'castorini/unicoil-noexp-msmarco-passage', 'index-unicoil'), 'SPLADE++ Ensemble Distil': ('SpladePlusPlusEnsembleDistil', 'naver/splade-cocondenser-ensembledistil', 'index-splade-pp-ed'), 'SPLADE++ Self Distil': ('SpladePlusPlusSelfDistil', 'naver/splade-cocondenser-selfdistil', 'index-splade-pp-sd') } index = 'index-splade-pp-ed' encoder = 'SpladePlusPlusEnsembleDistil' encoder_index = 0 st.set_page_config(page_title="Pyserini with ONNX Runtime", page_icon='🌸', layout="centered") cola, colb, colc = st.columns([5, 4, 5]) with colb: st.image("logo.jpeg") colaa, colbb, colcc = st.columns([1, 8, 1]) with colbb: runtime = st.select_slider( 'Select a runtime type', options=['PyTorch', 'ONNX Runtime']) st.write('Now using: ', runtime) colaa, colbb, colcc = st.columns([1, 8, 1]) with colbb: encoder = st.select_slider( 'Select a query encoder', options=['uniCOIL', 'SPLADE++ Ensemble Distil', 'SPLADE++ Self Distil']) st.write('Now Running Encoder: ', encoder) if runtime == 'PyTorch': runtime = 'pytorch' runtime_index = 1 else: runtime = 'onnx' runtime_index = 0 encoder, index = encoder_index_map[encoder][runtime_index], encoder_index_map[encoder][2] searcher = LuceneImpactSearcher( f'indexes/{index}', f'{encoder}', encoder_type=f'{runtime}') corpus = LuceneSearcher(f'indexes/index-unicoil') col1, col2 = st.columns([9, 1]) with col1: search_query = st.text_input(label="search query", placeholder="Search") with col2: st.write('#') button_clicked = st.button("🔎") if search_query or button_clicked: num_results = None t_0 = time.time() search_results = searcher.search(search_query, k=10) search_time = time.time() - t_0 st.write( f'

Retrieved {len(search_results):,.0f} documents in {search_time*1000:.2f} ms

', unsafe_allow_html=True) for i, result in enumerate(search_results[:10]): result_score = result.score result_id = result.docid contents = json.loads(result.raw) contents = contents['contents'] if 'contents' in contents else contents['content'] if contents == "": contents = json.loads(corpus.doc(result_id).raw())['contents'] output = f'
Rank: {i+1} | Document ID: {result_id} | Score:{result_score:.2f}
' try: st.write(output, unsafe_allow_html=True) st.write( f'
{contents}
', unsafe_allow_html=True) except: pass st.write('---')