from __future__ import annotations from dataclasses import dataclass import pickle import os from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar from nlp4web_codebase.ir.data_loaders.dm import Document from collections import Counter import tqdm import re import nltk nltk.download("stopwords", quiet=True) from nltk.corpus import stopwords as nltk_stopwords LANGUAGE = "english" word_splitter = re.compile(r"(?u)\b\w\w+\b").findall stopwords = set(nltk_stopwords.words(LANGUAGE)) def word_splitting(text: str) -> List[str]: return word_splitter(text.lower()) def lemmatization(words: List[str]) -> List[str]: return words # We ignore lemmatization here for simplicity def simple_tokenize(text: str) -> List[str]: words = word_splitting(text) tokenized = list(filter(lambda w: w not in stopwords, words)) tokenized = lemmatization(tokenized) return tokenized T = TypeVar("T", bound="InvertedIndex") @dataclass class PostingList: term: str # The term docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting @dataclass class InvertedIndex: posting_lists: List[PostingList] # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index # The output of the counting function: @dataclass class Counting: posting_lists: List[PostingList] vocab: Dict[str, int] cid2docid: Dict[str, int] collection_ids: List[str] dfs: List[int] # tid -> df dls: List[int] # docid -> doc length avgdl: float nterms: int doc_texts: Optional[List[str]] = None def run_counting( documents: Iterable[Document], tokenize_fn: Callable[[str], List[str]] = simple_tokenize, store_raw: bool = True, # store the document text in doc_texts ndocs: Optional[int] = None, show_progress_bar: bool = True, ) -> Counting: """Counting TFs, DFs, doc_lengths, etc.""" posting_lists: List[PostingList] = [] vocab: Dict[str, int] = {} cid2docid: Dict[str, int] = {} collection_ids: List[str] = [] dfs: List[int] = [] # tid -> df dls: List[int] = [] # docid -> doc length nterms: int = 0 doc_texts: Optional[List[str]] = [] for doc in tqdm.tqdm( documents, desc="Counting", total=ndocs, disable=not show_progress_bar, ): if doc.collection_id in cid2docid: continue collection_ids.append(doc.collection_id) docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) toks = tokenize_fn(doc.text) tok2tf = Counter(toks) dls.append(sum(tok2tf.values())) for tok, tf in tok2tf.items(): nterms += tf tid = vocab.get(tok, None) if tid is None: posting_lists.append( PostingList(term=tok, docid_postings=[], tweight_postings=[]) ) tid = vocab.setdefault(tok, len(vocab)) posting_lists[tid].docid_postings.append(docid) posting_lists[tid].tweight_postings.append(tf) if tid < len(dfs): dfs[tid] += 1 else: dfs.append(0) if store_raw: doc_texts.append(doc.text) else: doc_texts = None return Counting( posting_lists=posting_lists, vocab=vocab, cid2docid=cid2docid, collection_ids=collection_ids, dfs=dfs, dls=dls, avgdl=sum(dls) / len(dls), nterms=nterms, doc_texts=doc_texts, ) from nlp4web_codebase.ir.data_loaders.sciq import load_sciq sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) from dataclasses import asdict, dataclass import math import os from typing import Iterable, List, Optional, Type import tqdm from nlp4web_codebase.ir.data_loaders.dm import Document @dataclass class BM25Index(InvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> None: """Compute term weights and caching""" N = total_docs for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = BM25Index.calc_idf(df=dfs[tid], N=N) for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = BM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) posting_list.tweight_postings[i] = regularized_tf * idf @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type[BM25Index], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> BM25Index: # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=BM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) BM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = BM25Index( posting_lists=posting_lists, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, ) bm25_index.save("output/bm25_index") from nlp4web_codebase.ir.models import BaseRetriever from typing import Type from abc import abstractmethod class BaseInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[InvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: toks = self.index.tokenize(query) target_docid = self.index.cid2docid[cid] term_weights = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, tweight in zip( posting_list.docid_postings, posting_list.tweight_postings ): if docid == target_docid: term_weights[tok] = tweight break return term_weights def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists[tid] for docid, tweight in zip( posting_list.docid_postings, posting_list.tweight_postings ): docid2score.setdefault(docid, 0) docid2score[docid] += tweight docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) return { self.index.collection_ids[docid]: score for docid, score in docid2score.items() } class BM25Retriever(BaseInvertedIndexRetriever): @property def index_class(self) -> Type[BM25Index]: return BM25Index from nlp4web_codebase.ir.data_loaders import Split import pytrec_eval import numpy as np def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float: metric = "map_cut_10" qrels = sciq.get_qrels_dict(split) evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,)) qps = evaluator.evaluate(rankings) return float(np.mean([qp[metric] for qp in qps.values()])) # Loading dataset: from nlp4web_codebase.ir.data_loaders.sciq import load_sciq sciq = load_sciq() counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) # Building BM25 index and save: bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True ) bm25_index.save("output/bm25_index") plots_b: Dict[str, List[float]] = { "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], "Y": [] } plots_k1: Dict[str, List[float]] = { "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0], "Y": [] } ## YOUR_CODE_STARTS_HERE # Two steps should be involved: # Step 1. Fix k1 value to the default one 0.9, # go through all the candidate b values (0, 0.1, ..., 1.0), # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map; # Step 2. Fix b to the best one in step 1. and do the same for k1. sciq_dataset = load_sciq() dev_queries = sciq.get_split_queries(Split.dev) dev_qrels = sciq.get_qrels_dict(Split.dev) def evaluate_bm25(k1, b): bm25_index = BM25Index.build_from_documents( documents=iter(sciq_dataset.corpus), ndocs=len(sciq_dataset.corpus), show_progress_bar=True, k1=k1, b=b ) bm25_index.save("output/bm25_index_task1") # Initialize BM25Retriever with specified k1 and b values bm25_retriever = BM25Retriever(index_dir="output/bm25_index_task1") # Dictionary to store rankings for each query rankings = {} for query in dev_queries: # print(query.text) query_text = query.text query_id = query.query_id # Get top-ranked documents for the query top_documents = bm25_retriever.retrieve(query_text, topk=10) rankings[query_id] = top_documents # Store in the format expected by evaluate_map # Evaluate MAP@10 on the dev split return evaluate_map(rankings, split=Split.dev) for b in plots_b["X"]: plots_b["Y"].append(evaluate_bm25(k1=0.9, b=b)) # Find the best value of `b` best_b = plots_b["X"][np.argmax(plots_b["Y"])] for k1 in plots_k1["X"]: plots_k1["Y"].append(evaluate_bm25(k1=k1, b=best_b)) # Find the best value of `b` best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])] # Hint (on using the pre-requisite code): # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code); # - One can build bm25_index with `BM25Index.build_from_documents`; # - One can use BM25Retriever to load the index and perform retrieval on the dev queries # (dev queries can be obtained via sciq.get_split_queries(Split.dev)) ## YOUR_CODE_ENDS_HERE from scipy.sparse._csc import csc_matrix @dataclass class CSCInvertedIndex: posting_lists_matrix: csc_matrix # docid -> posting_list vocab: Dict[str, int] cid2docid: Dict[str, int] # collection_id -> docid collection_ids: List[str] # docid -> collection_id doc_texts: Optional[List[str]] = None # docid -> document text def save(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "index.pkl"), "wb") as f: pickle.dump(self, f) @classmethod def from_saved(cls: Type[T], saved_dir: str) -> T: index = cls( posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None ) with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: index = pickle.load(f) return index @dataclass class CSCBM25Index(CSCInvertedIndex): @staticmethod def tokenize(text: str) -> List[str]: return simple_tokenize(text) @staticmethod def cache_term_weights( posting_lists: List[PostingList], total_docs: int, avgdl: float, dfs: List[int], dls: List[int], k1: float, b: float, ) -> csc_matrix: """Compute term weights and caching""" ## YOUR_CODE_STARTS_HERE data = [] # Holds the term weights indices = [] # Document IDs indptr = [0] # Term IDs N = total_docs for tid, posting_list in enumerate( tqdm.tqdm(posting_lists, desc="Regularizing TFs") ): idf = CSCBM25Index.calc_idf(df=dfs[tid], N=N) for i in range(len(posting_list.docid_postings)): docid = posting_list.docid_postings[i] tf = posting_list.tweight_postings[i] dl = dls[docid] regularized_tf = CSCBM25Index.calc_regularized_tf( tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b ) term_weight = regularized_tf * idf # Append to lists for sparse matrix data.append(term_weight) indices.append(docid) indptr.append(len(data)) # Create a CSC matrix where rows are documents, columns are terms term_weights_matrix = csc_matrix((data, indices, indptr), shape=(total_docs, len(posting_lists)), dtype=np.float32) print("INDPTR HERE") return term_weights_matrix ## YOUR_CODE_ENDS_HERE @staticmethod def calc_regularized_tf( tf: int, dl: float, avgdl: float, k1: float, b: float ) -> float: return tf / (tf + k1 * (1 - b + b * dl / avgdl)) @staticmethod def calc_idf(df: int, N: int): return math.log(1 + (N - df + 0.5) / (df + 0.5)) @classmethod def build_from_documents( cls: Type[CSCBM25Index], documents: Iterable[Document], store_raw: bool = True, output_dir: Optional[str] = None, ndocs: Optional[int] = None, show_progress_bar: bool = True, k1: float = 0.9, b: float = 0.4, ) -> CSCBM25Index: # Counting TFs, DFs, doc_lengths, etc.: counting = run_counting( documents=documents, tokenize_fn=CSCBM25Index.tokenize, store_raw=store_raw, ndocs=ndocs, show_progress_bar=show_progress_bar, ) # Compute term weights and caching: posting_lists = counting.posting_lists total_docs = len(counting.cid2docid) posting_lists_matrix = CSCBM25Index.cache_term_weights( posting_lists=posting_lists, total_docs=total_docs, avgdl=counting.avgdl, dfs=counting.dfs, dls=counting.dls, k1=k1, b=b, ) # Assembly and save: index = CSCBM25Index( posting_lists_matrix=posting_lists_matrix, vocab=counting.vocab, cid2docid=counting.cid2docid, collection_ids=counting.collection_ids, doc_texts=counting.doc_texts, ) return index csc_bm25_index = CSCBM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, k1=best_k1, b=best_b ) csc_bm25_index.save("output/csc_bm25_index") class BaseCSCInvertedIndexRetriever(BaseRetriever): @property @abstractmethod def index_class(self) -> Type[CSCInvertedIndex]: pass def __init__(self, index_dir: str) -> None: self.index = self.index_class.from_saved(index_dir) def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE toks = self.index.tokenize(query) target_docid = self.index.cid2docid[cid] term_weights = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists_matrix.getcol(tid) doc_ids = posting_list.indices tweights = posting_list.data for docid, tweight in zip( doc_ids, tweights ): if docid == target_docid: term_weights[tok] = tweight break return term_weights ## YOUR_CODE_ENDS_HERE def score(self, query: str, cid: str) -> float: return sum(self.get_term_weights(query=query, cid=cid).values()) def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: ## YOUR_CODE_STARTS_HERE toks = self.index.tokenize(query) docid2score: Dict[int, float] = {} for tok in toks: if tok not in self.index.vocab: continue tid = self.index.vocab[tok] posting_list = self.index.posting_lists_matrix.getcol(tid) doc_ids = posting_list.indices tweights = posting_list.data for docid, tweight in zip( doc_ids, tweights ): docid2score.setdefault(docid, 0) docid2score[docid] += tweight docid2score = dict( sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] ) return { self.index.collection_ids[docid]: score for docid, score in docid2score.items() } ## YOUR_CODE_ENDS_HERE class CSCBM25Retriever(BaseCSCInvertedIndexRetriever): @property def index_class(self) -> Type[CSCBM25Index]: return CSCBM25Index import gradio as gr from typing import TypedDict class Hit(TypedDict): cid: str score: float text: str demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable return_type = List[Hit] ## YOUR_CODE_STARTS_HERE # Building BM25 index and save: bm25_index = BM25Index.build_from_documents( documents=iter(sciq.corpus), ndocs=12160, show_progress_bar=True, k1=best_k1, b=best_b ) bm25_index.save("output/bm25_index_app") # Loading index and use BM25 retriever to retrieve: bm25_retriever = BM25Retriever(index_dir="output/bm25_index_app") # print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking def search(query: str) -> List[Hit]: response = bm25_retriever.retrieve(query) hits = [] for cid, score in response.items(): docid = bm25_index.cid2docid[cid] hits.append(Hit(cid=cid, score=score, text=sciq.corpus[docid])) return hits # Create the Gradio interface demo = gr.Interface( fn=search, # Function to call on submit inputs=gr.Textbox(label="Enter your query"), # Input field with label outputs=gr.Textbox(label="RESULT HERE"), # Output field to display result live=False # Disable real-time updates to require a button click ) ## YOUR_CODE_ENDS_HERE demo.launch(share=True)