File size: 1,727 Bytes
7bd11ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import statistics
from typing import List
from typing import Tuple

import torch
from loguru import logger
from transformers import AutoTokenizer, AutoModelForSequenceClassification

from app.config.models.configs import Document


class BCEReranker:
    def __init__(self) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained("maidalun1020/bce-reranker-base_v1")
        self.model = AutoModelForSequenceClassification.from_pretrained(
            "maidalun1020/bce-reranker-base_v1"
        )
        self.model.eval()
        logger.info("Initialized BCE Reranker")

    def get_scores(self, query: str, docs: List[Document]) -> List[float]:
        logger.info("Reranking documents ... ")
        features = [[query, doc.page_content] for doc in docs]
        with torch.no_grad():
            inputs = self.tokenizer(
                features,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            scores = (
                self.model(**inputs, return_dict=True)
                .logits.view(-1, )
                .float()
                .tolist()
            )
        return scores


def rerank(
        rerank_model: BCEReranker, query: str, docs: List[Document]
) -> Tuple[float, List[Document]]:
    logger.info("Reranking...")
    scores = rerank_model.get_scores(query, docs)
    for score, d in zip(scores, docs):
        d.metadata["score"] = score

    sorted_scores = sorted(scores, reverse=True)

    logger.info(sorted_scores)
    median_ = statistics.mean(sorted_scores[:10])
    return median_, [
        doc for doc in sorted(docs, key=lambda it: it.metadata["score"], reverse=True)
    ]