pratyushpaliwal's picture
Update app.py
444cbed verified
import gradio as gr
from typing import List, TypedDict
from BM25Retriever import BM25Retriever
# from CSCBM25Retriever import CSCBM25Retriever # Replace with your retriever module
class Hit(TypedDict):
cid: str
score: float
text: str
# Load the pre-built index
# index_dir = "output/csc_bm25_index"
# retriever = CSCBM25Retriever(index_dir=index_dir)
retriever = BM25Retriever(index_dir="output/bm25_index")
# retriever = BM25Retriever(index_dir="output/bm25_index")
# Define the search function
def search(query: str) -> List[Hit]:
"""
Search function to query the BM25 index.
Args:
query: Input query string.
Returns:
List of search results ranked by relevance.
"""
results = retriever.retrieve(query, topk=10) # Retrieve top 10 results
hits = [
{
"cid": cid,
"score": score,
"text": retriever.index.doc_texts[retriever.index.cid2docid[cid]],
}
for cid, score in results.items()
]
return hits
# Define the Gradio interface
demo = gr.Interface(
fn=search, # Function to execute on input
inputs=gr.Textbox(label="Query"), # Single input text box for the query
outputs=gr.JSON(label="Search Results"), # Output the results as JSON
title="BM25 Search Engine Demo",
description="Demo of a BM25-based search engine using a sparse index on the SciQ dataset.",
)
# Launch the app
if __name__ == "__main__":
demo.launch()