Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from dataclasses import asdict, dataclass | |
import math | |
import os | |
from typing import Iterable, List, Optional, Type | |
import tqdm | |
from nlp4web_codebase.ir.data_loaders.dm import Document | |
from sciq_load import PostingList, InvertedIndex, Counting, run_counting, simple_tokenize | |
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq | |
sciq = load_sciq() | |
class BM25Index(InvertedIndex): | |
def tokenize(text: str) -> List[str]: | |
return simple_tokenize(text) | |
def cache_term_weights( | |
posting_lists: List[PostingList], | |
total_docs: int, | |
avgdl: float, | |
dfs: List[int], | |
dls: List[int], | |
k1: float, | |
b: float, | |
) -> None: | |
"""Compute term weights and caching""" | |
N = total_docs | |
for tid, posting_list in enumerate( | |
tqdm.tqdm(posting_lists, desc="Regularizing TFs") | |
): | |
idf = BM25Index.calc_idf(df=dfs[tid], N=N) | |
for i in range(len(posting_list.docid_postings)): | |
docid = posting_list.docid_postings[i] | |
tf = posting_list.tweight_postings[i] | |
dl = dls[docid] | |
regularized_tf = BM25Index.calc_regularized_tf( | |
tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b | |
) | |
posting_list.tweight_postings[i] = regularized_tf * idf | |
def calc_regularized_tf( | |
tf: int, dl: float, avgdl: float, k1: float, b: float | |
) -> float: | |
return tf / (tf + k1 * (1 - b + b * dl / avgdl)) | |
def calc_idf(df: int, N: int): | |
return math.log(1 + (N - df + 0.5) / (df + 0.5)) | |
def build_from_documents( | |
cls: Type[BM25Index], | |
documents: Iterable[Document], | |
store_raw: bool = True, | |
output_dir: Optional[str] = None, | |
ndocs: Optional[int] = None, | |
show_progress_bar: bool = True, | |
k1: float = 0.9, | |
b: float = 0.4, | |
) -> BM25Index: | |
# Counting TFs, DFs, doc_lengths, etc.: | |
counting = run_counting( | |
documents=documents, | |
tokenize_fn=BM25Index.tokenize, | |
store_raw=store_raw, | |
ndocs=ndocs, | |
show_progress_bar=show_progress_bar, | |
) | |
# Compute term weights and caching: | |
posting_lists = counting.posting_lists | |
total_docs = len(counting.cid2docid) | |
BM25Index.cache_term_weights( | |
posting_lists=posting_lists, | |
total_docs=total_docs, | |
avgdl=counting.avgdl, | |
dfs=counting.dfs, | |
dls=counting.dls, | |
k1=k1, | |
b=b, | |
) | |
# Assembly and save: | |
index = BM25Index( | |
posting_lists=posting_lists, | |
vocab=counting.vocab, | |
cid2docid=counting.cid2docid, | |
collection_ids=counting.collection_ids, | |
doc_texts=counting.doc_texts, | |
) | |
return index | |
bm25_index = BM25Index.build_from_documents( | |
documents=iter(sciq.corpus), | |
ndocs=12160, | |
show_progress_bar=True, | |
) | |
bm25_index.save("output/bm25_index") |