Spaces:
Runtime error
Runtime error
# https://github.com/langchain-ai/langchain/issues/8623 | |
from langchain.schema.retriever import BaseRetriever, Document | |
from langchain.vectorstores import VectorStore | |
from langchain.vectorstores import Chroma | |
from typing import List | |
## The idea that some documents are summaries so easier to exploit | |
SUMMARY_TYPES = [] | |
class QARetriever(BaseRetriever): | |
vectorstore: VectorStore | |
domains: list = [] | |
threshold: float = 22 | |
k_summary: int = 0 | |
k_total: int = 10 | |
namespace: str = "vectors" | |
def get_relevant_documents(self, query: str) -> List[Document]: | |
assert isinstance(self.domains, list) | |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary" | |
# Prepare base search kwargs | |
filters = {} | |
if len(self.domains): | |
filters["domain"] = {"$in": self.domains} | |
if self.k_summary > 0: | |
# Search for k_summary documents in the summaries dataset | |
filters_summaries = {**filters} | |
if len(SUMMARY_TYPES): | |
filters_summaries = { | |
**filters_summaries, | |
"report_type": {"$in": SUMMARY_TYPES}, | |
} | |
docs_summaries = self.vectorstore.similarity_search_with_score( | |
query=query, | |
namespace=self.namespace, | |
filter=self.format_filter(filters_summaries), | |
k=self.k_summary, | |
) | |
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold] | |
else: | |
docs_summaries = [] | |
# Search for k_total - k_summary documents in the full reports dataset | |
filters_full = {**filters} | |
print("filters", filters) | |
if len(SUMMARY_TYPES): | |
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}} | |
k_full = self.k_total - len(docs_summaries) | |
docs_full = self.vectorstore.similarity_search_with_score( | |
query=query, | |
namespace=self.namespace, | |
filter=self.format_filter(filters_full), | |
k=k_full, | |
) | |
# Concatenate documents | |
docs = docs_summaries + docs_full | |
# Filter if scores are below threshold | |
docs = [x for x in docs if x[1] > self.threshold] | |
# Add score to metadata | |
results = [] | |
for i, (doc, score) in enumerate(docs): | |
doc.metadata["similarity_score"] = score | |
doc.metadata["content"] = doc.page_content | |
doc.metadata["page_number"] = int(doc.metadata["page_number"]) | |
doc.page_content = ( | |
f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}""" | |
) | |
results.append(doc) | |
return results | |
def format_filter(self, filters): | |
# https://docs.trychroma.com/usage-guide#using-logical-operators | |
if isinstance(self.vectorstore, Chroma): | |
if len(filters) <= 1: | |
return filters | |
and_filters = [] | |
for field, condition in filters.items(): | |
and_filters.append({field: condition}) | |
return {"$and": and_filters} | |
return filters | |