rag-bajaj / RAG /rag_modules /reranking_manager.py
quantumbit's picture
Upload 39 files
e8051be verified
"""
Reranking Module for Advanced RAG
Handles result reranking using cross-encoder models.
"""
from typing import List, Dict
from sentence_transformers import CrossEncoder
from config.config import ENABLE_RERANKING, RERANKER_MODEL, RERANK_TOP_K
class RerankingManager:
"""Manages result reranking using cross-encoder models."""
def __init__(self):
"""Initialize the reranking manager."""
self.reranker_model = None
if ENABLE_RERANKING:
self._init_reranker_model()
print("βœ… Reranking Manager initialized")
def _init_reranker_model(self):
"""Initialize the reranker model."""
print(f"πŸ”„ Loading reranker model: {RERANKER_MODEL}")
self.reranker_model = CrossEncoder(RERANKER_MODEL)
print(f"βœ… Reranker model loaded successfully")
async def rerank_results(self, query: str, search_results: List[Dict]) -> List[Dict]:
"""Rerank search results using cross-encoder."""
if not ENABLE_RERANKING or not self.reranker_model or len(search_results) <= 1:
return search_results
try:
# Prepare pairs for reranking
query_doc_pairs = []
for result in search_results:
doc_text = result['payload'].get('text', '')[:512] # Limit text length
query_doc_pairs.append([query, doc_text])
# Get reranking scores
rerank_scores = self.reranker_model.predict(query_doc_pairs)
# Combine with original scores
for i, result in enumerate(search_results):
original_score = result.get('score', 0)
rerank_score = float(rerank_scores[i])
# Weighted combination of original and rerank scores
result['rerank_score'] = rerank_score
result['final_score'] = 0.3 * original_score + 0.7 * rerank_score
# Sort by final score
reranked_results = sorted(
search_results,
key=lambda x: x['final_score'],
reverse=True
)
print(f"🎯 Reranked {len(search_results)} results")
return reranked_results[:RERANK_TOP_K]
except Exception as e:
print(f"⚠️ Reranking failed: {e}")
return search_results[:RERANK_TOP_K]