|
from dataclasses import dataclass |
|
import pickle |
|
import os |
|
from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar |
|
from nlp4web_codebase.ir.data_loaders.dm import Document |
|
from collections import Counter |
|
import tqdm |
|
import re |
|
import nltk |
|
nltk.download("stopwords", quiet=True) |
|
from nltk.corpus import stopwords as nltk_stopwords |
|
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq |
|
sciq = load_sciq() |
|
|
|
LANGUAGE = "english" |
|
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall |
|
stopwords = set(nltk_stopwords.words(LANGUAGE)) |
|
|
|
|
|
def word_splitting(text: str) -> List[str]: |
|
return word_splitter(text.lower()) |
|
|
|
def lemmatization(words: List[str]) -> List[str]: |
|
return words |
|
|
|
def simple_tokenize(text: str) -> List[str]: |
|
words = word_splitting(text) |
|
tokenized = list(filter(lambda w: w not in stopwords, words)) |
|
tokenized = lemmatization(tokenized) |
|
return tokenized |
|
|
|
T = TypeVar("T", bound="InvertedIndex") |
|
|
|
@dataclass |
|
class PostingList: |
|
term: str |
|
docid_postings: List[int] |
|
tweight_postings: List[float] |
|
|
|
|
|
@dataclass |
|
class InvertedIndex: |
|
posting_lists: List[PostingList] |
|
vocab: Dict[str, int] |
|
cid2docid: Dict[str, int] |
|
collection_ids: List[str] |
|
doc_texts: Optional[List[str]] = None |
|
|
|
def save(self, output_dir: str) -> None: |
|
os.makedirs(output_dir, exist_ok=True) |
|
with open(os.path.join(output_dir, "index.pkl"), "wb") as f: |
|
pickle.dump(self, f) |
|
|
|
@classmethod |
|
def from_saved(cls: Type[T], saved_dir: str) -> T: |
|
index = cls( |
|
posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None |
|
) |
|
with open(os.path.join(saved_dir, "index.pkl"), "rb") as f: |
|
index = pickle.load(f) |
|
return index |
|
|
|
|
|
|
|
@dataclass |
|
class Counting: |
|
posting_lists: List[PostingList] |
|
vocab: Dict[str, int] |
|
cid2docid: Dict[str, int] |
|
collection_ids: List[str] |
|
dfs: List[int] |
|
dls: List[int] |
|
avgdl: float |
|
nterms: int |
|
doc_texts: Optional[List[str]] = None |
|
|
|
def run_counting( |
|
documents: Iterable[Document], |
|
tokenize_fn: Callable[[str], List[str]] = simple_tokenize, |
|
store_raw: bool = True, |
|
ndocs: Optional[int] = None, |
|
show_progress_bar: bool = True, |
|
) -> Counting: |
|
"""Counting TFs, DFs, doc_lengths, etc.""" |
|
posting_lists: List[PostingList] = [] |
|
vocab: Dict[str, int] = {} |
|
cid2docid: Dict[str, int] = {} |
|
collection_ids: List[str] = [] |
|
dfs: List[int] = [] |
|
dls: List[int] = [] |
|
nterms: int = 0 |
|
doc_texts: Optional[List[str]] = [] |
|
for doc in tqdm.tqdm( |
|
documents, |
|
desc="Counting", |
|
total=ndocs, |
|
disable=not show_progress_bar, |
|
): |
|
if doc.collection_id in cid2docid: |
|
continue |
|
collection_ids.append(doc.collection_id) |
|
docid = cid2docid.setdefault(doc.collection_id, len(cid2docid)) |
|
toks = tokenize_fn(doc.text) |
|
tok2tf = Counter(toks) |
|
dls.append(sum(tok2tf.values())) |
|
for tok, tf in tok2tf.items(): |
|
nterms += tf |
|
tid = vocab.get(tok, None) |
|
if tid is None: |
|
posting_lists.append( |
|
PostingList(term=tok, docid_postings=[], tweight_postings=[]) |
|
) |
|
tid = vocab.setdefault(tok, len(vocab)) |
|
posting_lists[tid].docid_postings.append(docid) |
|
posting_lists[tid].tweight_postings.append(tf) |
|
if tid < len(dfs): |
|
dfs[tid] += 1 |
|
else: |
|
dfs.append(0) |
|
if store_raw: |
|
doc_texts.append(doc.text) |
|
else: |
|
doc_texts = None |
|
return Counting( |
|
posting_lists=posting_lists, |
|
vocab=vocab, |
|
cid2docid=cid2docid, |
|
collection_ids=collection_ids, |
|
dfs=dfs, |
|
dls=dls, |
|
avgdl=sum(dls) / len(dls), |
|
nterms=nterms, |
|
doc_texts=doc_texts, |
|
) |
|
|
|
|
|
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus)) |