File size: 4,924 Bytes
4c65e87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from rank_bm25 import BM25Okapi
import numpy as np
import torch
import re
import string
from sentence_transformers import SentenceTransformer, util, CrossEncoder


DENSE_RETRIEVER_MODEL_NAME = "all-MiniLM-L6-v2"
CROSS_ENCODER_MODEL_NAME = 'cross-encoder/ms-marco-MiniLM-L-12-v2'
LLM_CORE_MODEL_NAME = "groq/llama3-8b-8192"


def clean_text(text):
    text = text.translate(str.maketrans('', '', string.punctuation))
    text = text.lower()
    text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    text = re.sub(r'\s+', ' ', text)

    return text.strip()


class HybridRetrieverReranker:
    def __init__(self, dataset, dense_model_name=DENSE_RETRIEVER_MODEL_NAME, cross_encoder_model=CROSS_ENCODER_MODEL_NAME):
        if 'cleaned_text' not in dataset.columns:
            raise ValueError("Dataset must contain a 'cleaned_text' column.")

        self.dataset = dataset
        self.bm25_corpus = dataset['cleaned_text'].tolist()
        self.tokenized_corpus = [chunk.split() for chunk in self.bm25_corpus]
        self.bm25 = BM25Okapi(self.tokenized_corpus)

        self.dense_model = SentenceTransformer(dense_model_name)
        self.cross_encoder = CrossEncoder(cross_encoder_model)


    def bm25_retrieve(self, query, top_k=70):
        """

        Retrieve top K documents using BM25.



        Args:

            query (str): Query text.

            top_k (int): Number of top BM25 documents to retrieve.



        Returns:

            list of dict: Top K BM25 results.

        """
        cleaned_query = clean_text(query)
        query_tokens = cleaned_query.split()
        bm25_scores = self.bm25.get_scores(query_tokens)
        top_k_indices = np.argsort(bm25_scores)[::-1][:top_k]
        return self.dataset.iloc[top_k_indices].to_dict(orient='records')


    def dense_retrieve(self, query, candidates=None, top_n=35):
        """

        Retrieve top N documents using dense retrieval with LaBSE.



        Args:

            query (str): Query text.

            candidates (list of dict): Candidate documents from BM25.

            top_n (int): Number of top dense results to retrieve.



        Returns:

            list of dict: Top N dense results.

        """
        if candidates is None:
            candidates = self.dataset.to_dict(orient='records')
        query_embedding = self.dense_model.encode(query, convert_to_tensor=True)

        candidate_embeddings = torch.stack([
            eval(doc['chunk_embedding'].replace('tensor', 'torch.tensor')).clone().detach()
            for doc in candidates
        ])

        similarities = util.pytorch_cos_sim(query_embedding, candidate_embeddings).squeeze(0)
        top_n_indices = torch.topk(similarities, top_n).indices
        return [candidates[idx] for idx in top_n_indices]


    def rerank(self, query, candidates=None, top_n=3):
        """

        Rerank top documents using a CrossEncoder.



        Args:

            query (str): Query text.

            candidates (list of dict): Candidate documents from dense retriever.

            top_n (int): Number of top reranked results to return.



        Returns:

            list of dict: Top N reranked documents.

        """
        if candidates is None:
            candidates = self.dataset.to_dict(orient='records')
        query_document_pairs = [(query, doc['raw_text']) for doc in candidates]
        scores = self.cross_encoder.predict(query_document_pairs)
        top_n_indices = np.argsort(scores)[::-1][:top_n]
        return [candidates[idx] for idx in top_n_indices]



    def hybrid_retrieve(self, query, enable_bm25=True, enable_dense=True, enable_rerank=True, top_k_bm25=60, top_n_dense=30, top_n_rerank=2):
        """

        Perform hybrid retrieval: BM25 followed by dense retrieval and optional reranking.



        Args:

            query (str): Query text.

            top_k_bm25 (int): Number of top BM25 documents to retrieve.

            top_n_dense (int): Number of top dense results to retrieve.

            enable_dense (bool): Whether dense retrieval should be enabled.

            enable_rerank (bool): Whether reranking should be enabled.

            top_n_rerank (int): Number of top reranked documents to return.



        Returns:

            list of dict: Final top results after hybrid retrieval and reranking.

        """
        if enable_bm25:
            bm25_results = self.bm25_retrieve(query, top_k=top_k_bm25)
        else:
            bm25_results = None

        if enable_dense:
            dense_results = self.dense_retrieve(query, bm25_results, top_n=top_n_dense)
        else:
            dense_results = bm25_results

        if enable_rerank:
            final_results = self.rerank(query, dense_results, top_n=top_n_rerank)
        else:
            final_results = dense_results

        return final_results