File size: 3,312 Bytes
f0fc5f8
 
 
 
 
cc2ce8c
f0fc5f8
 
6e28a81
cc2ce8c
 
 
 
 
6e28a81
cc2ce8c
6e28a81
cc2ce8c
6e28a81
 
f0fc5f8
 
 
6e28a81
f0fc5f8
 
cc2ce8c
f0fc5f8
cc2ce8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0fc5f8
 
cc2ce8c
 
 
 
f0fc5f8
6e28a81
cc2ce8c
 
 
 
6e28a81
cc2ce8c
f0fc5f8
 
 
 
 
 
 
 
 
6e28a81
f0fc5f8
 
 
6e28a81
 
 
f0fc5f8
 
 
 
cc2ce8c
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# https://github.com/langchain-ai/langchain/issues/8623


from langchain.schema.retriever import BaseRetriever, Document
from langchain.vectorstores import VectorStore
from langchain.vectorstores import Chroma
from typing import List


## The idea that some documents are summaries so easier to exploit
SUMMARY_TYPES = []


class QARetriever(BaseRetriever):
    vectorstore: VectorStore
    sources: list = []
    threshold: float = 22
    k_summary: int = 0
    k_total: int = 10
    namespace: str = "vectors"

    def get_relevant_documents(self, query: str) -> List[Document]:
        # Check if all elements in the list are either IPCC or IPBES
        assert isinstance(self.sources, list)
        assert self.k_total > self.k_summary, "k_total should be greater than k_summary"

        query = "He who can bear the misfortune of a nation is called the ruler of the world."
        # Prepare base search kwargs
        filters = {}
        if len(self.sources):
            filters["source"] = {"$in": self.sources}

        if self.k_summary > 0:
            # Search for k_summary documents in the summaries dataset
            if len(SUMMARY_TYPES):
                filters_summaries = {
                    **filters_summaries,
                    "report_type": {"$in": SUMMARY_TYPES},
                }
            docs_summaries = self.vectorstore.similarity_search_with_score(
                query=query,
                # namespace=self.namespace,
                filter=self.format_filter(filters_summaries),
                k=self.k_summary,
            )
            docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
        else:
            docs_summaries = []

        # Search for k_total - k_summary documents in the full reports dataset
        filters_full = {}
        if len(SUMMARY_TYPES):
            filters_full = {**filters_full, "report_type": {"$nin": SUMMARY_TYPES}}

        k_full = self.k_total - len(docs_summaries)
        docs_full = self.vectorstore.similarity_search_with_score(
            query=query,
            # namespace=self.namespace,
            filter=self.format_filter(filters_full),
            k=k_full,
        )
        print("docs_full", docs_full)

        # Concatenate documents
        docs = docs_summaries + docs_full

        # Filter if scores are below threshold
        docs = [x for x in docs if x[1] > self.threshold]

        # Add score to metadata
        results = []
        for i, (doc, score) in enumerate(docs):
            doc.metadata["similarity_score"] = score
            doc.metadata["content"] = doc.page_content
            doc.metadata["page_number"] = int(doc.metadata["page_number"])
            doc.page_content = (
                f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
            )
            results.append(doc)

        return results

    def format_filter(self, filters):
        # https://docs.trychroma.com/usage-guide#using-logical-operators
        if isinstance(self.vectorstore, Chroma):
            if len(filters) <= 1:
                return filters
            and_filters = []
            for field, condition in filters.items():
                and_filters.append({field: condition})
            return {"$and": and_filters}
        return filters