Spaces:
Runtime error
Runtime error
File size: 3,308 Bytes
f0fc5f8 cc2ce8c f0fc5f8 6e28a81 cc2ce8c 6e28a81 cc2ce8c 6e28a81 cc2ce8c 6e28a81 f0fc5f8 6e28a81 f0fc5f8 cc2ce8c f0fc5f8 cc2ce8c c6c35dc cc2ce8c f0fc5f8 cc2ce8c f0fc5f8 6e28a81 cc2ce8c c6c35dc cc2ce8c 6e28a81 cc2ce8c f0fc5f8 6e28a81 f0fc5f8 6e28a81 f0fc5f8 cc2ce8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
# https://github.com/langchain-ai/langchain/issues/8623
from langchain.schema.retriever import BaseRetriever, Document
from langchain.vectorstores import VectorStore
from langchain.vectorstores import Chroma
from typing import List
## The idea that some documents are summaries so easier to exploit
SUMMARY_TYPES = []
class QARetriever(BaseRetriever):
vectorstore: VectorStore
sources: list = []
threshold: float = 22
k_summary: int = 0
k_total: int = 10
namespace: str = "vectors"
def get_relevant_documents(self, query: str) -> List[Document]:
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(self.sources, list)
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
query = "He who can bear the misfortune of a nation is called the ruler of the world."
# Prepare base search kwargs
filters = {}
if len(self.sources):
filters["source"] = {"$in": self.sources}
if self.k_summary > 0:
# Search for k_summary documents in the summaries dataset
if len(SUMMARY_TYPES):
filters_summaries = {
**filters_summaries,
"report_type": {"$in": SUMMARY_TYPES},
}
docs_summaries = self.vectorstore.similarity_search_with_score(
query=query,
namespace=self.namespace,
filter=self.format_filter(filters_summaries),
k=self.k_summary,
)
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
else:
docs_summaries = []
# Search for k_total - k_summary documents in the full reports dataset
filters_full = {}
if len(SUMMARY_TYPES):
filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}
k_full = self.k_total - len(docs_summaries)
docs_full = self.vectorstore.similarity_search_with_score(
query=query,
namespace=self.namespace,
filter=self.format_filter(filters_full),
k=k_full,
)
print("docs_full", docs_full)
# Concatenate documents
docs = docs_summaries + docs_full
# Filter if scores are below threshold
docs = [x for x in docs if x[1] > self.threshold]
# Add score to metadata
results = []
for i, (doc, score) in enumerate(docs):
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["page_number"] = int(doc.metadata["page_number"])
doc.page_content = (
f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
)
results.append(doc)
return results
def format_filter(self, filters):
# https://docs.trychroma.com/usage-guide#using-logical-operators
if isinstance(self.vectorstore, Chroma):
if len(filters) <= 1:
return filters
and_filters = []
for field, condition in filters.items():
and_filters.append({field: condition})
return {"$and": and_filters}
return filters
|