import statistics from typing import List from typing import Tuple import torch from loguru import logger from transformers import AutoTokenizer, AutoModelForSequenceClassification from app.config.models.configs import Document class BCEReranker: def __init__(self) -> None: self.tokenizer = AutoTokenizer.from_pretrained("maidalun1020/bce-reranker-base_v1") self.model = AutoModelForSequenceClassification.from_pretrained( "maidalun1020/bce-reranker-base_v1" ) self.model.eval() logger.info("Initialized BCE Reranker") def get_scores(self, query: str, docs: List[Document]) -> List[float]: logger.info("Reranking documents ... ") features = [[query, doc.page_content] for doc in docs] with torch.no_grad(): inputs = self.tokenizer( features, padding=True, truncation=True, max_length=512, return_tensors="pt", ) scores = ( self.model(**inputs, return_dict=True) .logits.view(-1, ) .float() .tolist() ) return scores def rerank( rerank_model: BCEReranker, query: str, docs: List[Document] ) -> Tuple[float, List[Document]]: logger.info("Reranking...") scores = rerank_model.get_scores(query, docs) for score, d in zip(scores, docs): d.metadata["score"] = score sorted_scores = sorted(scores, reverse=True) logger.info(sorted_scores) median_ = statistics.mean(sorted_scores[:10]) return median_, [ doc for doc in sorted(docs, key=lambda it: it.metadata["score"], reverse=True) ]