pyserini-bm25 / app.py
orionweller's picture
docs
dc4d5eb
raw
history blame contribute delete
No virus
3.44 kB
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/default-java"
import gradio as gr
from pyserini.search.lucene import LuceneSearcher
import os
import json
def initialize_searcher(index_name):
if not os.path.exists(index_name):
os.system(f'python -c "from pyserini.search import LuceneSearcher; LuceneSearcher.from_prebuilt_index(\'{index_name}\')"')
searcher = LuceneSearcher.from_prebuilt_index(index_name)
searcher.set_bm25(k1=0.9, b=0.4)
return searcher
def search_pyserini(query, top_k, index_name):
try:
searcher = initialize_searcher(index_name)
hits = searcher.search(query, k=top_k)
results = []
for i, hit in enumerate(hits):
doc = searcher.doc(hit.docid)
doc_dict = json.loads(doc.raw())
results.append({
"rank": i + 1,
"doc_id": hit.docid,
"score": hit.score,
"content": doc_dict['contents']
})
return format_results(results)
except Exception as e:
return f"<div class='error'>An error occurred: {str(e)}</div>"
def format_results(results):
html = "<div class='results-container'>"
for result in results:
html += f"""
<div class='result-item'>
<h3>Rank {result['rank']} (Score: {result['score']:.4f})</h3>
<p class='doc-id'>Doc ID: {result['doc_id']}</p>
<p class='content'>{result['content']}</p>
</div>
"""
html += "</div>"
return html
css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
.results-container {
display: flex;
flex-direction: column;
gap: 20px;
}
.result-item {
border: 1px solid #ddd;
border-radius: 8px;
padding: 15px;
width: 100%;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.result-item h3 {
margin-top: 0;
color: #333;
}
.doc-id {
font-size: 0.9em;
color: #666;
margin-bottom: 10px;
}
.content {
font-size: 0.95em;
line-height: 1.4;
}
.error {
color: red;
font-weight: bold;
}
"""
with gr.Blocks(css=css) as iface:
gr.Markdown("# Pyserini Search Interface")
gr.Markdown("Enter a query to search using Pyserini with BM25 scoring (k1=0.9, b=0.4). See all possible prebuild index names at [https://github.com/castorini/pyserini/blob/master/docs/prebuilt-indexes.md#standard-lucene-indexes](https://github.com/castorini/pyserini/blob/master/docs/prebuilt-indexes.md#standard-lucene-indexes)")
with gr.Row():
index_input = gr.Textbox(
value="msmarco-passage",
lines=1,
label="Prebuilt Index Name",
placeholder="Enter the name of the prebuilt index"
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
value=10,
step=1,
label="Number of top results to return"
)
with gr.Row():
query_input = gr.Textbox(
lines=1,
placeholder="Enter your search query here...",
label="Search Query"
)
with gr.Row():
search_button = gr.Button("Search", variant="primary")
with gr.Row():
output = gr.HTML(label="Search Results")
search_button.click(
fn=search_pyserini,
inputs=[query_input, top_k_slider, index_input],
outputs=output
)
iface.launch()