TheoLvs's picture
feature/add_agents (#14)
48e003d verified
raw
history blame
5.26 kB
import sys
import os
from contextlib import contextmanager
from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from ..reranker import rerank_docs
from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
else:
# The rest get the base value
result.append(base)
return result
@contextmanager
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@tool
def query_retriever(question):
"""Just a dummy tool to simulate the retriever query"""
return question
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
@chain
async def retrieve_documents(state,config):
keywords_extraction = make_keywords_extraction_chain(llm)
current_question = state["remaining_questions"][0]
remaining_questions = state["remaining_questions"][1:]
# ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever")
# # There are several options to get the final top k
# # Option 1 - Get 100 documents by question and rerank by question
# # Option 2 - Get 100/n documents by question and rerank the total
# if rerank_by_question:
# k_by_question = divide_into_parts(k_final,len(questions))
# docs = state["documents"]
# if docs is None: docs = []
docs = []
k_by_question = k_final // state["n_questions"]
sources = current_question["sources"]
question = current_question["question"]
index = current_question["index"]
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
if index == "Vector":
# Search the document store using the retriever
# Configure high top k for further reranking step
retriever = ClimateQARetriever(
vectorstore=vectorstore,
sources = sources,
min_size = 200,
k_summary = k_summary,
k_total = k_before_reranking,
threshold = 0.5,
)
docs_question = await retriever.ainvoke(question,config)
elif index == "OpenAlex":
keywords = keywords_extraction.invoke(question)["keywords"]
openalex_query = " AND ".join(keywords)
print(f"... OpenAlex query: {openalex_query}")
retriever_openalex = OpenAlexRetriever(
min_year = state.get("min_year",1960),
max_year = state.get("max_year",None),
k = k_before_reranking
)
docs_question = await retriever_openalex.ainvoke(openalex_query,config)
else:
raise Exception(f"Index {index} not found in the routing index")
# Rerank
if reranker is not None:
with suppress_output():
docs_question = rerank_docs(reranker,docs_question,question)
else:
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
# If rerank by question we select the top documents for each question
if rerank_by_question:
docs_question = docs_question[:k_by_question]
# Add sources used in the metadata
for doc in docs_question:
doc.metadata["sources_used"] = sources
doc.metadata["question_used"] = question
doc.metadata["index_used"] = index
# Add to the list of docs
docs.extend(docs_question)
# Sorting the list in descending order by rerank_score
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
new_state = {"documents":docs,"remaining_questions":remaining_questions}
return new_state
return retrieve_documents