Spaces:
Configuration error
Configuration error
| import statistics | |
| from typing import List | |
| from typing import Tuple | |
| import torch | |
| from loguru import logger | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from app.config.models.configs import Document | |
| class BCEReranker: | |
| def __init__(self) -> None: | |
| self.tokenizer = AutoTokenizer.from_pretrained("maidalun1020/bce-reranker-base_v1") | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| "maidalun1020/bce-reranker-base_v1" | |
| ) | |
| self.model.eval() | |
| logger.info("Initialized BCE Reranker") | |
| def get_scores(self, query: str, docs: List[Document]) -> List[float]: | |
| logger.info("Reranking documents ... ") | |
| features = [[query, doc.page_content] for doc in docs] | |
| with torch.no_grad(): | |
| inputs = self.tokenizer( | |
| features, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt", | |
| ) | |
| scores = ( | |
| self.model(**inputs, return_dict=True) | |
| .logits.view(-1, ) | |
| .float() | |
| .tolist() | |
| ) | |
| return scores | |
| def rerank( | |
| rerank_model: BCEReranker, query: str, docs: List[Document] | |
| ) -> Tuple[float, List[Document]]: | |
| logger.info("Reranking...") | |
| scores = rerank_model.get_scores(query, docs) | |
| for score, d in zip(scores, docs): | |
| d.metadata["score"] = score | |
| sorted_scores = sorted(scores, reverse=True) | |
| logger.info(sorted_scores) | |
| median_ = statistics.mean(sorted_scores[:10]) | |
| return median_, [ | |
| doc for doc in sorted(docs, key=lambda it: it.metadata["score"], reverse=True) | |
| ] | |