BM25implementation / sciq_load.py
pratyushpaliwal's picture
Update sciq_load.py
b26218c verified
from dataclasses import dataclass
import pickle
import os
from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
from nlp4web_codebase.ir.data_loaders.dm import Document
from collections import Counter
import tqdm
import re
import nltk
nltk.download("stopwords", quiet=True)
from nltk.corpus import stopwords as nltk_stopwords
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
sciq = load_sciq()
LANGUAGE = "english"
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
stopwords = set(nltk_stopwords.words(LANGUAGE))
def word_splitting(text: str) -> List[str]:
return word_splitter(text.lower())
def lemmatization(words: List[str]) -> List[str]:
return words # We ignore lemmatization here for simplicity
def simple_tokenize(text: str) -> List[str]:
words = word_splitting(text)
tokenized = list(filter(lambda w: w not in stopwords, words))
tokenized = lemmatization(tokenized)
return tokenized
T = TypeVar("T", bound="InvertedIndex")
@dataclass
class PostingList:
term: str # The term
docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
@dataclass
class InvertedIndex:
posting_lists: List[PostingList] # docid -> posting_list
vocab: Dict[str, int]
cid2docid: Dict[str, int] # collection_id -> docid
collection_ids: List[str] # docid -> collection_id
doc_texts: Optional[List[str]] = None # docid -> document text
def save(self, output_dir: str) -> None:
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
pickle.dump(self, f)
@classmethod
def from_saved(cls: Type[T], saved_dir: str) -> T:
index = cls(
posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
)
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
index = pickle.load(f)
return index
# The output of the counting function:
@dataclass
class Counting:
posting_lists: List[PostingList]
vocab: Dict[str, int]
cid2docid: Dict[str, int]
collection_ids: List[str]
dfs: List[int] # tid -> df
dls: List[int] # docid -> doc length
avgdl: float
nterms: int
doc_texts: Optional[List[str]] = None
def run_counting(
documents: Iterable[Document],
tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
store_raw: bool = True, # store the document text in doc_texts
ndocs: Optional[int] = None,
show_progress_bar: bool = True,
) -> Counting:
"""Counting TFs, DFs, doc_lengths, etc."""
posting_lists: List[PostingList] = []
vocab: Dict[str, int] = {}
cid2docid: Dict[str, int] = {}
collection_ids: List[str] = []
dfs: List[int] = [] # tid -> df
dls: List[int] = [] # docid -> doc length
nterms: int = 0
doc_texts: Optional[List[str]] = []
for doc in tqdm.tqdm(
documents,
desc="Counting",
total=ndocs,
disable=not show_progress_bar,
):
if doc.collection_id in cid2docid:
continue
collection_ids.append(doc.collection_id)
docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
toks = tokenize_fn(doc.text)
tok2tf = Counter(toks)
dls.append(sum(tok2tf.values()))
for tok, tf in tok2tf.items():
nterms += tf
tid = vocab.get(tok, None)
if tid is None:
posting_lists.append(
PostingList(term=tok, docid_postings=[], tweight_postings=[])
)
tid = vocab.setdefault(tok, len(vocab))
posting_lists[tid].docid_postings.append(docid)
posting_lists[tid].tweight_postings.append(tf)
if tid < len(dfs):
dfs[tid] += 1
else:
dfs.append(0)
if store_raw:
doc_texts.append(doc.text)
else:
doc_texts = None
return Counting(
posting_lists=posting_lists,
vocab=vocab,
cid2docid=cid2docid,
collection_ids=collection_ids,
dfs=dfs,
dls=dls,
avgdl=sum(dls) / len(dls),
nterms=nterms,
doc_texts=doc_texts,
)
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))