File size: 4,472 Bytes
13880a3
 
 
 
 
 
 
 
 
 
 
b26218c
 
13880a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b26218c
13880a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from dataclasses import dataclass
import pickle
import os
from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
from nlp4web_codebase.ir.data_loaders.dm import Document
from collections import Counter
import tqdm
import re
import nltk
nltk.download("stopwords", quiet=True)
from nltk.corpus import stopwords as nltk_stopwords
from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
sciq = load_sciq()

LANGUAGE = "english"
word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
stopwords = set(nltk_stopwords.words(LANGUAGE))


def word_splitting(text: str) -> List[str]:
    return word_splitter(text.lower())

def lemmatization(words: List[str]) -> List[str]:
    return words  # We ignore lemmatization here for simplicity

def simple_tokenize(text: str) -> List[str]:
    words = word_splitting(text)
    tokenized = list(filter(lambda w: w not in stopwords, words))
    tokenized = lemmatization(tokenized)
    return tokenized

T = TypeVar("T", bound="InvertedIndex")

@dataclass
class PostingList:
    term: str  # The term
    docid_postings: List[int]  # docid_postings[i] means the docid (int) of the i-th associated posting
    tweight_postings: List[float]  # tweight_postings[i] means the term weight (float) of the i-th associated posting


@dataclass
class InvertedIndex:
    posting_lists: List[PostingList]  # docid -> posting_list
    vocab: Dict[str, int]
    cid2docid: Dict[str, int]  # collection_id -> docid
    collection_ids: List[str]  # docid -> collection_id
    doc_texts: Optional[List[str]] = None  # docid -> document text

    def save(self, output_dir: str) -> None:
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
            pickle.dump(self, f)

    @classmethod
    def from_saved(cls: Type[T], saved_dir: str) -> T:
        index = cls(
            posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
        )
        with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
            index = pickle.load(f)
        return index


# The output of the counting function:
@dataclass
class Counting:
    posting_lists: List[PostingList]
    vocab: Dict[str, int]
    cid2docid: Dict[str, int]
    collection_ids: List[str]
    dfs: List[int]  # tid -> df
    dls: List[int]  # docid -> doc length
    avgdl: float
    nterms: int
    doc_texts: Optional[List[str]] = None

def run_counting(
    documents: Iterable[Document],
    tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
    store_raw: bool = True,  # store the document text in doc_texts
    ndocs: Optional[int] = None,
    show_progress_bar: bool = True,
) -> Counting:
    """Counting TFs, DFs, doc_lengths, etc."""
    posting_lists: List[PostingList] = []
    vocab: Dict[str, int] = {}
    cid2docid: Dict[str, int] = {}
    collection_ids: List[str] = []
    dfs: List[int] = []  # tid -> df
    dls: List[int] = []  # docid -> doc length
    nterms: int = 0
    doc_texts: Optional[List[str]] = []
    for doc in tqdm.tqdm(
        documents,
        desc="Counting",
        total=ndocs,
        disable=not show_progress_bar,
    ):
        if doc.collection_id in cid2docid:
            continue
        collection_ids.append(doc.collection_id)
        docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
        toks = tokenize_fn(doc.text)
        tok2tf = Counter(toks)
        dls.append(sum(tok2tf.values()))
        for tok, tf in tok2tf.items():
            nterms += tf
            tid = vocab.get(tok, None)
            if tid is None:
                posting_lists.append(
                    PostingList(term=tok, docid_postings=[], tweight_postings=[])
                )
                tid = vocab.setdefault(tok, len(vocab))
            posting_lists[tid].docid_postings.append(docid)
            posting_lists[tid].tweight_postings.append(tf)
            if tid < len(dfs):
                dfs[tid] += 1
            else:
                dfs.append(0)
        if store_raw:
            doc_texts.append(doc.text)
        else:
            doc_texts = None
    return Counting(
        posting_lists=posting_lists,
        vocab=vocab,
        cid2docid=cid2docid,
        collection_ids=collection_ids,
        dfs=dfs,
        dls=dls,
        avgdl=sum(dls) / len(dls),
        nterms=nterms,
        doc_texts=doc_texts,
    )


counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))