# https://github.com/langchain-ai/langchain/issues/8623 from langchain.schema.retriever import BaseRetriever, Document from langchain.vectorstores import VectorStore from langchain.vectorstores import Chroma from typing import List ## The idea that some documents are summaries so easier to exploit SUMMARY_TYPES = [] class QARetriever(BaseRetriever): vectorstore: VectorStore domains: list = [] threshold: float = 22 k_summary: int = 0 k_total: int = 10 namespace: str = "vectors" def get_relevant_documents(self, query: str) -> List[Document]: assert isinstance(self.domains, list) assert self.k_total > self.k_summary, "k_total should be greater than k_summary" # Prepare base search kwargs filters = {} if len(self.domains): filters["domain"] = {"$in": self.domains} if self.k_summary > 0: # Search for k_summary documents in the summaries dataset filters_summaries = {**filters} if len(SUMMARY_TYPES): filters_summaries = { **filters_summaries, "report_type": {"$in": SUMMARY_TYPES}, } docs_summaries = self.vectorstore.similarity_search_with_score( query=query, namespace=self.namespace, filter=self.format_filter(filters_summaries), k=self.k_summary, ) docs_summaries = [x for x in docs_summaries if x[1] > self.threshold] else: docs_summaries = [] # Search for k_total - k_summary documents in the full reports dataset filters_full = {**filters} print("filters", filters) if len(SUMMARY_TYPES): filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}} k_full = self.k_total - len(docs_summaries) docs_full = self.vectorstore.similarity_search_with_score( query=query, namespace=self.namespace, filter=self.format_filter(filters_full), k=k_full, ) # Concatenate documents docs = docs_summaries + docs_full # Filter if scores are below threshold docs = [x for x in docs if x[1] > self.threshold] # Add score to metadata results = [] for i, (doc, score) in enumerate(docs): doc.metadata["similarity_score"] = score doc.metadata["content"] = doc.page_content doc.metadata["page_number"] = int(doc.metadata["page_number"]) doc.page_content = ( f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}""" ) results.append(doc) return results def format_filter(self, filters): # https://docs.trychroma.com/usage-guide#using-logical-operators if isinstance(self.vectorstore, Chroma): if len(filters) <= 1: return filters and_filters = [] for field, condition in filters.items(): and_filters.append({field: condition}) return {"$and": and_filters} return filters