File size: 2,504 Bytes
e8051be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""

Reranking Module for Advanced RAG

Handles result reranking using cross-encoder models.

"""

from typing import List, Dict
from sentence_transformers import CrossEncoder
from config.config import ENABLE_RERANKING, RERANKER_MODEL, RERANK_TOP_K


class RerankingManager:
    """Manages result reranking using cross-encoder models."""
    
    def __init__(self):
        """Initialize the reranking manager."""
        self.reranker_model = None
        if ENABLE_RERANKING:
            self._init_reranker_model()
        print("βœ… Reranking Manager initialized")
    
    def _init_reranker_model(self):
        """Initialize the reranker model."""
        print(f"πŸ”„ Loading reranker model: {RERANKER_MODEL}")
        self.reranker_model = CrossEncoder(RERANKER_MODEL)
        print(f"βœ… Reranker model loaded successfully")
    
    async def rerank_results(self, query: str, search_results: List[Dict]) -> List[Dict]:
        """Rerank search results using cross-encoder."""
        if not ENABLE_RERANKING or not self.reranker_model or len(search_results) <= 1:
            return search_results
        
        try:
            # Prepare pairs for reranking
            query_doc_pairs = []
            for result in search_results:
                doc_text = result['payload'].get('text', '')[:512]  # Limit text length
                query_doc_pairs.append([query, doc_text])
            
            # Get reranking scores
            rerank_scores = self.reranker_model.predict(query_doc_pairs)
            
            # Combine with original scores
            for i, result in enumerate(search_results):
                original_score = result.get('score', 0)
                rerank_score = float(rerank_scores[i])
                
                # Weighted combination of original and rerank scores
                result['rerank_score'] = rerank_score
                result['final_score'] = 0.3 * original_score + 0.7 * rerank_score
            
            # Sort by final score
            reranked_results = sorted(
                search_results,
                key=lambda x: x['final_score'],
                reverse=True
            )
            
            print(f"🎯 Reranked {len(search_results)} results")
            return reranked_results[:RERANK_TOP_K]
            
        except Exception as e:
            print(f"⚠️ Reranking failed: {e}")
            return search_results[:RERANK_TOP_K]