File size: 5,425 Bytes
f0fc5f8
 
 
 
38ed905
 
 
 
 
 
f0fc5f8
 
 
 
 
 
139fefe
393b23a
f0fc5f8
 
 
 
139fefe
 
 
 
f0fc5f8
 
 
 
 
 
 
139fefe
 
 
 
 
 
f0fc5f8
 
 
 
9a9100e
f0fc5f8
139fefe
 
f0fc5f8
 
 
 
 
9a9100e
f0fc5f8
 
139fefe
f0fc5f8
 
 
 
 
 
 
 
 
 
 
 
 
139fefe
f0fc5f8
 
139fefe
 
f0fc5f8
139fefe
f0fc5f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# https://github.com/langchain-ai/langchain/issues/8623

import pandas as pd

from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun

from typing import List
from pydantic import Field

class ClimateQARetriever(BaseRetriever):
    vectorstore:VectorStore
    sources:list = ["IPCC","IPBES"]
    reports:list = []
    threshold:float = 0.6
    k_summary:int = 3
    k_total:int = 10
    namespace:str = "vectors"


    def _get_relevant_documents(
        self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:

        # Check if all elements in the list are either IPCC or IPBES
        assert isinstance(self.sources,list)
        assert all([x in ["IPCC","IPBES"] for x in self.sources])
        assert self.k_total > self.k_summary, "k_total should be greater than k_summary"

        # Prepare base search kwargs

        filters = {}
        if len(self.reports) > 0:
            filters["short_name"] = {"$in":self.reports}
        else:
            filters["source"] = { "$in":self.sources}

        # Search for k_summary documents in the summaries dataset
        filters_summaries = {
            **filters,
            "report_type": { "$in":["SPM"]},
        }

        docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
        docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]

        # Search for k_total - k_summary documents in the full reports dataset
        filters_full = {
            **filters,
            "report_type": { "$nin":["SPM"]},
        }
        k_full = self.k_total - len(docs_summaries)
        docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)

        # Concatenate documents
        docs = docs_summaries + docs_full

        # Filter if scores are below threshold
        docs = [x for x in docs if x[1] > self.threshold]

        # Add score to metadata
        results = []
        for i,(doc,score) in enumerate(docs):
            doc.metadata["similarity_score"] = score
            doc.metadata["content"] = doc.page_content
            doc.metadata["page_number"] = int(doc.metadata["page_number"])
            # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
            results.append(doc)

        # Sort by score
        # results = sorted(results,key = lambda x : x.metadata["similarity_score"],reverse = True)

        return results




# def filter_summaries(df,k_summary = 3,k_total = 10):
#     # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)"

#     # # Filter by source
#     # if source == "IPCC":
#     #     df = df.loc[df["source"]=="IPCC"]
#     # elif source == "IPBES":
#     #     df = df.loc[df["source"]=="IPBES"]
#     # else:
#     #     pass

#     # Separate summaries and full reports
#     df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])]
#     df_full = df.loc[~df["report_type"].isin(["SPM","TS"])]

#     # Find passages from summaries dataset
#     passages_summaries = df_summaries.head(k_summary)

#     # Find passages from full reports dataset
#     passages_fullreports = df_full.head(k_total - len(passages_summaries))

#     # Concatenate passages
#     passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True)
#     return passages




# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300):
#     assert max_k > k_total

#     validated_sources = ["IPCC","IPBES"]
#     sources = [x for x in sources if x in validated_sources]
#     filters = {
#         "source": { "$in": sources },
#     }
#     print(filters)

#     # Retrieve documents
#     docs = retriever.retrieve(query,top_k = max_k,filters = filters)

#     # Filter by score
#     docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold]

#     if len(docs) == 0:
#         return []
#     res = pd.DataFrame(docs)
#     passages_df = filter_summaries(res,k_summary,k_total)
#     if as_dict:
#         contents = passages_df["content"].tolist()
#         meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records")
#         passages = []
#         for i in range(len(contents)):
#             passages.append({"content":contents[i],"meta":meta[i]})
#         return passages
#     else:
#         return passages_df



# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10):


#     print("hellooooo")

#     # Reformulate queries
#     reformulated_query,language = reformulate(query)

#     print(reformulated_query)

#     # Retrieve documents
#     passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold)
#     response = {
#       "query":query,
#       "reformulated_query":reformulated_query,
#       "language":language,
#       "sources":passages,
#       "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt},
#     }
#     return response