Spaces:
Runtime error
Runtime error
| # models/reranker.py | |
| from typing import Any, Dict, List | |
| import torch | |
| from FlagEmbedding import FlagReranker | |
| from core.exceptions import ModelLoadError | |
| from core.logger import setup_logger | |
| logger = setup_logger("reranker") | |
| class TextReranker: | |
| """ | |
| Using the BGE-Reranker model, the documents retrieved in the first search are reordered (Cross-Encoding) by comparing them with the query. | |
| """ | |
| def __init__(self, model_name: str = "BAAI/bge-reranker-v2-m3", use_fp16: bool = False): | |
| self.model_name = model_name | |
| self.device = self._get_device() | |
| self._warmup() | |
| try: | |
| logger.info(f"⏳ Loading Reranker Model: {self.model_name} on {self.device}") | |
| self.reranker = FlagReranker( | |
| self.model_name, | |
| use_fp16=(use_fp16 and self.device.startswith("cuda")) | |
| ) | |
| logger.info("✅ Reranker Model loaded successfully.") | |
| except Exception as e: | |
| logger.critical(f"❌ Failed to load Reranker Model: {e}", exc_info=True) | |
| raise ModelLoadError(f"Reranker initialization failed: {e}") | |
| def _get_device(self) -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def _warmup(self): | |
| logger.info("Warming up reranker model with a dummy input.") | |
| self.rerank(query="Hello world", documents=[{"text": "Hello world"}]) | |
| def rerank(self, query: str, documents: List[Dict[str, Any]], text_key: str = "text") -> List[Dict[str, Any]]: | |
| """ | |
| Takes a list of documents as input, recalculates their similarity to the query, and returns the results sorted by score. | |
| :param query: The original search query string | |
| :param documents: A list of dictionaries in the form [{'chunk_id': 1, 'text': '...'}, ...] | |
| :param text_key: The key name in the document dictionary containing the body text | |
| """ | |
| if not documents: | |
| return [] | |
| # Generate pairs for Cross-Encoder input: [[query, doc1], [query, doc2], ...] | |
| sentence_pairs = [[query, doc[text_key]] for doc in documents] | |
| try: | |
| # 1. Batch score calculation | |
| scores = self.reranker.compute_score(sentence_pairs, normalize=True) | |
| # Wrap in a list because compute_score can return a float when there is only one input document | |
| if isinstance(scores, float): | |
| scores = [scores] | |
| # 2. Inject rerank_score into source document dictionarys | |
| for i, doc in enumerate(documents): | |
| doc["rerank_score"] = float(scores[i]) | |
| # 3. Sort by score (descending) | |
| reranked_docs = sorted(documents, key=lambda x: x["rerank_score"], reverse=True) | |
| return reranked_docs | |
| except Exception as e: | |
| logger.error(f"Reranking failed for query '{query}': {e}") | |
| raise RuntimeError(f"Reranking process failed: {e}") |