Spaces:
Sleeping
Sleeping
""" | |
Reranking Module for Advanced RAG | |
Handles result reranking using cross-encoder models. | |
""" | |
from typing import List, Dict | |
from sentence_transformers import CrossEncoder | |
from config.config import ENABLE_RERANKING, RERANKER_MODEL, RERANK_TOP_K | |
class RerankingManager: | |
"""Manages result reranking using cross-encoder models.""" | |
def __init__(self): | |
"""Initialize the reranking manager.""" | |
self.reranker_model = None | |
if ENABLE_RERANKING: | |
self._init_reranker_model() | |
print("β Reranking Manager initialized") | |
def _init_reranker_model(self): | |
"""Initialize the reranker model.""" | |
print(f"π Loading reranker model: {RERANKER_MODEL}") | |
self.reranker_model = CrossEncoder(RERANKER_MODEL) | |
print(f"β Reranker model loaded successfully") | |
async def rerank_results(self, query: str, search_results: List[Dict]) -> List[Dict]: | |
"""Rerank search results using cross-encoder.""" | |
if not ENABLE_RERANKING or not self.reranker_model or len(search_results) <= 1: | |
return search_results | |
try: | |
# Prepare pairs for reranking | |
query_doc_pairs = [] | |
for result in search_results: | |
doc_text = result['payload'].get('text', '')[:512] # Limit text length | |
query_doc_pairs.append([query, doc_text]) | |
# Get reranking scores | |
rerank_scores = self.reranker_model.predict(query_doc_pairs) | |
# Combine with original scores | |
for i, result in enumerate(search_results): | |
original_score = result.get('score', 0) | |
rerank_score = float(rerank_scores[i]) | |
# Weighted combination of original and rerank scores | |
result['rerank_score'] = rerank_score | |
result['final_score'] = 0.3 * original_score + 0.7 * rerank_score | |
# Sort by final score | |
reranked_results = sorted( | |
search_results, | |
key=lambda x: x['final_score'], | |
reverse=True | |
) | |
print(f"π― Reranked {len(search_results)} results") | |
return reranked_results[:RERANK_TOP_K] | |
except Exception as e: | |
print(f"β οΈ Reranking failed: {e}") | |
return search_results[:RERANK_TOP_K] | |