# https://github.com/langchain-ai/langchain/issues/8623 | |
import pandas as pd | |
from langchain.schema.retriever import BaseRetriever, Document | |
from langchain.vectorstores.base import VectorStoreRetriever | |
from langchain.vectorstores import VectorStore | |
from langchain.callbacks.manager import CallbackManagerForRetrieverRun | |
from typing import List | |
from pydantic import Field | |
class ClimateQARetriever(BaseRetriever): | |
vectorstore:VectorStore | |
sources:list = ["IPCC","IPBES"] | |
threshold:float = 22 | |
k_summary:int = 3 | |
k_total:int = 10 | |
namespace:str = "vectors" | |
def get_relevant_documents(self, query: str) -> List[Document]: | |
# Check if all elements in the list are either IPCC or IPBES | |
assert isinstance(self.sources,list) | |
assert all([x in ["IPCC","IPBES"] for x in self.sources]) | |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary" | |
# Prepare base search kwargs | |
filters = { | |
"source": { "$in":self.sources}, | |
} | |
# Search for k_summary documents in the summaries dataset | |
filters_summaries = { | |
**filters, | |
"report_type": { "$in":["SPM","TS"]}, | |
} | |
docs_summaries = self.vectorstore.similarity_search_with_score(query=query,namespace = self.namespace,filter = filters_summaries,k = self.k_summary) | |
docs_summaries = [x for x in docs_summaries if x[1] > self.threshold] | |
# Search for k_total - k_summary documents in the full reports dataset | |
filters_full = { | |
**filters, | |
"report_type": { "$nin":["SPM","TS"]}, | |
} | |
k_full = self.k_total - len(docs_summaries) | |
docs_full = self.vectorstore.similarity_search_with_score(query=query,namespace = self.namespace,filter = filters_full,k = k_full) | |
# Concatenate documents | |
docs = docs_summaries + docs_full | |
# Filter if scores are below threshold | |
docs = [x for x in docs if x[1] > self.threshold] | |
# Add score to metadata | |
results = [] | |
for i,(doc,score) in enumerate(docs): | |
doc.metadata["similarity_score"] = score | |
doc.metadata["content"] = doc.page_content | |
doc.metadata["page_number"] = int(doc.metadata["page_number"]) | |
doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}""" | |
results.append(doc) | |
return results | |
# def filter_summaries(df,k_summary = 3,k_total = 10): | |
# # assert source in ["IPCC","IPBES","ALL"], "source arg should be in (IPCC,IPBES,ALL)" | |
# # # Filter by source | |
# # if source == "IPCC": | |
# # df = df.loc[df["source"]=="IPCC"] | |
# # elif source == "IPBES": | |
# # df = df.loc[df["source"]=="IPBES"] | |
# # else: | |
# # pass | |
# # Separate summaries and full reports | |
# df_summaries = df.loc[df["report_type"].isin(["SPM","TS"])] | |
# df_full = df.loc[~df["report_type"].isin(["SPM","TS"])] | |
# # Find passages from summaries dataset | |
# passages_summaries = df_summaries.head(k_summary) | |
# # Find passages from full reports dataset | |
# passages_fullreports = df_full.head(k_total - len(passages_summaries)) | |
# # Concatenate passages | |
# passages = pd.concat([passages_summaries,passages_fullreports],axis = 0,ignore_index = True) | |
# return passages | |
# def retrieve_with_summaries(query,retriever,k_summary = 3,k_total = 10,sources = ["IPCC","IPBES"],max_k = 100,threshold = 0.555,as_dict = True,min_length = 300): | |
# assert max_k > k_total | |
# validated_sources = ["IPCC","IPBES"] | |
# sources = [x for x in sources if x in validated_sources] | |
# filters = { | |
# "source": { "$in": sources }, | |
# } | |
# print(filters) | |
# # Retrieve documents | |
# docs = retriever.retrieve(query,top_k = max_k,filters = filters) | |
# # Filter by score | |
# docs = [{**x.meta,"score":x.score,"content":x.content} for x in docs if x.score > threshold] | |
# if len(docs) == 0: | |
# return [] | |
# res = pd.DataFrame(docs) | |
# passages_df = filter_summaries(res,k_summary,k_total) | |
# if as_dict: | |
# contents = passages_df["content"].tolist() | |
# meta = passages_df.drop(columns = ["content"]).to_dict(orient = "records") | |
# passages = [] | |
# for i in range(len(contents)): | |
# passages.append({"content":contents[i],"meta":meta[i]}) | |
# return passages | |
# else: | |
# return passages_df | |
# def retrieve(query,sources = ["IPCC"],threshold = 0.555,k = 10): | |
# print("hellooooo") | |
# # Reformulate queries | |
# reformulated_query,language = reformulate(query) | |
# print(reformulated_query) | |
# # Retrieve documents | |
# passages = retrieve_with_summaries(reformulated_query,retriever,k_total = k,k_summary = 3,as_dict = True,sources = sources,threshold = threshold) | |
# response = { | |
# "query":query, | |
# "reformulated_query":reformulated_query, | |
# "language":language, | |
# "sources":passages, | |
# "prompts":{"init_prompt":init_prompt,"sources_prompt":sources_prompt}, | |
# } | |
# return response | |