from rank_bm25 import BM25Okapi import numpy as np import torch import re import string from sentence_transformers import SentenceTransformer, util, CrossEncoder DENSE_RETRIEVER_MODEL_NAME = "all-MiniLM-L6-v2" CROSS_ENCODER_MODEL_NAME = 'cross-encoder/ms-marco-MiniLM-L-12-v2' LLM_CORE_MODEL_NAME = "groq/llama3-8b-8192" def clean_text(text): text = text.translate(str.maketrans('', '', string.punctuation)) text = text.lower() text = re.sub(r'[^a-zA-Z0-9\s]', '', text) text = re.sub(r'\s+', ' ', text) return text.strip() class HybridRetrieverReranker: def __init__(self, dataset, dense_model_name=DENSE_RETRIEVER_MODEL_NAME, cross_encoder_model=CROSS_ENCODER_MODEL_NAME): if 'cleaned_text' not in dataset.columns: raise ValueError("Dataset must contain a 'cleaned_text' column.") self.dataset = dataset self.bm25_corpus = dataset['cleaned_text'].tolist() self.tokenized_corpus = [chunk.split() for chunk in self.bm25_corpus] self.bm25 = BM25Okapi(self.tokenized_corpus) self.dense_model = SentenceTransformer(dense_model_name) self.cross_encoder = CrossEncoder(cross_encoder_model) def bm25_retrieve(self, query, top_k=70): """ Retrieve top K documents using BM25. Args: query (str): Query text. top_k (int): Number of top BM25 documents to retrieve. Returns: list of dict: Top K BM25 results. """ cleaned_query = clean_text(query) query_tokens = cleaned_query.split() bm25_scores = self.bm25.get_scores(query_tokens) top_k_indices = np.argsort(bm25_scores)[::-1][:top_k] return self.dataset.iloc[top_k_indices].to_dict(orient='records') def dense_retrieve(self, query, candidates=None, top_n=35): """ Retrieve top N documents using dense retrieval with LaBSE. Args: query (str): Query text. candidates (list of dict): Candidate documents from BM25. top_n (int): Number of top dense results to retrieve. Returns: list of dict: Top N dense results. """ if candidates is None: candidates = self.dataset.to_dict(orient='records') query_embedding = self.dense_model.encode(query, convert_to_tensor=True) candidate_embeddings = torch.stack([ eval(doc['chunk_embedding'].replace('tensor', 'torch.tensor')).clone().detach() for doc in candidates ]) similarities = util.pytorch_cos_sim(query_embedding, candidate_embeddings).squeeze(0) top_n_indices = torch.topk(similarities, top_n).indices return [candidates[idx] for idx in top_n_indices] def rerank(self, query, candidates=None, top_n=3): """ Rerank top documents using a CrossEncoder. Args: query (str): Query text. candidates (list of dict): Candidate documents from dense retriever. top_n (int): Number of top reranked results to return. Returns: list of dict: Top N reranked documents. """ if candidates is None: candidates = self.dataset.to_dict(orient='records') query_document_pairs = [(query, doc['raw_text']) for doc in candidates] scores = self.cross_encoder.predict(query_document_pairs) top_n_indices = np.argsort(scores)[::-1][:top_n] return [candidates[idx] for idx in top_n_indices] def hybrid_retrieve(self, query, enable_bm25=True, enable_dense=True, enable_rerank=True, top_k_bm25=60, top_n_dense=30, top_n_rerank=2): """ Perform hybrid retrieval: BM25 followed by dense retrieval and optional reranking. Args: query (str): Query text. top_k_bm25 (int): Number of top BM25 documents to retrieve. top_n_dense (int): Number of top dense results to retrieve. enable_dense (bool): Whether dense retrieval should be enabled. enable_rerank (bool): Whether reranking should be enabled. top_n_rerank (int): Number of top reranked documents to return. Returns: list of dict: Final top results after hybrid retrieval and reranking. """ if enable_bm25: bm25_results = self.bm25_retrieve(query, top_k=top_k_bm25) else: bm25_results = None if enable_dense: dense_results = self.dense_retrieve(query, bm25_results, top_n=top_n_dense) else: dense_results = bm25_results if enable_rerank: final_results = self.rerank(query, dense_results, top_n=top_n_rerank) else: final_results = dense_results return final_results