import numpy as np from typing import List, Mapping, Optional, Union from collections import defaultdict from transformers.utils import logging logger = logging.get_logger(__name__) class BM25Retriever: def __init__(self, k1:float=0.9, b:float=0.4) -> None: self.name = "bm25" self.k1 = k1 self.b = b self.remove_all() @property def num_keys(self): return self.N def add(self, docs: List[Union[str, List[int]]], stop_tokens: set={}): """Build in-memory BM25 index.""" for doc in docs: if isinstance(doc, str): doc = doc.split() df = {} tf = defaultdict(int) for token in doc: if token not in stop_tokens: tf[token] += 1 df[token] = 1 self.tfs.append(dict(tf)) for token in df: self.dfs[token] += 1 # store the doc offset in the inverted lists of the corresponding token self.inverted_lists[token].append(self.N) self.N += 1 self.doc_lengths.append(len(doc)) def remove_all(self): """Remove all keys from the index.""" self.dfs = defaultdict(float) self.tfs = [] self.inverted_lists = defaultdict(list) self.doc_lengths = [] self.N = 0 def search(self, queries: Union[str, List[int], List[str], List[List[int]]], hits: int=100, k1: Optional[float]=None, b: Optional[float]=None): """Search over the BM25 index.""" if k1 is None: k1 = self.k1 if b is None: b = self.b hits = min(self.N, hits) global_scores = np.zeros(self.N, dtype=np.float32) if isinstance(queries, str): queries = [queries] elif isinstance(queries, list) and isinstance(queries[0], int): queries = [queries] all_scores = np.zeros((len(queries), hits), dtype=np.float32) all_indices = np.zeros((len(queries), hits), dtype=np.int64) doc_lengths = np.array(self.doc_lengths) for i, query in enumerate(queries): if isinstance(query, str): query = query.split(" ") # TODO: stem for token in query: if token in self.inverted_lists: candidates = self.inverted_lists[token] else: continue tfs = np.array([self.tfs[candidate][token] for candidate in candidates], dtype=np.float32) df = self.dfs[token] idf = np.log((self.N - df + 0.5) / (df + 0.5) + 1) candidate_scores = idf * (k1 + 1) * tfs / (tfs + k1 * (1 - b + b * doc_lengths[candidates])) global_scores[candidates] += candidate_scores indice = np.argpartition(-global_scores, hits - 1)[:hits] score = global_scores[indice] sorted_idx = np.argsort(score)[::-1] indice = indice[sorted_idx] score = score[sorted_idx] invalid_pos = score == 0 indice[invalid_pos] = -1 score[invalid_pos] = -float('inf') all_scores[i] = score all_indices[i] = indice return all_scores, all_indices