project / app /ranking.py
kabylake's picture
commit
7bd11ed
raw
history blame contribute delete
No virus
1.73 kB
import statistics
from typing import List
from typing import Tuple
import torch
from loguru import logger
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from app.config.models.configs import Document
class BCEReranker:
def __init__(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained("maidalun1020/bce-reranker-base_v1")
self.model = AutoModelForSequenceClassification.from_pretrained(
"maidalun1020/bce-reranker-base_v1"
)
self.model.eval()
logger.info("Initialized BCE Reranker")
def get_scores(self, query: str, docs: List[Document]) -> List[float]:
logger.info("Reranking documents ... ")
features = [[query, doc.page_content] for doc in docs]
with torch.no_grad():
inputs = self.tokenizer(
features,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
)
scores = (
self.model(**inputs, return_dict=True)
.logits.view(-1, )
.float()
.tolist()
)
return scores
def rerank(
rerank_model: BCEReranker, query: str, docs: List[Document]
) -> Tuple[float, List[Document]]:
logger.info("Reranking...")
scores = rerank_model.get_scores(query, docs)
for score, d in zip(scores, docs):
d.metadata["score"] = score
sorted_scores = sorted(scores, reverse=True)
logger.info(sorted_scores)
median_ = statistics.mean(sorted_scores[:10])
return median_, [
doc for doc in sorted(docs, key=lambda it: it.metadata["score"], reverse=True)
]