# Copyright (c) OpenMMLab. All rights reserved. """extract feature and search with user query.""" import os import time import numpy as np import pytoml from BCEmbedding.tools.langchain import BCERerank from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.retrievers.document_compressors import DocumentCompressorPipeline from langchain.retrievers import ContextualCompressionRetriever from langchain_community.vectorstores.faiss import FAISS as Vectorstore from langchain_community.vectorstores.utils import DistanceStrategy from langchain_community.document_transformers import LongContextReorder from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter from loguru import logger from sklearn.metrics import precision_recall_curve from .file_operation import FileOperation from .helper import QueryTracker class Retriever: """Tokenize and extract features from the project's documents, for use in the reject pipeline and response pipeline.""" def __init__(self, embeddings, reranker, work_dir: str, reject_throttle: float) -> None: """Init with model device type and config.""" self.reject_throttle = reject_throttle # self.rejecter = Vectorstore.load_local( # os.path.join(work_dir, 'db_reject'), # embeddings=embeddings, # allow_dangerous_deserialization=True) self.retriever = Vectorstore.load_local( os.path.join(work_dir, 'db_response'), embeddings=embeddings, allow_dangerous_deserialization=True, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT).as_retriever( search_type='similarity', search_kwargs={ 'score_threshold': 0.15, 'k': 5 }) self.reordering = LongContextReorder() redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings) pipeline_compressor = DocumentCompressorPipeline(transformers=[redundant_filter,self.reordering ,reranker]) self.compression_retriever = ContextualCompressionRetriever(base_compressor=pipeline_compressor, base_retriever=self.retriever) # self.compression_retriever = ContextualCompressionRetriever( # base_compressor=reranker, base_retriever=self.retriever) # def is_reject(self, question, k=30, disable_throttle=False): # """If no search results below the threshold can be found from the # database, reject this query.""" # if disable_throttle: # # for searching throttle during update sample # docs_with_score = self.rejecter.similarity_search_with_relevance_scores( # question, k=1) # if len(docs_with_score) < 1: # return True, docs_with_score # return False, docs_with_score # else: # # for retrieve result # # if no chunk passed the throttle, give the max # docs_with_score = self.rejecter.similarity_search_with_relevance_scores( # question, k=k) # ret = [] # max_score = -1 # top1 = None # for (doc, score) in docs_with_score: # if score >= self.reject_throttle: # ret.append(doc) # if score > max_score: # max_score = score # top1 = (doc, score) # reject = False if len(ret) > 0 else True # return reject, [top1] # def update_throttle(self, # config_path: str = 'config.ini', # good_questions=[], # bad_questions=[]): # """Update reject throttle based on positive and negative examples.""" # if len(good_questions) == 0 or len(bad_questions) == 0: # raise Exception('good and bad question examples cat not be empty.') # questions = good_questions + bad_questions # predictions = [] # for question in questions: # self.reject_throttle = -1 # _, docs = self.is_reject(question=question, disable_throttle=True) # score = docs[0][1] # predictions.append(max(0, score)) # labels = [1 for _ in range(len(good_questions)) # ] + [0 for _ in range(len(bad_questions))] # precision, recall, thresholds = precision_recall_curve( # labels, predictions) # # get the best index for sum(precision, recall) # sum_precision_recall = precision[:-1] + recall[:-1] # index_max = np.argmax(sum_precision_recall) # optimal_threshold = max(thresholds[index_max], 0.0) # with open(config_path, encoding='utf8') as f: # config = pytoml.load(f) # config['feature_store']['reject_throttle'] = float(optimal_threshold) # with open(config_path, 'w', encoding='utf8') as f: # pytoml.dump(config, f) # logger.info( # f'The optimal threshold is: {optimal_threshold}, saved it to {config_path}' # noqa E501 # ) def query(self, question: str, context_max_length: int = 128000, tracker: QueryTracker = None): """Processes a query and returns the best match from the vector store database. If the question is rejected, returns None. Args: question (str): The question asked by the user. Returns: str: The best matching chunk, or None. str: The best matching text, or None """ if question is None or len(question) < 1: return None, None, [] if len(question) > 512: logger.warning('input too long, truncate to 512') question = question[0:512] # reject, docs = self.is_reject(question=question) # assert (len(docs) > 0) # if reject: # return None, None, [docs[0][0].metadata['source']] docs = self.compression_retriever.get_relevant_documents(question) # switch to the base retriever to get the top 5 logger.info('query:{} getting {} references '.format(question, len(docs))) if tracker is not None: tracker.log('retrieve', [doc.metadata['source'] for doc in docs]) chunks = [] # context = '' references = [] # add file text to context, until exceed `context_max_length` # file_opr = FileOperation() for idx, doc in enumerate(docs): chunk = doc.page_content chunks.append(chunk) # if 'read' not in doc.metadata: # logger.error( # 'If you are using the version before 20240319, please rerun `python3 -m huixiangdou.service.feature_store`' # ) # raise Exception('huixiangdou version mismatch') # file_text, error = file_opr.read(doc.metadata['read']) # if error is not None: # # read file failed, skip # continue source = doc.metadata['source'] # logger.info('target {} file length {}'.format( # source, len(file_text))) # if len(file_text) + len(context) > context_max_length: # if source in references: # continue # references.append(source) # # add and break # add_len = context_max_length - len(context) # if add_len <= 0: # break # chunk_index = file_text.find(chunk) # if chunk_index == -1: # # chunk not in file_text # context += chunk # context += '\n' # context += file_text[0:add_len - len(chunk) - 1] # else: # start_index = max(0, chunk_index - (add_len - len(chunk))) # context += file_text[start_index:start_index + add_len] # break references.append(source) # context = context[0:context_max_length] logger.debug('query:{} getting {} references ,top1 file:{}'.format(question, len(references),references[0])) logger.info('query:{} getting {} references '.format(question, len(chunks))) return chunks, [os.path.basename(r) for r in references] # return '\n'.join(chunks), context, [ # os.path.basename(r) for r in references # ] class CacheRetriever: def __init__(self, config_path: str, max_len: int = 4): self.cache = dict() self.max_len = max_len with open(config_path, encoding='utf8') as f: config = pytoml.load(f)['feature_store'] embedding_model_path = config['embedding_model_path'] reranker_model_path = config['reranker_model_path'] # load text2vec and rerank model logger.info('loading test2vec and rerank models') self.embeddings = HuggingFaceEmbeddings( model_name=embedding_model_path, model_kwargs={'device': 'cuda'}, encode_kwargs={ 'batch_size': 1024, 'normalize_embeddings': True }) self.embeddings.client = self.embeddings.client.half() reranker_args = { 'model': reranker_model_path, 'top_n': 7, 'device': 'cuda', 'use_fp16': True } self.reranker = BCERerank(**reranker_args) def get(self, fs_id: str = 'default', config_path='config.ini', work_dir: str = 'workdir'): if fs_id in self.cache: self.cache[fs_id]['time'] = time.time() return self.cache[fs_id]['retriever'] if not os.path.exists(work_dir) or not os.path.exists(config_path): return None, 'workdir or config.ini not exist' with open(config_path, encoding='utf8') as f: reject_throttle = pytoml.load( f)['feature_store']['reject_throttle'] if len(self.cache) >= self.max_len: # drop the oldest one del_key = None min_time = time.time() for key, value in self.cache.items(): cur_time = value['time'] if cur_time < min_time: min_time = cur_time del_key = key if del_key is not None: del_value = self.cache[del_key] self.cache.pop(del_key) del del_value['retriever'] retriever = Retriever(embeddings=self.embeddings, reranker=self.reranker, work_dir=work_dir, reject_throttle=reject_throttle) self.cache[fs_id] = {'retriever': retriever, 'time': time.time()} return retriever def pop(self, fs_id: str): if fs_id not in self.cache: return del_value = self.cache[fs_id] self.cache.pop(fs_id) # manually free memory del del_value