wnathanael commited on
Commit
4b38ca9
1 Parent(s): 9622331

Prepared for submit

Browse files
app.py CHANGED
@@ -1,7 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ import pickle
4
+ import os
5
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
6
+ from nlp4web_codebase.ir.data_loaders.dm import Document
7
+ from collections import Counter
8
+ import tqdm
9
+ import re
10
+ import nltk
11
+ nltk.download("stopwords", quiet=True)
12
+ from nltk.corpus import stopwords as nltk_stopwords
13
+
14
+ LANGUAGE = "english"
15
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
16
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
17
+
18
+
19
+ def word_splitting(text: str) -> List[str]:
20
+ return word_splitter(text.lower())
21
+
22
+ def lemmatization(words: List[str]) -> List[str]:
23
+ return words # We ignore lemmatization here for simplicity
24
+
25
+ def simple_tokenize(text: str) -> List[str]:
26
+ words = word_splitting(text)
27
+ tokenized = list(filter(lambda w: w not in stopwords, words))
28
+ tokenized = lemmatization(tokenized)
29
+ return tokenized
30
+
31
+ T = TypeVar("T", bound="InvertedIndex")
32
+
33
+ @dataclass
34
+ class PostingList:
35
+ term: str # The term
36
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
37
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
38
+
39
+
40
+ @dataclass
41
+ class InvertedIndex:
42
+ posting_lists: List[PostingList] # docid -> posting_list
43
+ vocab: Dict[str, int]
44
+ cid2docid: Dict[str, int] # collection_id -> docid
45
+ collection_ids: List[str] # docid -> collection_id
46
+ doc_texts: Optional[List[str]] = None # docid -> document text
47
+
48
+ def save(self, output_dir: str) -> None:
49
+ os.makedirs(output_dir, exist_ok=True)
50
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
51
+ pickle.dump(self, f)
52
+
53
+ @classmethod
54
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
55
+ index = cls(
56
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
57
+ )
58
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
59
+ index = pickle.load(f)
60
+ return index
61
+
62
+
63
+ # The output of the counting function:
64
+ @dataclass
65
+ class Counting:
66
+ posting_lists: List[PostingList]
67
+ vocab: Dict[str, int]
68
+ cid2docid: Dict[str, int]
69
+ collection_ids: List[str]
70
+ dfs: List[int] # tid -> df
71
+ dls: List[int] # docid -> doc length
72
+ avgdl: float
73
+ nterms: int
74
+ doc_texts: Optional[List[str]] = None
75
+
76
+ def run_counting(
77
+ documents: Iterable[Document],
78
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
79
+ store_raw: bool = True, # store the document text in doc_texts
80
+ ndocs: Optional[int] = None,
81
+ show_progress_bar: bool = True,
82
+ ) -> Counting:
83
+ """Counting TFs, DFs, doc_lengths, etc."""
84
+ posting_lists: List[PostingList] = []
85
+ vocab: Dict[str, int] = {}
86
+ cid2docid: Dict[str, int] = {}
87
+ collection_ids: List[str] = []
88
+ dfs: List[int] = [] # tid -> df
89
+ dls: List[int] = [] # docid -> doc length
90
+ nterms: int = 0
91
+ doc_texts: Optional[List[str]] = []
92
+ for doc in tqdm.tqdm(
93
+ documents,
94
+ desc="Counting",
95
+ total=ndocs,
96
+ disable=not show_progress_bar,
97
+ ):
98
+ if doc.collection_id in cid2docid:
99
+ continue
100
+ collection_ids.append(doc.collection_id)
101
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
102
+ toks = tokenize_fn(doc.text)
103
+ tok2tf = Counter(toks)
104
+ dls.append(sum(tok2tf.values()))
105
+ for tok, tf in tok2tf.items():
106
+ nterms += tf
107
+ tid = vocab.get(tok, None)
108
+ if tid is None:
109
+ posting_lists.append(
110
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
111
+ )
112
+ tid = vocab.setdefault(tok, len(vocab))
113
+ posting_lists[tid].docid_postings.append(docid)
114
+ posting_lists[tid].tweight_postings.append(tf)
115
+ if tid < len(dfs):
116
+ dfs[tid] += 1
117
+ else:
118
+ dfs.append(0)
119
+ if store_raw:
120
+ doc_texts.append(doc.text)
121
+ else:
122
+ doc_texts = None
123
+ return Counting(
124
+ posting_lists=posting_lists,
125
+ vocab=vocab,
126
+ cid2docid=cid2docid,
127
+ collection_ids=collection_ids,
128
+ dfs=dfs,
129
+ dls=dls,
130
+ avgdl=sum(dls) / len(dls),
131
+ nterms=nterms,
132
+ doc_texts=doc_texts,
133
+ )
134
+
135
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
136
+ sciq = load_sciq()
137
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
138
+
139
+ from dataclasses import asdict, dataclass
140
+ import math
141
+ import os
142
+ from typing import Iterable, List, Optional, Type
143
+ import tqdm
144
+ from nlp4web_codebase.ir.data_loaders.dm import Document
145
+
146
+ @dataclass
147
+ class BM25Index(InvertedIndex):
148
+
149
+ @staticmethod
150
+ def tokenize(text: str) -> List[str]:
151
+ return simple_tokenize(text)
152
+
153
+ @staticmethod
154
+ def cache_term_weights(
155
+ posting_lists: List[PostingList],
156
+ total_docs: int,
157
+ avgdl: float,
158
+ dfs: List[int],
159
+ dls: List[int],
160
+ k1: float,
161
+ b: float,
162
+ ) -> None:
163
+ """Compute term weights and caching"""
164
+
165
+ N = total_docs
166
+ for tid, posting_list in enumerate(
167
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
168
+ ):
169
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
170
+ for i in range(len(posting_list.docid_postings)):
171
+ docid = posting_list.docid_postings[i]
172
+ tf = posting_list.tweight_postings[i]
173
+ dl = dls[docid]
174
+ regularized_tf = BM25Index.calc_regularized_tf(
175
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
176
+ )
177
+ posting_list.tweight_postings[i] = regularized_tf * idf
178
+
179
+ @staticmethod
180
+ def calc_regularized_tf(
181
+ tf: int, dl: float, avgdl: float, k1: float, b: float
182
+ ) -> float:
183
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
184
+
185
+ @staticmethod
186
+ def calc_idf(df: int, N: int):
187
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
188
+
189
+ @classmethod
190
+ def build_from_documents(
191
+ cls: Type[BM25Index],
192
+ documents: Iterable[Document],
193
+ store_raw: bool = True,
194
+ output_dir: Optional[str] = None,
195
+ ndocs: Optional[int] = None,
196
+ show_progress_bar: bool = True,
197
+ k1: float = 0.9,
198
+ b: float = 0.4,
199
+ ) -> BM25Index:
200
+ # Counting TFs, DFs, doc_lengths, etc.:
201
+ counting = run_counting(
202
+ documents=documents,
203
+ tokenize_fn=BM25Index.tokenize,
204
+ store_raw=store_raw,
205
+ ndocs=ndocs,
206
+ show_progress_bar=show_progress_bar,
207
+ )
208
+
209
+ # Compute term weights and caching:
210
+ posting_lists = counting.posting_lists
211
+ total_docs = len(counting.cid2docid)
212
+ BM25Index.cache_term_weights(
213
+ posting_lists=posting_lists,
214
+ total_docs=total_docs,
215
+ avgdl=counting.avgdl,
216
+ dfs=counting.dfs,
217
+ dls=counting.dls,
218
+ k1=k1,
219
+ b=b,
220
+ )
221
+
222
+ # Assembly and save:
223
+ index = BM25Index(
224
+ posting_lists=posting_lists,
225
+ vocab=counting.vocab,
226
+ cid2docid=counting.cid2docid,
227
+ collection_ids=counting.collection_ids,
228
+ doc_texts=counting.doc_texts,
229
+ )
230
+ return index
231
+
232
+ bm25_index = BM25Index.build_from_documents(
233
+ documents=iter(sciq.corpus),
234
+ ndocs=12160,
235
+ show_progress_bar=True,
236
+ )
237
+ bm25_index.save("output/bm25_index")
238
+
239
+ from nlp4web_codebase.ir.models import BaseRetriever
240
+ from typing import Type
241
+ from abc import abstractmethod
242
+
243
+ class BaseInvertedIndexRetriever(BaseRetriever):
244
+
245
+ @property
246
+ @abstractmethod
247
+ def index_class(self) -> Type[InvertedIndex]:
248
+ pass
249
+
250
+ def __init__(self, index_dir: str) -> None:
251
+ self.index = self.index_class.from_saved(index_dir)
252
+
253
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
254
+ toks = self.index.tokenize(query)
255
+ target_docid = self.index.cid2docid[cid]
256
+ term_weights = {}
257
+ for tok in toks:
258
+ if tok not in self.index.vocab:
259
+ continue
260
+ tid = self.index.vocab[tok]
261
+ posting_list = self.index.posting_lists[tid]
262
+ for docid, tweight in zip(
263
+ posting_list.docid_postings, posting_list.tweight_postings
264
+ ):
265
+ if docid == target_docid:
266
+ term_weights[tok] = tweight
267
+ break
268
+ return term_weights
269
+
270
+ def score(self, query: str, cid: str) -> float:
271
+ return sum(self.get_term_weights(query=query, cid=cid).values())
272
+
273
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
274
+ toks = self.index.tokenize(query)
275
+ docid2score: Dict[int, float] = {}
276
+ for tok in toks:
277
+ if tok not in self.index.vocab:
278
+ continue
279
+ tid = self.index.vocab[tok]
280
+ posting_list = self.index.posting_lists[tid]
281
+ for docid, tweight in zip(
282
+ posting_list.docid_postings, posting_list.tweight_postings
283
+ ):
284
+ docid2score.setdefault(docid, 0)
285
+ docid2score[docid] += tweight
286
+ docid2score = dict(
287
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
288
+ )
289
+ return {
290
+ self.index.collection_ids[docid]: score
291
+ for docid, score in docid2score.items()
292
+ }
293
+
294
+ class BM25Retriever(BaseInvertedIndexRetriever):
295
+
296
+ @property
297
+ def index_class(self) -> Type[BM25Index]:
298
+ return BM25Index
299
+
300
+ from nlp4web_codebase.ir.data_loaders import Split
301
+ import pytrec_eval
302
+ import numpy as np
303
+
304
+ def evaluate_map(rankings: Dict[str, Dict[str, float]], split=Split.dev) -> float:
305
+ metric = "map_cut_10"
306
+ qrels = sciq.get_qrels_dict(split)
307
+ evaluator = pytrec_eval.RelevanceEvaluator(sciq.get_qrels_dict(split), (metric,))
308
+ qps = evaluator.evaluate(rankings)
309
+ return float(np.mean([qp[metric] for qp in qps.values()]))
310
+
311
+ # Loading dataset:
312
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
313
+ sciq = load_sciq()
314
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
315
+
316
+ # Building BM25 index and save:
317
+ bm25_index = BM25Index.build_from_documents(
318
+ documents=iter(sciq.corpus),
319
+ ndocs=12160,
320
+ show_progress_bar=True
321
+ )
322
+ bm25_index.save("output/bm25_index")
323
+
324
+
325
+ plots_b: Dict[str, List[float]] = {
326
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
327
+ "Y": []
328
+ }
329
+ plots_k1: Dict[str, List[float]] = {
330
+ "X": [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
331
+ "Y": []
332
+ }
333
+
334
+ ## YOUR_CODE_STARTS_HERE
335
+ # Two steps should be involved:
336
+ # Step 1. Fix k1 value to the default one 0.9,
337
+ # go through all the candidate b values (0, 0.1, ..., 1.0),
338
+ # and record in plots_b["Y"] the corresponding performances obtained via evaluate_map;
339
+ # Step 2. Fix b to the best one in step 1. and do the same for k1.
340
+
341
+ sciq_dataset = load_sciq()
342
+ dev_queries = sciq.get_split_queries(Split.dev)
343
+ dev_qrels = sciq.get_qrels_dict(Split.dev)
344
+
345
+
346
+
347
+ def evaluate_bm25(k1, b):
348
+ bm25_index = BM25Index.build_from_documents(
349
+ documents=iter(sciq_dataset.corpus),
350
+ ndocs=len(sciq_dataset.corpus),
351
+ show_progress_bar=True,
352
+ k1=k1,
353
+ b=b
354
+ )
355
+ bm25_index.save("output/bm25_index_task1")
356
+
357
+ # Initialize BM25Retriever with specified k1 and b values
358
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index_task1")
359
+
360
+ # Dictionary to store rankings for each query
361
+ rankings = {}
362
+
363
+
364
+ for query in dev_queries:
365
+ # print(query.text)
366
+ query_text = query.text
367
+ query_id = query.query_id
368
+
369
+ # Get top-ranked documents for the query
370
+ top_documents = bm25_retriever.retrieve(query_text, topk=10)
371
+ rankings[query_id] = top_documents # Store in the format expected by evaluate_map
372
+
373
+ # Evaluate MAP@10 on the dev split
374
+ return evaluate_map(rankings, split=Split.dev)
375
+
376
+ for b in plots_b["X"]:
377
+ plots_b["Y"].append(evaluate_bm25(k1=0.9, b=b))
378
+
379
+ # Find the best value of `b`
380
+ best_b = plots_b["X"][np.argmax(plots_b["Y"])]
381
+
382
+ for k1 in plots_k1["X"]:
383
+ plots_k1["Y"].append(evaluate_bm25(k1=k1, b=best_b))
384
+
385
+ # Find the best value of `b`
386
+ best_k1 = plots_k1["X"][np.argmax(plots_k1["Y"])]
387
+
388
+
389
+
390
+ # Hint (on using the pre-requisite code):
391
+ # - One can use the loaded sciq dataset directly (loaded in the pre-requisite code);
392
+ # - One can build bm25_index with `BM25Index.build_from_documents`;
393
+ # - One can use BM25Retriever to load the index and perform retrieval on the dev queries
394
+ # (dev queries can be obtained via sciq.get_split_queries(Split.dev))
395
+ ## YOUR_CODE_ENDS_HERE
396
+
397
+ from scipy.sparse._csc import csc_matrix
398
+
399
+ @dataclass
400
+ class CSCInvertedIndex:
401
+ posting_lists_matrix: csc_matrix # docid -> posting_list
402
+ vocab: Dict[str, int]
403
+ cid2docid: Dict[str, int] # collection_id -> docid
404
+ collection_ids: List[str] # docid -> collection_id
405
+ doc_texts: Optional[List[str]] = None # docid -> document text
406
+
407
+ def save(self, output_dir: str) -> None:
408
+ os.makedirs(output_dir, exist_ok=True)
409
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
410
+ pickle.dump(self, f)
411
+
412
+ @classmethod
413
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
414
+ index = cls(
415
+ posting_lists_matrix=None, vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
416
+ )
417
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
418
+ index = pickle.load(f)
419
+ return index
420
+
421
+
422
+ @dataclass
423
+ class CSCBM25Index(CSCInvertedIndex):
424
+
425
+ @staticmethod
426
+ def tokenize(text: str) -> List[str]:
427
+ return simple_tokenize(text)
428
+
429
+ @staticmethod
430
+ def cache_term_weights(
431
+ posting_lists: List[PostingList],
432
+ total_docs: int,
433
+ avgdl: float,
434
+ dfs: List[int],
435
+ dls: List[int],
436
+ k1: float,
437
+ b: float,
438
+ ) -> csc_matrix:
439
+ """Compute term weights and caching"""
440
+
441
+ ## YOUR_CODE_STARTS_HERE
442
+ data = [] # Holds the term weights
443
+ indices = [] # Document IDs
444
+ indptr = [0] # Term IDs
445
+
446
+ N = total_docs
447
+ for tid, posting_list in enumerate(
448
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
449
+ ):
450
+ idf = CSCBM25Index.calc_idf(df=dfs[tid], N=N)
451
+ for i in range(len(posting_list.docid_postings)):
452
+ docid = posting_list.docid_postings[i]
453
+ tf = posting_list.tweight_postings[i]
454
+ dl = dls[docid]
455
+ regularized_tf = CSCBM25Index.calc_regularized_tf(
456
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
457
+ )
458
+ term_weight = regularized_tf * idf
459
+
460
+ # Append to lists for sparse matrix
461
+ data.append(term_weight)
462
+ indices.append(docid)
463
+ indptr.append(len(data))
464
+
465
+ # Create a CSC matrix where rows are documents, columns are terms
466
+ term_weights_matrix = csc_matrix((data, indices, indptr), shape=(total_docs, len(posting_lists)), dtype=np.float32)
467
+ print("INDPTR HERE")
468
+
469
+ return term_weights_matrix
470
+ ## YOUR_CODE_ENDS_HERE
471
+
472
+ @staticmethod
473
+ def calc_regularized_tf(
474
+ tf: int, dl: float, avgdl: float, k1: float, b: float
475
+ ) -> float:
476
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
477
+
478
+ @staticmethod
479
+ def calc_idf(df: int, N: int):
480
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
481
+
482
+ @classmethod
483
+ def build_from_documents(
484
+ cls: Type[CSCBM25Index],
485
+ documents: Iterable[Document],
486
+ store_raw: bool = True,
487
+ output_dir: Optional[str] = None,
488
+ ndocs: Optional[int] = None,
489
+ show_progress_bar: bool = True,
490
+ k1: float = 0.9,
491
+ b: float = 0.4,
492
+ ) -> CSCBM25Index:
493
+ # Counting TFs, DFs, doc_lengths, etc.:
494
+ counting = run_counting(
495
+ documents=documents,
496
+ tokenize_fn=CSCBM25Index.tokenize,
497
+ store_raw=store_raw,
498
+ ndocs=ndocs,
499
+ show_progress_bar=show_progress_bar,
500
+ )
501
+
502
+ # Compute term weights and caching:
503
+ posting_lists = counting.posting_lists
504
+ total_docs = len(counting.cid2docid)
505
+ posting_lists_matrix = CSCBM25Index.cache_term_weights(
506
+ posting_lists=posting_lists,
507
+ total_docs=total_docs,
508
+ avgdl=counting.avgdl,
509
+ dfs=counting.dfs,
510
+ dls=counting.dls,
511
+ k1=k1,
512
+ b=b,
513
+ )
514
+
515
+ # Assembly and save:
516
+ index = CSCBM25Index(
517
+ posting_lists_matrix=posting_lists_matrix,
518
+ vocab=counting.vocab,
519
+ cid2docid=counting.cid2docid,
520
+ collection_ids=counting.collection_ids,
521
+ doc_texts=counting.doc_texts,
522
+ )
523
+ return index
524
+
525
+ csc_bm25_index = CSCBM25Index.build_from_documents(
526
+ documents=iter(sciq.corpus),
527
+ ndocs=12160,
528
+ show_progress_bar=True,
529
+ k1=best_k1,
530
+ b=best_b
531
+ )
532
+ csc_bm25_index.save("output/csc_bm25_index")
533
+
534
+ class BaseCSCInvertedIndexRetriever(BaseRetriever):
535
+
536
+ @property
537
+ @abstractmethod
538
+ def index_class(self) -> Type[CSCInvertedIndex]:
539
+ pass
540
+
541
+ def __init__(self, index_dir: str) -> None:
542
+ self.index = self.index_class.from_saved(index_dir)
543
+
544
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
545
+ ## YOUR_CODE_STARTS_HERE
546
+ toks = self.index.tokenize(query)
547
+ target_docid = self.index.cid2docid[cid]
548
+ term_weights = {}
549
+ for tok in toks:
550
+ if tok not in self.index.vocab:
551
+ continue
552
+ tid = self.index.vocab[tok]
553
+ posting_list = self.index.posting_lists_matrix.getcol(tid)
554
+ doc_ids = posting_list.indices
555
+ tweights = posting_list.data
556
+ for docid, tweight in zip(
557
+ doc_ids, tweights
558
+ ):
559
+ if docid == target_docid:
560
+ term_weights[tok] = tweight
561
+ break
562
+ return term_weights
563
+ ## YOUR_CODE_ENDS_HERE
564
+
565
+ def score(self, query: str, cid: str) -> float:
566
+ return sum(self.get_term_weights(query=query, cid=cid).values())
567
+
568
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
569
+ ## YOUR_CODE_STARTS_HERE
570
+ toks = self.index.tokenize(query)
571
+ docid2score: Dict[int, float] = {}
572
+ for tok in toks:
573
+ if tok not in self.index.vocab:
574
+ continue
575
+ tid = self.index.vocab[tok]
576
+ posting_list = self.index.posting_lists_matrix.getcol(tid)
577
+ doc_ids = posting_list.indices
578
+ tweights = posting_list.data
579
+ for docid, tweight in zip(
580
+ doc_ids, tweights
581
+ ):
582
+ docid2score.setdefault(docid, 0)
583
+ docid2score[docid] += tweight
584
+ docid2score = dict(
585
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
586
+ )
587
+ return {
588
+ self.index.collection_ids[docid]: score
589
+ for docid, score in docid2score.items()
590
+ }
591
+ ## YOUR_CODE_ENDS_HERE
592
+
593
+ class CSCBM25Retriever(BaseCSCInvertedIndexRetriever):
594
+
595
+ @property
596
+ def index_class(self) -> Type[CSCBM25Index]:
597
+ return CSCBM25Index
598
+
599
  import gradio as gr
600
+ from typing import TypedDict
601
+
602
+ class Hit(TypedDict):
603
+ cid: str
604
+ score: float
605
+ text: str
606
+
607
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
608
+ return_type = List[Hit]
609
+
610
+ ## YOUR_CODE_STARTS_HERE
611
+ # Building BM25 index and save:
612
+ bm25_index = BM25Index.build_from_documents(
613
+ documents=iter(sciq.corpus),
614
+ ndocs=12160,
615
+ show_progress_bar=True
616
+ )
617
+ bm25_index.save("output/bm25_index_app")
618
+
619
+ # Loading index and use BM25 retriever to retrieve:
620
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index_app")
621
+
622
+ # print(bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?")) # the ranking
623
+ def search(query: str) -> List[Hit]:
624
+
625
+ response = bm25_retriever.retrieve(query)
626
+ hits = []
627
+ for cid, score in response.items():
628
+ docid = bm25_index.cid2docid[cid]
629
+ hits.append(Hit(cid=cid, score=score, text=sciq.corpus[docid]))
630
+ return hits
631
 
 
 
632
 
633
+ # Create the Gradio interface
634
+ demo = gr.Interface(
635
+ fn=search, # Function to call on submit
636
+ inputs=gr.Textbox(label="Enter your query"), # Input field with label
637
+ outputs=gr.Textbox(label="RESULT HERE"), # Output field to display result
638
+ live=False # Disable real-time updates to require a button click
639
+ )
640
+ ## YOUR_CODE_ENDS_HERE
641
+ demo.launch(share=True)
nlp4web-codebase-main/nlp4web_codebase/__init__.py ADDED
File without changes
nlp4web-codebase-main/nlp4web_codebase/ir/__init__.py ADDED
File without changes
nlp4web-codebase-main/nlp4web_codebase/ir/analysis.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Protocol
3
+ import pandas as pd
4
+ import tqdm
5
+ import ujson
6
+ from nlp4web_codebase.ir.data_loaders import IRDataset
7
+
8
+
9
+ def round_dict(obj: Dict[str, float], ndigits: int = 4) -> Dict[str, float]:
10
+ return {k: round(v, ndigits=ndigits) for k, v in obj.items()}
11
+
12
+
13
+ def sort_dict(obj: Dict[str, float], reverse: bool = True) -> Dict[str, float]:
14
+ return dict(sorted(obj.items(), key=lambda pair: pair[1], reverse=reverse))
15
+
16
+
17
+ def save_ranking_results(
18
+ output_dir: str,
19
+ query_ids: List[str],
20
+ rankings: List[Dict[str, float]],
21
+ query_performances_lists: List[Dict[str, float]],
22
+ cid2tweights_lists: Optional[List[Dict[str, Dict[str, float]]]] = None,
23
+ ):
24
+ os.makedirs(output_dir, exist_ok=True)
25
+ output_path = os.path.join(output_dir, "ranking_results.jsonl")
26
+ rows = []
27
+ for i, (query_id, ranking, query_performances) in enumerate(
28
+ zip(query_ids, rankings, query_performances_lists)
29
+ ):
30
+ row = {
31
+ "query_id": query_id,
32
+ "ranking": round_dict(ranking),
33
+ "query_performances": round_dict(query_performances),
34
+ "cid2tweights": {},
35
+ }
36
+ if cid2tweights_lists is not None:
37
+ row["cid2tweights"] = {
38
+ cid: round_dict(tws) for cid, tws in cid2tweights_lists[i].items()
39
+ }
40
+ rows.append(row)
41
+ pd.DataFrame(rows).to_json(
42
+ output_path,
43
+ orient="records",
44
+ lines=True,
45
+ )
46
+
47
+
48
+ class TermWeightingFunction(Protocol):
49
+ def __call__(self, query: str, cid: str) -> Dict[str, float]: ...
50
+
51
+
52
+ def compare(
53
+ dataset: IRDataset,
54
+ results_path1: str,
55
+ results_path2: str,
56
+ output_dir: str,
57
+ main_metric: str = "recip_rank",
58
+ system1: Optional[str] = None,
59
+ system2: Optional[str] = None,
60
+ term_weighting_fn1: Optional[TermWeightingFunction] = None,
61
+ term_weighting_fn2: Optional[TermWeightingFunction] = None,
62
+ ) -> None:
63
+ os.makedirs(output_dir, exist_ok=True)
64
+ df1 = pd.read_json(results_path1, orient="records", lines=True)
65
+ df2 = pd.read_json(results_path2, orient="records", lines=True)
66
+ assert len(df1) == len(df2)
67
+ all_qrels = {}
68
+ for split in dataset.split2qrels:
69
+ all_qrels.update(dataset.get_qrels_dict(split))
70
+ qid2query = {query.query_id: query for query in dataset.queries}
71
+ cid2doc = {doc.collection_id: doc for doc in dataset.corpus}
72
+ diff_col = f"{main_metric}:qp1-qp2"
73
+ merged = pd.merge(df1, df2, on="query_id", how="outer")
74
+ rows = []
75
+ for _, example in tqdm.tqdm(merged.iterrows(), desc="Comparing", total=len(merged)):
76
+ docs = {cid: cid2doc[cid].text for cid in dict(example["ranking_x"])}
77
+ docs.update({cid: cid2doc[cid].text for cid in dict(example["ranking_y"])})
78
+ query_id = example["query_id"]
79
+ row = {
80
+ "query_id": query_id,
81
+ "query": qid2query[query_id].text,
82
+ diff_col: example["query_performances_x"][main_metric]
83
+ - example["query_performances_y"][main_metric],
84
+ "ranking1": ujson.dumps(example["ranking_x"], indent=4),
85
+ "ranking2": ujson.dumps(example["ranking_y"], indent=4),
86
+ "docs": ujson.dumps(docs, indent=4),
87
+ "query_performances1": ujson.dumps(
88
+ example["query_performances_x"], indent=4
89
+ ),
90
+ "query_performances2": ujson.dumps(
91
+ example["query_performances_y"], indent=4
92
+ ),
93
+ "qrels": ujson.dumps(all_qrels[query_id], indent=4),
94
+ }
95
+ if term_weighting_fn1 is not None and term_weighting_fn2 is not None:
96
+ all_cids = set(example["ranking_x"]) | set(example["ranking_y"])
97
+ cid2tweights1 = {}
98
+ cid2tweights2 = {}
99
+ ranking1 = {}
100
+ ranking2 = {}
101
+ for cid in all_cids:
102
+ tweights1 = term_weighting_fn1(query=qid2query[query_id].text, cid=cid)
103
+ tweights2 = term_weighting_fn2(query=qid2query[query_id].text, cid=cid)
104
+ ranking1[cid] = sum(tweights1.values())
105
+ ranking2[cid] = sum(tweights2.values())
106
+ cid2tweights1[cid] = tweights1
107
+ cid2tweights2[cid] = tweights2
108
+ ranking1 = sort_dict(ranking1)
109
+ ranking2 = sort_dict(ranking2)
110
+ row["ranking1"] = ujson.dumps(ranking1, indent=4)
111
+ row["ranking2"] = ujson.dumps(ranking2, indent=4)
112
+ cid2tweights1 = {cid: cid2tweights1[cid] for cid in ranking1}
113
+ cid2tweights2 = {cid: cid2tweights2[cid] for cid in ranking2}
114
+ row["cid2tweights1"] = ujson.dumps(cid2tweights1, indent=4)
115
+ row["cid2tweights2"] = ujson.dumps(cid2tweights2, indent=4)
116
+ rows.append(row)
117
+ table = pd.DataFrame(rows).sort_values(by=diff_col, ascending=False)
118
+ output_path = os.path.join(output_dir, f"compare-{system1}_vs_{system2}.tsv")
119
+ table.to_csv(output_path, sep="\t", index=False)
120
+
121
+
122
+ # if __name__ == "__main__":
123
+ # # python -m lecture2.bm25.analysis
124
+ # from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
125
+ # from lecture2.bm25.bm25_retriever import BM25Retriever
126
+ # from lecture2.bm25.tfidf_retriever import TFIDFRetriever
127
+ # import numpy as np
128
+
129
+ # sciq = load_sciq()
130
+ # system1 = "bm25"
131
+ # system2 = "tfidf"
132
+ # results_path1 = f"output/sciq-{system1}/results/ranking_results.jsonl"
133
+ # results_path2 = f"output/sciq-{system2}/results/ranking_results.jsonl"
134
+ # index_dir1 = f"output/sciq-{system1}"
135
+ # index_dir2 = f"output/sciq-{system2}"
136
+ # compare(
137
+ # dataset=sciq,
138
+ # results_path1=results_path1,
139
+ # results_path2=results_path2,
140
+ # output_dir=f"output/sciq-{system1}_vs_{system2}",
141
+ # system1=system1,
142
+ # system2=system2,
143
+ # term_weighting_fn1=BM25Retriever(index_dir1).get_term_weights,
144
+ # term_weighting_fn2=TFIDFRetriever(index_dir2).get_term_weights,
145
+ # )
146
+
147
+ # # bias on #shared_terms of TFIDF:
148
+ # df1 = pd.read_json(results_path1, orient="records", lines=True)
149
+ # df2 = pd.read_json(results_path2, orient="records", lines=True)
150
+ # merged = pd.merge(df1, df2, on="query_id", how="outer")
151
+ # nterms1 = []
152
+ # nterms2 = []
153
+ # for _, row in merged.iterrows():
154
+ # nterms1.append(len(list(dict(row["cid2tweights_x"]).values())[0]))
155
+ # nterms2.append(len(list(dict(row["cid2tweights_y"]).values())[0]))
156
+ # percentiles = (5, 25, 50, 75, 95)
157
+ # print(system1, np.percentile(nterms1, percentiles), np.mean(nterms1).round(2))
158
+ # print(system2, np.percentile(nterms2, percentiles), np.mean(nterms2).round(2))
159
+ # # bm25 [ 3. 4. 5. 7. 11.] 5.64
160
+ # # tfidf [1. 2. 3. 5. 9.] 3.58
nlp4web-codebase-main/nlp4web_codebase/ir/data_loaders/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Dict, List
4
+ from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel
5
+
6
+
7
+ class Split(str, Enum):
8
+ train = "train"
9
+ dev = "dev"
10
+ test = "test"
11
+
12
+
13
+ @dataclass
14
+ class IRDataset:
15
+ corpus: List[Document]
16
+ queries: List[Query]
17
+ split2qrels: Dict[Split, List[QRel]]
18
+
19
+ def get_stats(self) -> Dict[str, int]:
20
+ stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
21
+ for split, qrels in self.split2qrels.items():
22
+ stats[f"|qrels-{split}|"] = len(qrels)
23
+ return stats
24
+
25
+ def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
26
+ qrels_dict = {}
27
+ for qrel in self.split2qrels[split]:
28
+ qrels_dict.setdefault(qrel.query_id, {})
29
+ qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
30
+ return qrels_dict
31
+
32
+ def get_split_queries(self, split: Split) -> List[Query]:
33
+ qrels = self.split2qrels[split]
34
+ qids = {qrel.query_id for qrel in qrels}
35
+ return list(filter(lambda query: query.query_id in qids, self.queries))
nlp4web-codebase-main/nlp4web_codebase/ir/data_loaders/dm.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+
5
+ @dataclass
6
+ class Document:
7
+ collection_id: str
8
+ text: str
9
+
10
+
11
+ @dataclass
12
+ class Query:
13
+ query_id: str
14
+ text: str
15
+
16
+
17
+ @dataclass
18
+ class QRel:
19
+ query_id: str
20
+ collection_id: str
21
+ relevance: int
22
+ answer: Optional[str] = None
nlp4web-codebase-main/nlp4web_codebase/ir/data_loaders/sciq.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from nlp4web_codebase.ir.data_loaders import IRDataset, Split
3
+ from nlp4web_codebase.ir.data_loaders.dm import Document, Query, QRel
4
+ from datasets import load_dataset
5
+ import joblib
6
+
7
+
8
+ @(joblib.Memory(".cache").cache)
9
+ def load_sciq(verbose: bool = False) -> IRDataset:
10
+ train = load_dataset("allenai/sciq", split="train")
11
+ validation = load_dataset("allenai/sciq", split="validation")
12
+ test = load_dataset("allenai/sciq", split="test")
13
+ data = {Split.train: train, Split.dev: validation, Split.test: test}
14
+
15
+ # Each duplicated record is the same to each other:
16
+ df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
17
+ for question, group in df.groupby("question"):
18
+ assert len(set(group["support"].tolist())) == len(group)
19
+ assert len(set(group["correct_answer"].tolist())) == len(group)
20
+
21
+ # Build:
22
+ corpus = []
23
+ queries = []
24
+ split2qrels: Dict[str, List[dict]] = {}
25
+ question2id = {}
26
+ support2id = {}
27
+ for split, rows in data.items():
28
+ if verbose:
29
+ print(f"|raw_{split}|", len(rows))
30
+ split2qrels[split] = []
31
+ for i, row in enumerate(rows):
32
+ example_id = f"{split}-{i}"
33
+ support: str = row["support"]
34
+ if len(support.strip()) == 0:
35
+ continue
36
+ question = row["question"]
37
+ if len(support.strip()) == 0:
38
+ continue
39
+ if support in support2id:
40
+ continue
41
+ else:
42
+ support2id[support] = example_id
43
+ if question in question2id:
44
+ continue
45
+ else:
46
+ question2id[question] = example_id
47
+ doc = {"collection_id": example_id, "text": support}
48
+ query = {"query_id": example_id, "text": row["question"]}
49
+ qrel = {
50
+ "query_id": example_id,
51
+ "collection_id": example_id,
52
+ "relevance": 1,
53
+ "answer": row["correct_answer"],
54
+ }
55
+ corpus.append(Document(**doc))
56
+ queries.append(Query(**query))
57
+ split2qrels[split].append(QRel(**qrel))
58
+
59
+ # Assembly and return:
60
+ return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ # python -m nlp4web_codebase.ir.data_loaders.sciq
65
+ import ujson
66
+ import time
67
+
68
+ start = time.time()
69
+ dataset = load_sciq(verbose=True)
70
+ print(f"Loading costs: {time.time() - start}s")
71
+ print(ujson.dumps(dataset.get_stats(), indent=4))
72
+ # ________________________________________________________________________________
73
+ # [Memory] Calling __main__--home-kwang-research-nlp4web-ir-exercise-nlp4web-nlp4web-ir-data_loaders-sciq.load_sciq...
74
+ # load_sciq(verbose=True)
75
+ # |raw_train| 11679
76
+ # |raw_dev| 1000
77
+ # |raw_test| 1000
78
+ # ________________________________________________________load_sciq - 7.3s, 0.1min
79
+ # Loading costs: 7.260092735290527s
80
+ # {
81
+ # "|corpus|": 12160,
82
+ # "|queries|": 12160,
83
+ # "|qrels-train|": 10409,
84
+ # "|qrels-dev|": 875,
85
+ # "|qrels-test|": 876
86
+ # }
nlp4web-codebase-main/nlp4web_codebase/ir/models/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, Type
3
+
4
+
5
+ class BaseRetriever(ABC):
6
+
7
+ @property
8
+ @abstractmethod
9
+ def index_class(self) -> Type[Any]:
10
+ pass
11
+
12
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ def score(self, query: str, cid: str) -> float:
17
+ pass
18
+
19
+ @abstractmethod
20
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
21
+ pass
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ git+https://github.com/kwang2049/nlp4web-codebase.git
3
+ pytrec_eval
4
+ tqdm
5
+ nltk
6
+ scipy
7
+ numpy
8
+ datasets
9
+ joblib