Clara / climateqa /engine /text_retriever.py
Samiraxio's picture
Upload folder using huggingface_hub
35fb63f verified
raw
history blame contribute delete
No virus
1.52 kB
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents.base import Document
from langchain_core.vectorstores import VectorStore
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from typing import List
class ClimateQARetriever(BaseRetriever):
vectorstore: VectorStore
sources: list = []
reports:list = []
threshold: float = 0.01
k_summary: int = 3
k_total: int = 7
min_size: int = 200
filter: dict = None
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
# Check if all elements in the list are either IPCC or IPBES
assert isinstance(self.sources,list)
# assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
# Prepare base search kwargs
filters = {}
filters["source"] = { "$in":self.sources}
docs = self.vectorstore.similarity_search_with_score(query=query,k=self.k_total, filter=self.filter)
# Add score to metadata
results = []
for i, (doc, score) in enumerate(docs):
# filtre les sources sous le seuil
if score < self.threshold:
continue
doc.metadata["similarity_score"] = score
doc.metadata["content"] = doc.page_content
doc.metadata["chunk_type"] = "text"
doc.metadata["page_number"] = 1
results.append(doc)
return results