from __future__ import annotations from dataclasses import dataclass import pickle import os from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar from nlp4web_codebase.ir.data_loaders.dm import Document from collections import Counter import tqdm import re import nltk nltk.download("stopwords", quiet=True) from nltk.corpus import stopwords as nltk_stopwords from nlp4web_codebase.ir.data_loaders.sciq import load_sciq import gradio as gr from typing import TypedDict from dataclasses import asdict, dataclass import math import os from typing import Iterable, List, Optional, Type import tqdm from nlp4web_codebase.ir.data_loaders.dm import Document from nlp4web_codebase.ir.models import BaseRetriever from typing import Type from abc import abstractmethod from nlp4web_codebase.ir.data_loaders import Split import pytrec_eval import numpy as np from matplotlib import pyplot as plt from scipy.sparse._csc import csc_matrix # -*- coding: utf-8 -*- """Kopie von HW1 (more instructed).ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1vFJ6AROcCYNkZRIxpyHs9T1sf-bdB_jW """ """## Pre-requisite code The code within this section will be used in the tasks. Please do not change these code lines. ### SciQ loading and counting """ LANGUAGE = "english" word_splitter = re.compile(r"(?u)\b\w\w+\b").findall stopwords = set(nltk_stopwords.words(LANGUAGE)) def word_splitting(text: str) -> List[str]: return word_splitter(text.lower()) def lemmatization(words: List[str]) -> List[str]: return words # We ignore lemmatization here for simplicity def simple_tokenize(text: str) -> List[str]: words = word_splitting(text) tokenized = list(filter(lambda w: w not in stopwords, words)) tokenized = lemmatization(tokenized) return tokenized T = TypeVar("T", bound="InvertedIndex") @dataclass class PostingList: term: str # The term docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting @dataclass class InvertedIndex: posting_lists: List[PostingList] # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index # The output of the counting function: @dataclass class Counting: posting_lists: List[PostingList] vocab: Dict[str, int] cid2docid: Dict[str, int] collection_ids: List[str] dfs: List[int] # tid -> df dls: List[int] # docid -> doc length avgdl: float nterms: int doc_texts: Optional[List[str]] = None def run_counting( documents: Iterable[Document], tokenize_fn: Callable[[str], List[str]] = simple_tokenize, store_raw: bool = True, # store the document text in doc_texts ndocs: Optional[int] = None, show_progress_bar: bool = True, ) -> Counting: """Counting TFs, DFs, doc_lengths, etc.""" posting_lists: List[PostingList] = [] vocab: Dict[str, int] = {} cid2docid: Dict[str, int] = {} collection_ids: List[str] = [] dfs: List[int] = [] # tid -> df dls: List[int] = [] # docid -> doc length nterms: int = 0 doc_texts: Optional[List[str]] = [] for doc in tqdm.tqdm( documents, desc="Counting", total=ndocs, disable=not show_progress_bar, ): if doc.collection_id in cid2docid: continue collection_ids.append(doc.collection_id) docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) toks = tokenize_fn(doc.text) tok2tf = Counter(toks) dls.append(sum(tok2tf.values())) for tok, tf in tok2tf.items(): nterms += tf tid = vocab.get(tok, None) if tid is None: posting_lists.append( PostingList(term=tok, docid_postings=[], tweight_postings=[]) ) tid = vocab.setdefault(tok, len(vocab)) posting_lists[tid].docid_postings.append(docid) posting_lists[tid].tweight_postings.append(tf) if tid < len(dfs): dfs[tid] += 1 else: dfs.append(0) if store_raw: doc_texts.append(doc.text) else: doc_texts = None return Counting( posting_lists=posting_lists, vocab=vocab, cid2docid=cid2docid, collection_ids=collection_ids, dfs=dfs, dls=dls, avgdl=sum(dls) / len(dls), nterms=nterms, doc_texts=doc_texts, ) sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) """### BM25 Index""" @dataclass class BM25Index(InvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> None: """Compute term weights and caching""" N = total_docs for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = BM25Index.calc_idf(df=dfs[tid], N=N) for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = BM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) posting_list.tweight_postings[i] = regularized_tf * idf @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type[BM25Index], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> BM25Index: # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=BM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) BM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = BM25Index( posting_lists=posting_lists, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, ) bm25_index.save("output/bm25_index") """### BM25 Retriever""" class BaseInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[InvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: toks = self.index.tokenize(query) target_docid = self.index.cid2docid[cid] term_weights = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, tweight in zip( posting_list.docid_postings, posting_list.tweight_postings ): if docid == target_docid: term_weights[tok] = tweight break return term_weights def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, tweight in zip( posting_list.docid_postings, posting_list.tweight_postings ): docid2score.setdefault(docid, 0) docid2score[docid] += tweight docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) return { self.index.collection_ids[docid]: score for docid, score in docid2score.items() } class BM25Retriever(BaseInvertedIndexRetriever): @property def index_class(self) -> Type[BM25Index]: return BM25Index bm25_retriever = BM25Retriever(index_dir="output/bm25_index") bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?") """# TASK1: tune b and k1 (4 points) Tune b and k1 on the **dev** split of SciQ using the metric MAP@10. The evaluation function (`evalaute_map`) is provided. Record the values in `plots_k1` and `plots_b`. Do it in a greedy manner: as the influence from b is larger, please first tune b (with k1 fixed to the default value 0.9) and use the best value of b to further tune k1. $${\displaystyle {\text{score}}(D,Q)=\sum _{i=1}^{n}{\text{IDF}}(q_{i})\cdot {\frac {f(q_{i},D)\cdot (k_{1}+1)}{f(q_{i},D)+k_{1}\cdot \left(1-b+b\cdot {\frac {|D|}{\text{avgdl}}}\right)}}}$$ """ def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float: metric = "map_cut_10" qrels = sciq.get_qrels_dict(split) evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,)) qps = evaluator.evaluate(rankings) return float(np.mean([qp[metric] for qp in qps.values()])) """Example of using the pre-requisite code:""" # Loading dataset: sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) # Building BM25 index and save: bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True ) bm25_index.save("output/bm25_index") # Loading index and use BM25 retriever to retrieve: bm25_retriever = BM25Retriever(index_dir="output/bm25_index") print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking plots_b: Dict[str, List[float]] = { "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], "Y": [] } plots_k1: Dict[str, List[float]] = { "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], "Y": [] } ## YOUR_CODE_STARTS_HERE # Two steps should be involved: # Step 1. Fix k1 value to the default one 0.9, # go through all the candidate b values (0, 0.1, ..., 1.0), # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map; # Step 2. Fix b to the best one in step 1. and do the same for k1. # Hint (on using the pre-requisite code): # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code); # - One can build bm25_index with `BM25Index.build_from_documents`; # - One can use BM25Retriever to load the index and perform retrieval on the dev queries # (dev queries can be obtained via sciq.get_split_queries(Split.dev)) sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) #Tuning b fixed_k1 = 0.9 for b in plots_b["X"]: print(b) bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=len(sciq.corpus), k1=fixed_k1, b=b, show_progress_bar=True, ) bm25_index.save("output/bm25_index") bm25_retriever = BM25Retriever(index_dir="output/bm25_index") dev_queries = {query.query_id: query.text for query in sciq.get_split_queries(Split.dev)} rankings : Dict[str, Dict[str, float]] = {qid: bm25_retriever.retrieve(query) for qid, query in dev_queries.items()} score = evaluate_map(rankings, split=Split.dev) plots_b["Y"].append(score) print(f"appended {score} to the plots_b list") tuned_b = plots_b["X"][np.argmax(plots_b["Y"])] print(f"The best value for b is: {tuned_b}") for k1 in plots_k1["X"]: print(k1) bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=len(sciq.corpus), k1=k1, b=tuned_b, show_progress_bar=True, ) bm25_index.save("output/bm25_index") bm25_retriever = BM25Retriever(index_dir="output/bm25_index") dev_queries = {query.query_id: query.text for query in sciq.get_split_queries(Split.dev)} rankings : Dict[str, Dict[str, float]] = {qid: bm25_retriever.retrieve(query) for qid, query in dev_queries.items()} score = evaluate_map(rankings, split=Split.dev) plots_k1["Y"].append(score) print(f"appended {score} to the plots_k1 list") tuned_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])] print(f"The best value for k1 is: {tuned_k1}") ## YOU_CODE_ENDS_HERE ## TEST_CASES (should be close to 0.8135637188208616 and 0.7512916099773244) print(plots_k1["Y"][9]) print(plots_b["Y"][1]) ## RESULT_CHECKING_POINT print(plots_k1) print(plots_b) plt.plot(plots_b["X"], plots_b["Y"], label="b") plt.plot(plots_k1["X"], plots_k1["Y"], label="k1") plt.ylabel("MAP") plt.legend() plt.grid() plt.show() """Let's check the effectiveness gain on test after this tuning on dev""" default_map = 0.7849 best_b = plots_b["X"][np.argmax(plots_b["Y"])] best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])] bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, k1=best_k1, b=best_b ) bm25_index.save("output/bm25_index") bm25_retriever = BM25Retriever(index_dir="output/bm25_index") rankings = {} for query in sciq.get_split_queries(Split.test): # note this is now on test ranking = bm25_retriever.retrieve(query=query.text) rankings[query.query_id] = ranking optimized_map = evaluate_map(rankings, split=Split.test) # note this is now on test print(default_map, optimized_map) """# TASK2: CSC matrix and `CSCBM25Index` (12 points) Recall that we use Python lists to implement posting lists, mapping term IDs to the documents in which they appear. This is inefficient due to its naive design. Actually [Compressed Sparse Column matrix](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csc_matrix.html) is very suitable for storing the posting lists and can boost the efficiency. ## TASK2.1: learn about `scipy.sparse.csc_matrix` (2 point) Convert the matrix \begin{bmatrix} 0 & 1 & 0 & 3 \\ 10 & 2 & 1 & 0 \\ 0 & 0 & 0 & 9 \end{bmatrix} to a `csc_matrix` by specifying `data`, `indices`, `indptr` and `shape`. """ input_matrix = [[0, 1, 0, 3], [10, 2, 1, 0], [0, 0, 0, 9]] data = None indices = None indptr = None shape = None ## YOUR_CODE_STARTS_HERE # Please assign the values to data, indices, indptr and shape # One can just do it in a hard-coded manner data = [10,1,2,1,3,9] indices = [1,0,1,1,0,2] indptr = [0,1,3,4,6] shape = (3,4) ## YOUR_CODE_ENDS_HERE output_matrix = csc_matrix((data, indices, indptr), shape=shape) ## TEST_CASES (should be 3 and 11) print((output_matrix.indices + output_matrix.data).tolist()[2]) print((output_matrix.indices + output_matrix.data).tolist()[-1]) ## RESULT_CHECKING_POINT print((output_matrix.indices + output_matrix.data).tolist()) """## TASK2.2: implement `CSCBM25Index` (4 points) Implement `CSCBM25Index` by completing the missing code. Note that `CSCInvertedIndex` is similar to `InvertedIndex` which we talked about during the class. The main difference is posting lists are represented by a CSC sparse matrix. """ @dataclass class CSCInvertedIndex: posting_lists_matrix: csc_matrix # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index @dataclass class CSCBM25Index(CSCInvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> csc_matrix: """Compute term weights and caching""" ## YOUR_CODE_STARTS_HERE data = [] indices = [] indptr = [0] N = total_docs for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = BM25Index.calc_idf(df=dfs[tid], N=N) counter = 0 for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = CSCBM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) weight = regularized_tf * idf data.append(np.float32(weight)) indices.append(np.int32(docid)) counter += 1 indptr.append(indptr[-1] + counter) posting_lists_matrix = csc_matrix((np.array(data, dtype=np.float32), np.array(indices, dtype=np.int32), np.array(indptr))) return posting_lists_matrix ## YOUR_CODE_ENDS_HERE @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type[CSCBM25Index], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> CSCBM25Index: # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=CSCBM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) posting_lists_matrix = CSCBM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = CSCBM25Index( posting_lists_matrix=posting_lists_matrix, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index csc_bm25_index = CSCBM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, k1=best_k1, b=best_b ) csc_bm25_index.save("output/csc_bm25_index") ## TEST_CASES (should be 7 and 95) print(len(str(os.path.getsize("output/csc_bm25_index/index.pkl")))) print(os.path.getsize("output/csc_bm25_index/index.pkl") // int(1e5)) ## RESULT_CHECKING_POINT print(os.path.getsize("output/csc_bm25_index/index.pkl")) """We can compare the size of the CSC-based index with the Python-list-based index:""" print(os.path.getsize("output/bm25_index/index.pkl")) """## TASK2.3: implement `CSCInvertedIndexRetriever` (6 points) Implement `CSCInvertedIndexRetriever` by completing the missing code. """ class BaseCSCInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[CSCInvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE toks = self.index.tokenize(query) target_docid = self.index.cid2docid[cid] term_weights = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists_matrix[:,tid].toarray().flatten() if posting_list[target_docid] > 0: term_weights[tok] = posting_list[target_docid] return term_weights ## YOUR_CODE_ENDS_HERE def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists_matrix[:, tid].toarray().flatten() for docid, tweight in enumerate(posting_list): if tweight > 0: docid2score.setdefault(docid, 0) docid2score[docid] += tweight docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) return { self.index.collection_ids[docid]: score for docid, score in docid2score.items() } ## YOUR_CODE_ENDS_HERE class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): @property def index_class(self) -> Type[CSCBM25Index]: return CSCBM25Index ## TEST_CASES (should be close to # {'theory': 3.1838157176971436, 'evolution': 3.488086223602295, 'natural': 2.629807710647583, 'selection': 3.552377462387085} # {'train-11632': 16.241527557373047, 'train-10931': 13.352127075195312, 'train-2006': 12.854086875915527, 'train-7040': 12.690572738647461, 'train-1719': 11.01913833618164, 'train-9875': 10.886155128479004, 'train-1971': 10.796306610107422, 'train-9882': 10.535819053649902, 'train-2018': 10.481085777282715, 'test-586': 10.478515625} #) csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index") query = "Who proposed the theory of evolution by natural selection?" print(csc_bm25_retriever.get_term_weights(query=query, cid="train-2006")) print(csc_bm25_retriever.retrieve(query)) ## RESULT_CHECKING_POINT csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index") query = "What are the differences between immunodeficiency and autoimmune diseases?" print(csc_bm25_retriever.get_term_weights(query=query, cid="train-1691")) print(csc_bm25_retriever.retrieve("What are the differences between immunodeficiency and autoimmune diseases?")) """# TASK3: a search-engine demo based on Huggingface space (4 points) ## TASK3.1: create the gradio app (2 point) Create a gradio app to demo the BM25 search engine index on SciQ. The app should have a single input variable for the query (of type `str`) and a single output variable for the returned ranking (of type `List[Hit]` in the code below). Please use the BM25 system with default k1 and b values. Hint: it should use a "search" function of signature: ```python def search(query: str) -> List[Hit]: ... ``` """ class Hit(TypedDict): cid: str score: float text: str demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable return_type = List[Hit] ## YOUR_CODE_STARTS_HERE def search(query: str) -> List[Hit]: bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=len(sciq.corpus), show_progress_bar=True,) bm25_index.save("output/bm25_index") bm25_retriever = BM25Retriever(index_dir="output/bm25_index") ranking = bm25_retriever.retrieve(query) hits: List[Hit] = [ {"cid": cid, "score": score, "text": bm25_index.doc_texts[bm25_index.cid2docid[cid]]} for cid, score in ranking.items() ] return hits demo: Optional[gr.Interface] = gr.Interface( fn=search, inputs="text", outputs=gr.Textbox(), title="Search-engine demo", description="Please enter your search query", ) ## YOUR_CODE_ENDS_HERE demo.launch()