File size: 11,302 Bytes
7a919c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# Copyright (c) OpenMMLab. All rights reserved.
"""extract feature and search with user query."""
import os
import time

import numpy as np
import pytoml
from BCEmbedding.tools.langchain import BCERerank
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.retrievers import ContextualCompressionRetriever
from langchain_community.vectorstores.faiss import FAISS as Vectorstore
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.document_transformers import LongContextReorder
from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter
from loguru import logger
from sklearn.metrics import precision_recall_curve

from .file_operation import FileOperation
from .helper import QueryTracker


class Retriever:
    """Tokenize and extract features from the project's documents, for use in
    the reject pipeline and response pipeline."""

    def __init__(self, embeddings, reranker, work_dir: str,
                 reject_throttle: float) -> None:
        """Init with model device type and config."""
        self.reject_throttle = reject_throttle
        # self.rejecter = Vectorstore.load_local(
        #     os.path.join(work_dir, 'db_reject'),
        #     embeddings=embeddings,
        #     allow_dangerous_deserialization=True)
        self.retriever = Vectorstore.load_local(
            os.path.join(work_dir, 'db_response'),
            embeddings=embeddings,
            allow_dangerous_deserialization=True,
            distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT).as_retriever(
                search_type='similarity',
                search_kwargs={
                    'score_threshold': 0.15,
                    'k': 5
                })

        self.reordering = LongContextReorder()
        redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
        pipeline_compressor = DocumentCompressorPipeline(transformers=[redundant_filter,self.reordering ,reranker])
        self.compression_retriever  = ContextualCompressionRetriever(base_compressor=pipeline_compressor,
                                                            base_retriever=self.retriever)
        # self.compression_retriever = ContextualCompressionRetriever(
        #     base_compressor=reranker, base_retriever=self.retriever)

    # def is_reject(self, question, k=30, disable_throttle=False):
    #     """If no search results below the threshold can be found from the
    #     database, reject this query."""
    #     if disable_throttle:
    #         # for searching throttle during update sample
    #         docs_with_score = self.rejecter.similarity_search_with_relevance_scores(
    #             question, k=1)
    #         if len(docs_with_score) < 1:
    #             return True, docs_with_score
    #         return False, docs_with_score
    #     else:
    #         # for retrieve result
    #         # if no chunk passed the throttle, give the max
    #         docs_with_score = self.rejecter.similarity_search_with_relevance_scores(
    #             question, k=k)
    #         ret = []
    #         max_score = -1
    #         top1 = None
    #         for (doc, score) in docs_with_score:
    #             if score >= self.reject_throttle:
    #                 ret.append(doc)
    #             if score > max_score:
    #                 max_score = score
    #                 top1 = (doc, score)
    #         reject = False if len(ret) > 0 else True
    #         return reject, [top1]

    # def update_throttle(self,
    #                     config_path: str = 'config.ini',
    #                     good_questions=[],
    #                     bad_questions=[]):
    #     """Update reject throttle based on positive and negative examples."""

    #     if len(good_questions) == 0 or len(bad_questions) == 0:
    #         raise Exception('good and bad question examples cat not be empty.')
    #     questions = good_questions + bad_questions
    #     predictions = []
    #     for question in questions:
    #         self.reject_throttle = -1
    #         _, docs = self.is_reject(question=question, disable_throttle=True)
    #         score = docs[0][1]
    #         predictions.append(max(0, score))

    #     labels = [1 for _ in range(len(good_questions))
    #               ] + [0 for _ in range(len(bad_questions))]
    #     precision, recall, thresholds = precision_recall_curve(
    #         labels, predictions)

    #     # get the best index for sum(precision, recall)
    #     sum_precision_recall = precision[:-1] + recall[:-1]
    #     index_max = np.argmax(sum_precision_recall)
    #     optimal_threshold = max(thresholds[index_max], 0.0)

    #     with open(config_path, encoding='utf8') as f:
    #         config = pytoml.load(f)
    #     config['feature_store']['reject_throttle'] = float(optimal_threshold)
    #     with open(config_path, 'w', encoding='utf8') as f:
    #         pytoml.dump(config, f)
    #     logger.info(
    #         f'The optimal threshold is: {optimal_threshold}, saved it to {config_path}'  # noqa E501
    #     )

    def query(self,
              question: str,
              context_max_length: int = 128000,
              tracker: QueryTracker = None):
        """Processes a query and returns the best match from the vector store
        database. If the question is rejected, returns None.

        Args:
            question (str): The question asked by the user.

        Returns:
            str: The best matching chunk, or None.
            str: The best matching text, or None
        """
        if question is None or len(question) < 1:
            return None, None, []

        if len(question) > 512:
            logger.warning('input too long, truncate to 512')
            question = question[0:512]

        # reject, docs = self.is_reject(question=question)
        # assert (len(docs) > 0)
        # if reject:
        #     return None, None, [docs[0][0].metadata['source']]

        docs = self.compression_retriever.get_relevant_documents(question) # switch to the base retriever to get the top 5 
        logger.info('query:{} getting {} references '.format(question, len(docs)))
        
        
        if tracker is not None:
            tracker.log('retrieve', [doc.metadata['source'] for doc in docs])
        chunks = []
        # context = ''
        references = []

        # add file text to context, until exceed `context_max_length`

        # file_opr = FileOperation()
        for idx, doc in enumerate(docs):
            chunk = doc.page_content
            chunks.append(chunk)

            # if 'read' not in doc.metadata:
            #     logger.error(
            #         'If you are using the version before 20240319, please rerun `python3 -m huixiangdou.service.feature_store`'
            #     )
            #     raise Exception('huixiangdou version mismatch')
            # file_text, error = file_opr.read(doc.metadata['read'])
            # if error is not None:
            #     # read file failed, skip
            #     continue

            source = doc.metadata['source']
            # logger.info('target {} file length {}'.format(
            #     source, len(file_text)))
            # if len(file_text) + len(context) > context_max_length:
            #     if source in references:
            #         continue
            #     references.append(source)
            #     # add and break
            #     add_len = context_max_length - len(context)
            #     if add_len <= 0:
            #         break
            #     chunk_index = file_text.find(chunk)
            #     if chunk_index == -1:
            #         # chunk not in file_text
            #         context += chunk
            #         context += '\n'
            #         context += file_text[0:add_len - len(chunk) - 1]
            #     else:
            #         start_index = max(0, chunk_index - (add_len - len(chunk)))
            #         context += file_text[start_index:start_index + add_len]
            #     break

            references.append(source)

        # context = context[0:context_max_length]
        logger.debug('query:{} getting {} references ,top1 file:{}'.format(question, len(references),references[0]))
        logger.info('query:{} getting {} references '.format(question, len(chunks)))
        return chunks, [os.path.basename(r) for r in references]
        # return '\n'.join(chunks), context, [
        #     os.path.basename(r) for r in references
        # ]


class CacheRetriever:

    def __init__(self, config_path: str, max_len: int = 4):
        self.cache = dict()
        self.max_len = max_len
        with open(config_path, encoding='utf8') as f:
            config = pytoml.load(f)['feature_store']
            embedding_model_path = config['embedding_model_path']
            reranker_model_path = config['reranker_model_path']

        # load text2vec and rerank model
        logger.info('loading test2vec and rerank models')
        self.embeddings = HuggingFaceEmbeddings(
            model_name=embedding_model_path,
            model_kwargs={'device': 'cuda'},
            encode_kwargs={
                'batch_size': 1024,
                'normalize_embeddings': True
            })
        self.embeddings.client = self.embeddings.client.half()
        reranker_args = {
            'model': reranker_model_path,
            'top_n': 7,
            'device': 'cuda',
            'use_fp16': True
        }
        self.reranker = BCERerank(**reranker_args)

    def get(self,
            fs_id: str = 'default',
            config_path='config.ini',
            work_dir: str = 'workdir'):
        if fs_id in self.cache:
            self.cache[fs_id]['time'] = time.time()
            return self.cache[fs_id]['retriever']

        if not os.path.exists(work_dir) or not os.path.exists(config_path):
            return None, 'workdir or config.ini not exist'

        with open(config_path, encoding='utf8') as f:
            reject_throttle = pytoml.load(
                f)['feature_store']['reject_throttle']

        if len(self.cache) >= self.max_len:
            # drop the oldest one
            del_key = None
            min_time = time.time()
            for key, value in self.cache.items():
                cur_time = value['time']
                if cur_time < min_time:
                    min_time = cur_time
                    del_key = key

            if del_key is not None:
                del_value = self.cache[del_key]
                self.cache.pop(del_key)
                del del_value['retriever']

        retriever = Retriever(embeddings=self.embeddings,
                              reranker=self.reranker,
                              work_dir=work_dir,
                              reject_throttle=reject_throttle)
        self.cache[fs_id] = {'retriever': retriever, 'time': time.time()}
        return retriever

    def pop(self, fs_id: str):
        if fs_id not in self.cache:
            return
        del_value = self.cache[fs_id]
        self.cache.pop(fs_id)
        # manually free memory
        del del_value