from langchain_core.retrievers import BaseRetriever | |
from langchain_core.documents.base import Document | |
from langchain_core.vectorstores import VectorStore | |
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun | |
from typing import List | |
# class GraphRetriever(BaseRetriever): | |
# vectorstore:VectorStore | |
# sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever | |
# threshold:float = 0.5 | |
# k_total:int = 10 | |
# def _get_relevant_documents( | |
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun | |
# ) -> List[Document]: | |
# # Check if all elements in the list are IEA or OWID | |
# assert isinstance(self.sources,list) | |
# assert self.sources | |
# assert any([x in ["OWID"] for x in self.sources]) | |
# # Prepare base search kwargs | |
# filters = {} | |
# filters["source"] = {"$in": self.sources} | |
# docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total) | |
# # Filter if scores are below threshold | |
# docs = [x for x in docs if x[1] > self.threshold] | |
# # Remove duplicate documents | |
# unique_docs = [] | |
# seen_docs = [] | |
# for i, doc in enumerate(docs): | |
# if doc[0].page_content not in seen_docs: | |
# unique_docs.append(doc) | |
# seen_docs.append(doc[0].page_content) | |
# # Add score to metadata | |
# results = [] | |
# for i,(doc,score) in enumerate(unique_docs): | |
# doc.metadata["similarity_score"] = score | |
# doc.metadata["content"] = doc.page_content | |
# results.append(doc) | |
# return results | |
async def retrieve_graphs( | |
query: str, | |
vectorstore:VectorStore, | |
sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever | |
threshold:float = 0.5, | |
k_total:int = 10, | |
)-> List[Document]: | |
# Check if all elements in the list are IEA or OWID | |
assert isinstance(sources,list) | |
assert sources | |
assert any([x in ["OWID"] for x in sources]) | |
# Prepare base search kwargs | |
filters = {} | |
filters["source"] = {"$in": sources} | |
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total) | |
# Filter if scores are below threshold | |
docs = [x for x in docs if x[1] > threshold] | |
# Remove duplicate documents | |
unique_docs = [] | |
seen_docs = [] | |
for i, doc in enumerate(docs): | |
if doc[0].page_content not in seen_docs: | |
unique_docs.append(doc) | |
seen_docs.append(doc[0].page_content) | |
# Add score to metadata | |
results = [] | |
for i,(doc,score) in enumerate(unique_docs): | |
doc.metadata["similarity_score"] = score | |
doc.metadata["content"] = doc.page_content | |
results.append(doc) | |
return results |