project / app /pipeline.py
kabylake's picture
commit
7bd11ed
raw
history blame contribute delete
No virus
5.11 kB
import string
from typing import List, Optional, Tuple
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from loguru import logger
from app.chroma import ChromaDenseVectorDB
from app.config.models.configs import (
ResponseModel,
Config, SemanticSearchConfig,
)
from app.ranking import BCEReranker, rerank
from app.splade import SpladeSparseVectorDB
class LLMBundle:
def __init__(
self,
chain: Chain,
dense_db: ChromaDenseVectorDB,
reranker: BCEReranker,
sparse_db: SpladeSparseVectorDB,
chunk_sizes: List[int],
hyde_chain: Optional[LLMChain] = None
) -> None:
self.chain = chain
self.dense_db = dense_db
self.reranker = reranker
self.sparse_db = sparse_db
self.chunk_sizes = chunk_sizes
self.hyde_chain = hyde_chain
def get_relevant_documents(
self,
original_query: str,
query: str,
config: SemanticSearchConfig,
label: str,
) -> Tuple[List[str], float]:
most_relevant_docs = []
docs = []
current_reranker_score, reranker_score = -1e5, -1e5
for chunk_size in self.chunk_sizes:
all_relevant_docs = []
all_relevant_doc_ids = set()
logger.debug("Evaluating query: {}", query)
if config.query_prefix:
logger.info(f"Adding query prefix for retrieval: {config.query_prefix}")
query = config.query_prefix + query
sparse_search_docs_ids, sparse_scores = self.sparse_db.query(
search=query, n=config.max_k, label=label, chunk_size=chunk_size
)
logger.info(f"Stage 1: Got {len(sparse_search_docs_ids)} documents.")
filter = (
{"chunk_size": chunk_size}
if len(self.chunk_sizes) > 1
else dict()
)
if label:
filter.update({"label": label})
if (
not filter
):
filter = None
logger.info(f"Dense embeddings filter: {filter}")
res = self.dense_db.similarity_search_with_relevance_scores(
query, filter=filter
)
dense_search_doc_ids = [r[0].metadata["document_id"] for r in res]
all_doc_ids = (
set(sparse_search_docs_ids).union(set(dense_search_doc_ids))
).difference(all_relevant_doc_ids)
if all_doc_ids:
relevant_docs = self.dense_db.get_documents_by_id(
document_ids=list(all_doc_ids)
)
all_relevant_docs += relevant_docs
# Re-rank embeddings
reranker_score, relevant_docs = rerank(
rerank_model=self.reranker,
query=original_query,
docs=all_relevant_docs,
)
if reranker_score > current_reranker_score:
docs = relevant_docs
current_reranker_score = reranker_score
len_ = 0
for doc in docs:
doc_length = len(doc.page_content)
if len_ + doc_length < config.max_char_size:
most_relevant_docs.append(doc)
len_ += doc_length
return most_relevant_docs, current_reranker_score
def get_and_parse_response(
self,
query: str,
config: Config,
label: str = "",
) -> ResponseModel:
original_query = query
# Add HyDE queries
hyde_response = self.hyde_chain.run(query)
query += hyde_response
logger.info(f"query: {query}")
semantic_search_config = config.semantic_search
most_relevant_docs, score = self.get_relevant_documents(
original_query, query, semantic_search_config, label
)
res = self.chain(
{"input_documents": most_relevant_docs, "question": original_query},
)
out = ResponseModel(
response=res["output_text"],
question=query,
average_score=score,
hyde_response="",
)
for doc in res["input_documents"]:
out.semantic_search.append(doc.page_content)
return out
class PartialFormatter(string.Formatter):
def __init__(self, missing="~~", bad_fmt="!!"):
self.missing, self.bad_fmt = missing, bad_fmt
def get_field(self, field_name, args, kwargs):
try:
val = super(PartialFormatter, self).get_field(field_name, args, kwargs)
except (KeyError, AttributeError):
val = None, field_name
return val
def format_field(self, value, spec):
if value is None:
return self.missing
try:
return super(PartialFormatter, self).format_field(value, spec)
except ValueError:
if self.bad_fmt is not None:
return self.bad_fmt
else:
raise