|
from nlp4web_codebase.ir.models import BaseRetriever |
|
from typing import Type, Dict |
|
from abc import abstractmethod |
|
from BM25Index import BM25Index |
|
from sciq_load import PostingList, InvertedIndex, Counting, run_counting, simple_tokenize |
|
|
|
|
|
class BaseInvertedIndexRetriever(BaseRetriever): |
|
|
|
@property |
|
@abstractmethod |
|
def index_class(self) -> Type[InvertedIndex]: |
|
pass |
|
|
|
def __init__(self, index_dir: str) -> None: |
|
self.index = self.index_class.from_saved(index_dir) |
|
|
|
def get_term_weights(self, query: str, cid: str) -> Dict[str, float]: |
|
toks = self.index.tokenize(query) |
|
target_docid = self.index.cid2docid[cid] |
|
term_weights = {} |
|
for tok in toks: |
|
if tok not in self.index.vocab: |
|
continue |
|
tid = self.index.vocab[tok] |
|
posting_list = self.index.posting_lists[tid] |
|
for docid, tweight in zip( |
|
posting_list.docid_postings, posting_list.tweight_postings |
|
): |
|
if docid == target_docid: |
|
term_weights[tok] = tweight |
|
break |
|
return term_weights |
|
|
|
def score(self, query: str, cid: str) -> float: |
|
return sum(self.get_term_weights(query=query, cid=cid).values()) |
|
|
|
def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]: |
|
toks = self.index.tokenize(query) |
|
docid2score: Dict[int, float] = {} |
|
for tok in toks: |
|
if tok not in self.index.vocab: |
|
continue |
|
tid = self.index.vocab[tok] |
|
posting_list = self.index.posting_lists[tid] |
|
for docid, tweight in zip( |
|
posting_list.docid_postings, posting_list.tweight_postings |
|
): |
|
docid2score.setdefault(docid, 0) |
|
docid2score[docid] += tweight |
|
docid2score = dict( |
|
sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk] |
|
) |
|
return { |
|
self.index.collection_ids[docid]: score |
|
for docid, score in docid2score.items() |
|
} |
|
|
|
|
|
class BM25Retriever(BaseInvertedIndexRetriever): |
|
|
|
@property |
|
def index_class(self) -> Type[BM25Index]: |
|
return BM25Index |
|
|
|
bm25_retriever = BM25Retriever(index_dir="output/bm25_index") |
|
bm25_retriever.retrieve("What type of diseases occur when the immune system attacks normal body cells?") |