import string from typing import List, Optional, Tuple from langchain.chains import LLMChain from langchain.chains.base import Chain from loguru import logger from app.chroma import ChromaDenseVectorDB from app.config.models.configs import ( ResponseModel, Config, SemanticSearchConfig, ) from app.ranking import BCEReranker, rerank from app.splade import SpladeSparseVectorDB class LLMBundle: def __init__( self, chain: Chain, dense_db: ChromaDenseVectorDB, reranker: BCEReranker, sparse_db: SpladeSparseVectorDB, chunk_sizes: List[int], hyde_chain: Optional[LLMChain] = None ) -> None: self.chain = chain self.dense_db = dense_db self.reranker = reranker self.sparse_db = sparse_db self.chunk_sizes = chunk_sizes self.hyde_chain = hyde_chain def get_relevant_documents( self, original_query: str, query: str, config: SemanticSearchConfig, label: str, ) -> Tuple[List[str], float]: most_relevant_docs = [] docs = [] current_reranker_score, reranker_score = -1e5, -1e5 for chunk_size in self.chunk_sizes: all_relevant_docs = [] all_relevant_doc_ids = set() logger.debug("Evaluating query: {}", query) if config.query_prefix: logger.info(f"Adding query prefix for retrieval: {config.query_prefix}") query = config.query_prefix + query sparse_search_docs_ids, sparse_scores = self.sparse_db.query( search=query, n=config.max_k, label=label, chunk_size=chunk_size ) logger.info(f"Stage 1: Got {len(sparse_search_docs_ids)} documents.") filter = ( {"chunk_size": chunk_size} if len(self.chunk_sizes) > 1 else dict() ) if label: filter.update({"label": label}) if ( not filter ): filter = None logger.info(f"Dense embeddings filter: {filter}") res = self.dense_db.similarity_search_with_relevance_scores( query, filter=filter ) dense_search_doc_ids = [r[0].metadata["document_id"] for r in res] all_doc_ids = ( set(sparse_search_docs_ids).union(set(dense_search_doc_ids)) ).difference(all_relevant_doc_ids) if all_doc_ids: relevant_docs = self.dense_db.get_documents_by_id( document_ids=list(all_doc_ids) ) all_relevant_docs += relevant_docs # Re-rank embeddings reranker_score, relevant_docs = rerank( rerank_model=self.reranker, query=original_query, docs=all_relevant_docs, ) if reranker_score > current_reranker_score: docs = relevant_docs current_reranker_score = reranker_score len_ = 0 for doc in docs: doc_length = len(doc.page_content) if len_ + doc_length < config.max_char_size: most_relevant_docs.append(doc) len_ += doc_length return most_relevant_docs, current_reranker_score def get_and_parse_response( self, query: str, config: Config, label: str = "", ) -> ResponseModel: original_query = query # Add HyDE queries hyde_response = self.hyde_chain.run(query) query += hyde_response logger.info(f"query: {query}") semantic_search_config = config.semantic_search most_relevant_docs, score = self.get_relevant_documents( original_query, query, semantic_search_config, label ) res = self.chain( {"input_documents": most_relevant_docs, "question": original_query}, ) out = ResponseModel( response=res["output_text"], question=query, average_score=score, hyde_response="", ) for doc in res["input_documents"]: out.semantic_search.append(doc.page_content) return out class PartialFormatter(string.Formatter): def __init__(self, missing="~~", bad_fmt="!!"): self.missing, self.bad_fmt = missing, bad_fmt def get_field(self, field_name, args, kwargs): try: val = super(PartialFormatter, self).get_field(field_name, args, kwargs) except (KeyError, AttributeError): val = None, field_name return val def format_field(self, value, spec): if value is None: return self.missing try: return super(PartialFormatter, self).format_field(value, spec) except ValueError: if self.bad_fmt is not None: return self.bad_fmt else: raise