VagoX1 commited on
Commit
d4b2a30
1 Parent(s): 810634e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -583,24 +583,28 @@ return_type = List[Hit]
583
 
584
  ## YOUR_CODE_STARTS_HERE
585
 
586
- def search(query) -> List[Hit]:
587
- return_type: List[Hit] = []
588
- bm_25_retriever = BM25Retriever(index_dir="output/bm25_index")
589
- ranking = bm_25_retriever.retrieve(query)
590
- for rank in ranking:
591
- hit = {
592
- "cid": rank,
593
- "score": ranking[rank],
594
- "text": bm_25_retriever.index.doc_texts[bm_25_retriever.index.cid2docid[rank]]
595
- }
596
- return_type.append(hit)
597
-
598
- return return_type
 
 
599
 
600
  demo = gr.Interface(
601
  fn=search,
602
- inputs=["text"],
603
- outputs=gr.Textbox()
 
 
604
  )
605
 
606
  ## YOUR_CODE_ENDS_HERE
 
583
 
584
  ## YOUR_CODE_STARTS_HERE
585
 
586
+ def search(query: str) -> List[Hit]:
587
+ bm25_index = BM25Index.build_from_documents(
588
+ documents=iter(sciq.corpus),
589
+ ndocs=12160,
590
+ show_progress_bar=True
591
+ )
592
+ bm25_index.save("output/bm25_index")
593
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
594
+ ranking = bm25_retriever.retrieve(query=query)
595
+ hits = []
596
+ for cid, score in ranking.items():
597
+ doc = next((doc for doc in sciq.corpus if doc.collection_id == cid), None)
598
+ if doc:
599
+ hits.append({"cid": cid, "score": score, "text": doc.text})
600
+ return hits
601
 
602
  demo = gr.Interface(
603
  fn=search,
604
+ inputs=gr.Textbox(lines=2, placeholder="Enter your query here..."),
605
+ outputs=gr.JSON(label="Search Results"),
606
+ title="SciQ Search Engine",
607
+ description="Enter a query to search the SciQ dataset using BM25.",
608
  )
609
 
610
  ## YOUR_CODE_ENDS_HERE