Spaces:
Sleeping
Sleeping
File size: 2,504 Bytes
e8051be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
"""
Reranking Module for Advanced RAG
Handles result reranking using cross-encoder models.
"""
from typing import List, Dict
from sentence_transformers import CrossEncoder
from config.config import ENABLE_RERANKING, RERANKER_MODEL, RERANK_TOP_K
class RerankingManager:
"""Manages result reranking using cross-encoder models."""
def __init__(self):
"""Initialize the reranking manager."""
self.reranker_model = None
if ENABLE_RERANKING:
self._init_reranker_model()
print("β
Reranking Manager initialized")
def _init_reranker_model(self):
"""Initialize the reranker model."""
print(f"π Loading reranker model: {RERANKER_MODEL}")
self.reranker_model = CrossEncoder(RERANKER_MODEL)
print(f"β
Reranker model loaded successfully")
async def rerank_results(self, query: str, search_results: List[Dict]) -> List[Dict]:
"""Rerank search results using cross-encoder."""
if not ENABLE_RERANKING or not self.reranker_model or len(search_results) <= 1:
return search_results
try:
# Prepare pairs for reranking
query_doc_pairs = []
for result in search_results:
doc_text = result['payload'].get('text', '')[:512] # Limit text length
query_doc_pairs.append([query, doc_text])
# Get reranking scores
rerank_scores = self.reranker_model.predict(query_doc_pairs)
# Combine with original scores
for i, result in enumerate(search_results):
original_score = result.get('score', 0)
rerank_score = float(rerank_scores[i])
# Weighted combination of original and rerank scores
result['rerank_score'] = rerank_score
result['final_score'] = 0.3 * original_score + 0.7 * rerank_score
# Sort by final score
reranked_results = sorted(
search_results,
key=lambda x: x['final_score'],
reverse=True
)
print(f"π― Reranked {len(search_results)} results")
return reranked_results[:RERANK_TOP_K]
except Exception as e:
print(f"β οΈ Reranking failed: {e}")
return search_results[:RERANK_TOP_K]
|