File size: 3,018 Bytes
bcc8503 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from typing import List
# class GraphRetriever(BaseRetriever):
# vectorstore:VectorStore
# sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
# threshold:float = 0.5
# k_total:int = 10
# def _get_relevant_documents(
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
# ) -> List[Document]:
# # Check if all elements in the list are IEA or OWID
# assert isinstance(self.sources,list)
# assert self.sources
# assert any([x in ["OWID"] for x in self.sources])
# # Prepare base search kwargs
# filters = {}
# filters["source"] = {"$in": self.sources}
# docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
# # Filter if scores are below threshold
# docs = [x for x in docs if x[1] > self.threshold]
# # Remove duplicate documents
# unique_docs = []
# seen_docs = []
# for i, doc in enumerate(docs):
# if doc[0].page_content not in seen_docs:
# unique_docs.append(doc)
# seen_docs.append(doc[0].page_content)
# # Add score to metadata
# results = []
# for i,(doc,score) in enumerate(unique_docs):
# doc.metadata["similarity_score"] = score
# doc.metadata["content"] = doc.page_content
# results.append(doc)
# return results
async def retrieve_graphs(
query: str,
vectorstore:VectorStore,
sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
threshold:float = 0.5,
k_total:int = 10,
)-> List[Document]:
# Check if all elements in the list are IEA or OWID
assert isinstance(sources,list)
assert sources
assert any([x in ["OWID"] for x in sources])
# Prepare base search kwargs
filters = {}
filters["source"] = {"$in": sources}
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
# Filter if scores are below threshold
docs = [x for x in docs if x[1] > threshold]
# Remove duplicate documents
unique_docs = []
seen_docs = []
for i, doc in enumerate(docs):
if doc[0].page_content not in seen_docs:
unique_docs.append(doc)
seen_docs.append(doc[0].page_content)
# Add score to metadata
results = []
for i,(doc,score) in enumerate(unique_docs):
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
results.append(doc)
return results |