lreining commited on
Commit
b8fd633
·
1 Parent(s): 152b8a0
Files changed (2) hide show
  1. app.py +454 -0
  2. requirements.txt +109 -0
app.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import os
5
+ import pickle
6
+ import re
7
+ from abc import abstractmethod
8
+ from collections import Counter
9
+ from dataclasses import dataclass
10
+ from typing import Callable, Dict, Iterable, List, Optional, Type, TypedDict, TypeVar
11
+
12
+ import gradio as gr
13
+ import nltk
14
+ import numpy as np
15
+ import tqdm
16
+ from nlp4web_codebase.ir.data_loaders.dm import Document
17
+ from nlp4web_codebase.ir.models import BaseRetriever
18
+ from scipy.sparse._csc import csc_matrix
19
+
20
+
21
+ class Hit(TypedDict):
22
+ cid: str
23
+ score: float
24
+ text: str
25
+
26
+
27
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
28
+ return_type = List[Hit]
29
+
30
+ LANGUAGE = "english"
31
+ nltk.download("stopwords", quiet=True)
32
+ from nltk.corpus import stopwords as nltk_stopwords
33
+
34
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
35
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
36
+
37
+
38
+ def word_splitting(text: str) -> List[str]:
39
+ return word_splitter(text.lower())
40
+
41
+
42
+ def lemmatization(words: List[str]) -> List[str]:
43
+ return words # We ignore lemmatization here for simplicity
44
+
45
+
46
+ def simple_tokenize(text: str) -> List[str]:
47
+ words = word_splitting(text)
48
+ tokenized = list(filter(lambda w: w not in stopwords, words))
49
+ tokenized = lemmatization(tokenized)
50
+ return tokenized
51
+
52
+
53
+ @dataclass
54
+ class PostingList:
55
+ term: str # The term
56
+ docid_postings: List[
57
+ int
58
+ ] # docid_postings[i] means the docid (int) of the i-th associated posting
59
+ tweight_postings: List[
60
+ float
61
+ ] # tweight_postings[i] means the term weight (float) of the i-th associated posting
62
+
63
+
64
+ @dataclass
65
+ class InvertedIndex:
66
+ posting_lists: List[PostingList] # docid -> posting_list
67
+ vocab: Dict[str, int]
68
+ cid2docid: Dict[str, int] # collection_id -> docid
69
+ collection_ids: List[str] # docid -> collection_id
70
+ doc_texts: Optional[List[str]] = None # docid -> document text
71
+
72
+ def save(self, output_dir: str) -> None:
73
+ os.makedirs(output_dir, exist_ok=True)
74
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
75
+ pickle.dump(self, f)
76
+
77
+ @classmethod
78
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
79
+ index = cls(
80
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
81
+ )
82
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
83
+ index = pickle.load(f)
84
+ return index
85
+
86
+
87
+ T = TypeVar("T", bound="InvertedIndex")
88
+
89
+
90
+ # The output of the counting function:
91
+ @dataclass
92
+ class Counting:
93
+ posting_lists: List[PostingList]
94
+ vocab: Dict[str, int]
95
+ cid2docid: Dict[str, int]
96
+ collection_ids: List[str]
97
+ dfs: List[int] # tid -> df
98
+ dls: List[int] # docid -> doc length
99
+ avgdl: float
100
+ nterms: int
101
+ doc_texts: Optional[List[str]] = None
102
+
103
+
104
+ def run_counting(
105
+ documents: Iterable[Document],
106
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
107
+ store_raw: bool = True, # store the document text in doc_texts
108
+ ndocs: Optional[int] = None,
109
+ show_progress_bar: bool = True,
110
+ ) -> Counting:
111
+ """Counting TFs, DFs, doc_lengths, etc."""
112
+ posting_lists: List[PostingList] = []
113
+ vocab: Dict[str, int] = {}
114
+ cid2docid: Dict[str, int] = {}
115
+ collection_ids: List[str] = []
116
+ dfs: List[int] = [] # tid -> df
117
+ dls: List[int] = [] # docid -> doc length
118
+ nterms: int = 0
119
+ doc_texts: Optional[List[str]] = []
120
+ for doc in tqdm.tqdm(
121
+ documents,
122
+ desc="Counting",
123
+ total=ndocs,
124
+ disable=not show_progress_bar,
125
+ ):
126
+ if doc.collection_id in cid2docid:
127
+ continue
128
+ collection_ids.append(doc.collection_id)
129
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
130
+ toks = tokenize_fn(doc.text)
131
+ tok2tf = Counter(toks)
132
+ dls.append(sum(tok2tf.values()))
133
+ for tok, tf in tok2tf.items():
134
+ nterms += tf
135
+ tid = vocab.get(tok, None)
136
+ if tid is None:
137
+ posting_lists.append(
138
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
139
+ )
140
+ tid = vocab.setdefault(tok, len(vocab))
141
+ posting_lists[tid].docid_postings.append(docid)
142
+ posting_lists[tid].tweight_postings.append(tf)
143
+ if tid < len(dfs):
144
+ dfs[tid] += 1
145
+ else:
146
+ dfs.append(0)
147
+ if store_raw:
148
+ doc_texts.append(doc.text)
149
+ else:
150
+ doc_texts = None
151
+ return Counting(
152
+ posting_lists=posting_lists,
153
+ vocab=vocab,
154
+ cid2docid=cid2docid,
155
+ collection_ids=collection_ids,
156
+ dfs=dfs,
157
+ dls=dls,
158
+ avgdl=sum(dls) / len(dls),
159
+ nterms=nterms,
160
+ doc_texts=doc_texts,
161
+ )
162
+
163
+
164
+ @dataclass
165
+ class BM25Index(InvertedIndex):
166
+
167
+ @staticmethod
168
+ def tokenize(text: str) -> List[str]:
169
+ return simple_tokenize(text)
170
+
171
+ @staticmethod
172
+ def cache_term_weights(
173
+ posting_lists: List[PostingList],
174
+ total_docs: int,
175
+ avgdl: float,
176
+ dfs: List[int],
177
+ dls: List[int],
178
+ k1: float,
179
+ b: float,
180
+ ) -> None:
181
+ """Compute term weights and caching"""
182
+
183
+ N = total_docs
184
+ for tid, posting_list in enumerate(
185
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
186
+ ):
187
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
188
+ for i in range(len(posting_list.docid_postings)):
189
+ docid = posting_list.docid_postings[i]
190
+ tf = posting_list.tweight_postings[i]
191
+ dl = dls[docid]
192
+ regularized_tf = BM25Index.calc_regularized_tf(
193
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
194
+ )
195
+ posting_list.tweight_postings[i] = regularized_tf * idf
196
+
197
+ @staticmethod
198
+ def calc_regularized_tf(
199
+ tf: int, dl: float, avgdl: float, k1: float, b: float
200
+ ) -> float:
201
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
202
+
203
+ @staticmethod
204
+ def calc_idf(df: int, N: int):
205
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
206
+
207
+ @classmethod
208
+ def build_from_documents(
209
+ cls: Type[BM25Index],
210
+ documents: Iterable[Document],
211
+ store_raw: bool = True,
212
+ output_dir: Optional[str] = None,
213
+ ndocs: Optional[int] = None,
214
+ show_progress_bar: bool = True,
215
+ k1: float = 0.9,
216
+ b: float = 0.4,
217
+ ) -> BM25Index:
218
+ # Counting TFs, DFs, doc_lengths, etc.:
219
+ counting = run_counting(
220
+ documents=documents,
221
+ tokenize_fn=BM25Index.tokenize,
222
+ store_raw=store_raw,
223
+ ndocs=ndocs,
224
+ show_progress_bar=show_progress_bar,
225
+ )
226
+
227
+ # Compute term weights and caching:
228
+ posting_lists = counting.posting_lists
229
+ total_docs = len(counting.cid2docid)
230
+ BM25Index.cache_term_weights(
231
+ posting_lists=posting_lists,
232
+ total_docs=total_docs,
233
+ avgdl=counting.avgdl,
234
+ dfs=counting.dfs,
235
+ dls=counting.dls,
236
+ k1=k1,
237
+ b=b,
238
+ )
239
+
240
+ # Assembly and save:
241
+ index = BM25Index(
242
+ posting_lists=posting_lists,
243
+ vocab=counting.vocab,
244
+ cid2docid=counting.cid2docid,
245
+ collection_ids=counting.collection_ids,
246
+ doc_texts=counting.doc_texts,
247
+ )
248
+ return index
249
+
250
+
251
+ @dataclass
252
+ class CSCInvertedIndex:
253
+ posting_lists_matrix: csc_matrix # docid -> posting_list
254
+ vocab: Dict[str, int]
255
+ cid2docid: Dict[str, int] # collection_id -> docid
256
+ collection_ids: List[str] # docid -> collection_id
257
+ doc_texts: Optional[List[str]] = None # docid -> document text
258
+
259
+ def save(self, output_dir: str) -> None:
260
+ os.makedirs(output_dir, exist_ok=True)
261
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
262
+ pickle.dump(self, f)
263
+
264
+ @classmethod
265
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
266
+ index = cls(
267
+ posting_lists_matrix=None,
268
+ vocab={},
269
+ cid2docid={},
270
+ collection_ids=[],
271
+ doc_texts=None,
272
+ )
273
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
274
+ index = pickle.load(f)
275
+ return index
276
+
277
+
278
+ @dataclass
279
+ class CSCBM25Index(CSCInvertedIndex):
280
+
281
+ @staticmethod
282
+ def tokenize(text: str) -> List[str]:
283
+ return simple_tokenize(text)
284
+
285
+ @staticmethod
286
+ def cache_term_weights(
287
+ posting_lists: List[PostingList],
288
+ total_docs: int,
289
+ avgdl: float,
290
+ dfs: List[int],
291
+ dls: List[int],
292
+ k1: float,
293
+ b: float,
294
+ ) -> csc_matrix:
295
+ """Compute term weights and caching"""
296
+ data = []
297
+ indices = []
298
+ indptr = [0]
299
+ max_docid = 0
300
+ for tid, posting_list in enumerate(
301
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
302
+ ):
303
+ idf = CSCBM25Index.calc_idf(df=dfs[tid], N=total_docs)
304
+ for i in range(len(posting_list.docid_postings)):
305
+ docid = posting_list.docid_postings[i]
306
+ if docid > max_docid:
307
+ max_docid = docid
308
+ tf = posting_list.tweight_postings[i]
309
+ dl = dls[docid]
310
+ regularized_tf = CSCBM25Index.calc_regularized_tf(
311
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
312
+ )
313
+ result = regularized_tf * idf
314
+ posting_list.tweight_postings[i] = result # TODO?
315
+ if result != 0:
316
+ data.append(result)
317
+ indices.append(docid)
318
+ indptr.append(len(data))
319
+
320
+ shape = (max_docid, len(posting_lists))
321
+ return csc_matrix((data, indices, indptr), shape=shape)
322
+
323
+ @staticmethod
324
+ def calc_regularized_tf(
325
+ tf: int, dl: float, avgdl: float, k1: float, b: float
326
+ ) -> float:
327
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
328
+
329
+ @staticmethod
330
+ def calc_idf(df: int, N: int):
331
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
332
+
333
+ @classmethod
334
+ def build_from_documents(
335
+ cls: Type[CSCBM25Index],
336
+ documents: Iterable[Document],
337
+ store_raw: bool = True,
338
+ output_dir: Optional[str] = None,
339
+ ndocs: Optional[int] = None,
340
+ show_progress_bar: bool = True,
341
+ k1: float = 0.9,
342
+ b: float = 0.4,
343
+ ) -> CSCBM25Index:
344
+ # Counting TFs, DFs, doc_lengths, etc.:
345
+ counting = run_counting(
346
+ documents=documents,
347
+ tokenize_fn=CSCBM25Index.tokenize,
348
+ store_raw=store_raw,
349
+ ndocs=ndocs,
350
+ show_progress_bar=show_progress_bar,
351
+ )
352
+
353
+ # Compute term weights and caching:
354
+ posting_lists = counting.posting_lists
355
+ total_docs = len(counting.cid2docid)
356
+ posting_lists_matrix = CSCBM25Index.cache_term_weights(
357
+ posting_lists=posting_lists,
358
+ total_docs=total_docs,
359
+ avgdl=counting.avgdl,
360
+ dfs=counting.dfs,
361
+ dls=counting.dls,
362
+ k1=k1,
363
+ b=b,
364
+ )
365
+
366
+ # Assembly and save:
367
+ index = CSCBM25Index(
368
+ posting_lists_matrix=posting_lists_matrix,
369
+ vocab=counting.vocab,
370
+ cid2docid=counting.cid2docid,
371
+ collection_ids=counting.collection_ids,
372
+ doc_texts=counting.doc_texts,
373
+ )
374
+ return index
375
+
376
+
377
+ class BaseCSCInvertedIndexRetriever(BaseRetriever):
378
+
379
+ @property
380
+ @abstractmethod
381
+ def index_class(self) -> Type[CSCInvertedIndex]:
382
+ pass
383
+
384
+ def __init__(self, index_dir: str) -> None:
385
+ self.index = self.index_class.from_saved(index_dir)
386
+
387
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
388
+ toks = self.index.tokenize(query)
389
+ target_docid = self.index.cid2docid[cid]
390
+ term_weights = {}
391
+ for tok in toks:
392
+ if tok not in self.index.vocab:
393
+ continue
394
+ tid = self.index.vocab[tok]
395
+ weight = self.index.posting_lists_matrix[target_docid, tid]
396
+ if weight != 0:
397
+ term_weights[tok] = weight
398
+ return term_weights
399
+
400
+ def score(self, query: str, cid: str) -> float:
401
+ return sum(self.get_term_weights(query=query, cid=cid).values())
402
+
403
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
404
+ toks = self.index.tokenize(query)
405
+ scores = np.zeros(self.index.posting_lists_matrix.shape[0])
406
+ for tok in toks:
407
+ if tok not in self.index.vocab:
408
+ continue
409
+ tid = self.index.vocab[tok]
410
+ col = self.index.posting_lists_matrix[:, tid].toarray().flatten()
411
+ scores += col
412
+
413
+ docids = np.argsort(scores)[::-1][:topk]
414
+ scores = scores[docids]
415
+ return {
416
+ self.index.collection_ids[docid]: score
417
+ for docid, score in zip(docids, scores)
418
+ }
419
+
420
+
421
+ class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
422
+
423
+ @property
424
+ def index_class(self) -> Type[CSCBM25Index]:
425
+ return CSCBM25Index
426
+
427
+
428
+ if __name__ == "__main__":
429
+ top_k = 10
430
+ csc_bm25_retriever = CSCBM25Retriever(index_dir="output/csc_bm25_index")
431
+
432
+ def query(query: str) -> List[Hit]:
433
+ hits = []
434
+ for cid, score in csc_bm25_retriever.retrieve(query).items():
435
+ hit = Hit(
436
+ cid=cid,
437
+ score=score,
438
+ text=csc_bm25_retriever.index.doc_texts[
439
+ csc_bm25_retriever.index.cid2docid[cid]
440
+ ],
441
+ )
442
+ hits.append(hit)
443
+ return hits
444
+
445
+ demo = gr.Interface(
446
+ fn=query,
447
+ inputs=gr.Textbox(lines=1, label="Query"),
448
+ # outputs=["text" for _ in range(top_k)],
449
+ outputs=[gr.Textbox(label=f"Result {i+1}") for i in range(top_k)],
450
+ title="BM25 Retriever",
451
+ description="Enter query",
452
+ )
453
+
454
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.4.3
3
+ aiohttp==3.10.10
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ anyio==4.6.2.post1
7
+ asttokens==2.4.1
8
+ attrs==24.2.0
9
+ audioop-lts==0.2.1
10
+ certifi==2024.8.30
11
+ charset-normalizer==3.4.0
12
+ click==8.1.7
13
+ comm==0.2.2
14
+ contourpy==1.3.0
15
+ cycler==0.12.1
16
+ datasets==3.0.1
17
+ debugpy==1.8.7
18
+ decorator==5.1.1
19
+ dill==0.3.8
20
+ exceptiongroup==1.2.2
21
+ executing==2.1.0
22
+ fastapi==0.115.4
23
+ ffmpy==0.4.0
24
+ filelock==3.16.1
25
+ fonttools==4.54.1
26
+ frozenlist==1.5.0
27
+ fsspec==2024.6.1
28
+ gradio==5.5.0
29
+ gradio_client==1.4.2
30
+ h11==0.14.0
31
+ httpcore==1.0.6
32
+ httpx==0.27.2
33
+ huggingface-hub==0.26.2
34
+ idna==3.10
35
+ importlib_metadata==8.5.0
36
+ ipykernel==6.29.5
37
+ ipython==8.29.0
38
+ jedi==0.19.1
39
+ Jinja2==3.1.4
40
+ joblib==1.4.2
41
+ jupyter_client==8.6.3
42
+ jupyter_core==5.7.2
43
+ kiwisolver==1.4.7
44
+ markdown-it-py==3.0.0
45
+ MarkupSafe==2.1.5
46
+ matplotlib==3.9.2
47
+ matplotlib-inline==0.1.7
48
+ mdurl==0.1.2
49
+ multidict==6.1.0
50
+ multiprocess==0.70.16
51
+ nest_asyncio==1.6.0
52
+ nlp4web-codebase @ git+https://github.com/kwang2049/nlp4web-codebase.git@83f9afbbf7e372c116fdd04997a96449007f861f
53
+ nltk==3.8.1
54
+ numpy==1.26.4
55
+ orjson==3.10.11
56
+ packaging==24.1
57
+ pandas==2.2.2
58
+ parso==0.8.4
59
+ pexpect==4.9.0
60
+ pickleshare==0.7.5
61
+ pillow==11.0.0
62
+ pip==24.2
63
+ platformdirs==4.3.6
64
+ prompt_toolkit==3.0.48
65
+ propcache==0.2.0
66
+ psutil==6.1.0
67
+ ptyprocess==0.7.0
68
+ pure_eval==0.2.3
69
+ pyarrow==18.0.0
70
+ pydantic==2.9.2
71
+ pydantic_core==2.23.4
72
+ pydub==0.25.1
73
+ Pygments==2.18.0
74
+ pyparsing==3.2.0
75
+ python-dateutil==2.9.0
76
+ python-multipart==0.0.12
77
+ pytrec_eval==0.5
78
+ pytz==2024.2
79
+ PyYAML==6.0.2
80
+ pyzmq==26.2.0
81
+ regex==2024.9.11
82
+ requests==2.32.3
83
+ rich==13.9.4
84
+ ruff==0.7.2
85
+ safehttpx==0.1.1
86
+ scipy==1.13.1
87
+ semantic-version==2.10.0
88
+ setuptools==75.1.0
89
+ shellingham==1.5.4
90
+ six==1.16.0
91
+ sniffio==1.3.1
92
+ stack-data==0.6.2
93
+ starlette==0.41.2
94
+ tomlkit==0.12.0
95
+ tornado==6.4.1
96
+ tqdm==4.66.5
97
+ traitlets==5.14.3
98
+ typer==0.12.5
99
+ typing_extensions==4.12.2
100
+ tzdata==2024.2
101
+ ujson==5.10.0
102
+ urllib3==2.2.3
103
+ uvicorn==0.32.0
104
+ wcwidth==0.2.13
105
+ websockets==12.0
106
+ wheel==0.44.0
107
+ xxhash==3.5.0
108
+ yarl==1.17.1
109
+ zipp==3.20.2