Spaces:
Running
Running
File size: 4,924 Bytes
4c65e87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from rank_bm25 import BM25Okapi
import numpy as np
import torch
import re
import string
from sentence_transformers import SentenceTransformer, util, CrossEncoder
DENSE_RETRIEVER_MODEL_NAME = "all-MiniLM-L6-v2"
CROSS_ENCODER_MODEL_NAME = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
LLM_CORE_MODEL_NAME = "groq/llama3-8b-8192"
def clean_text(text):
text = text.translate(str.maketrans('', '', string.punctuation))
text = text.lower()
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
class HybridRetrieverReranker:
def __init__(self, dataset, dense_model_name=DENSE_RETRIEVER_MODEL_NAME, cross_encoder_model=CROSS_ENCODER_MODEL_NAME):
if 'cleaned_text' not in dataset.columns:
raise ValueError("Dataset must contain a 'cleaned_text' column.")
self.dataset = dataset
self.bm25_corpus = dataset['cleaned_text'].tolist()
self.tokenized_corpus = [chunk.split() for chunk in self.bm25_corpus]
self.bm25 = BM25Okapi(self.tokenized_corpus)
self.dense_model = SentenceTransformer(dense_model_name)
self.cross_encoder = CrossEncoder(cross_encoder_model)
def bm25_retrieve(self, query, top_k=70):
"""
Retrieve top K documents using BM25.
Args:
query (str): Query text.
top_k (int): Number of top BM25 documents to retrieve.
Returns:
list of dict: Top K BM25 results.
"""
cleaned_query = clean_text(query)
query_tokens = cleaned_query.split()
bm25_scores = self.bm25.get_scores(query_tokens)
top_k_indices = np.argsort(bm25_scores)[::-1][:top_k]
return self.dataset.iloc[top_k_indices].to_dict(orient='records')
def dense_retrieve(self, query, candidates=None, top_n=35):
"""
Retrieve top N documents using dense retrieval with LaBSE.
Args:
query (str): Query text.
candidates (list of dict): Candidate documents from BM25.
top_n (int): Number of top dense results to retrieve.
Returns:
list of dict: Top N dense results.
"""
if candidates is None:
candidates = self.dataset.to_dict(orient='records')
query_embedding = self.dense_model.encode(query, convert_to_tensor=True)
candidate_embeddings = torch.stack([
eval(doc['chunk_embedding'].replace('tensor', 'torch.tensor')).clone().detach()
for doc in candidates
])
similarities = util.pytorch_cos_sim(query_embedding, candidate_embeddings).squeeze(0)
top_n_indices = torch.topk(similarities, top_n).indices
return [candidates[idx] for idx in top_n_indices]
def rerank(self, query, candidates=None, top_n=3):
"""
Rerank top documents using a CrossEncoder.
Args:
query (str): Query text.
candidates (list of dict): Candidate documents from dense retriever.
top_n (int): Number of top reranked results to return.
Returns:
list of dict: Top N reranked documents.
"""
if candidates is None:
candidates = self.dataset.to_dict(orient='records')
query_document_pairs = [(query, doc['raw_text']) for doc in candidates]
scores = self.cross_encoder.predict(query_document_pairs)
top_n_indices = np.argsort(scores)[::-1][:top_n]
return [candidates[idx] for idx in top_n_indices]
def hybrid_retrieve(self, query, enable_bm25=True, enable_dense=True, enable_rerank=True, top_k_bm25=60, top_n_dense=30, top_n_rerank=2):
"""
Perform hybrid retrieval: BM25 followed by dense retrieval and optional reranking.
Args:
query (str): Query text.
top_k_bm25 (int): Number of top BM25 documents to retrieve.
top_n_dense (int): Number of top dense results to retrieve.
enable_dense (bool): Whether dense retrieval should be enabled.
enable_rerank (bool): Whether reranking should be enabled.
top_n_rerank (int): Number of top reranked documents to return.
Returns:
list of dict: Final top results after hybrid retrieval and reranking.
"""
if enable_bm25:
bm25_results = self.bm25_retrieve(query, top_k=top_k_bm25)
else:
bm25_results = None
if enable_dense:
dense_results = self.dense_retrieve(query, bm25_results, top_n=top_n_dense)
else:
dense_results = bm25_results
if enable_rerank:
final_results = self.rerank(query, dense_results, top_n=top_n_rerank)
else:
final_results = dense_results
return final_results |