import sys import os from contextlib import contextmanager from langchain_core.tools import tool from langchain_core.runnables import chain from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_core.runnables import RunnableLambda from ..reranker import rerank_docs from ...knowledge.retriever import ClimateQARetriever from ...knowledge.openalex import OpenAlexRetriever from .keywords_extraction import make_keywords_extraction_chain from ..utils import log_event def divide_into_parts(target, parts): # Base value for each part base = target // parts # Remainder to distribute remainder = target % parts # List to hold the result result = [] for i in range(parts): if i < remainder: # These parts get base value + 1 result.append(base + 1) else: # The rest get the base value result.append(base) return result @contextmanager def suppress_output(): # Open a null device with open(os.devnull, 'w') as devnull: # Store the original stdout and stderr old_stdout = sys.stdout old_stderr = sys.stderr # Redirect stdout and stderr to the null device sys.stdout = devnull sys.stderr = devnull try: yield finally: # Restore stdout and stderr sys.stdout = old_stdout sys.stderr = old_stderr @tool def query_retriever(question): """Just a dummy tool to simulate the retriever query""" return question def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): # The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results @chain async def retrieve_documents(state,config): keywords_extraction = make_keywords_extraction_chain(llm) current_question = state["remaining_questions"][0] remaining_questions = state["remaining_questions"][1:] # ToolMessage(f"Retrieving documents for question: {current_question['question']}",tool_call_id = "retriever") # # There are several options to get the final top k # # Option 1 - Get 100 documents by question and rerank by question # # Option 2 - Get 100/n documents by question and rerank the total # if rerank_by_question: # k_by_question = divide_into_parts(k_final,len(questions)) # docs = state["documents"] # if docs is None: docs = [] docs = [] k_by_question = k_final // state["n_questions"] sources = current_question["sources"] question = current_question["question"] index = current_question["index"] await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config) if index == "Vector": # Search the document store using the retriever # Configure high top k for further reranking step retriever = ClimateQARetriever( vectorstore=vectorstore, sources = sources, min_size = 200, k_summary = k_summary, k_total = k_before_reranking, threshold = 0.5, ) docs_question = await retriever.ainvoke(question,config) elif index == "OpenAlex": keywords = keywords_extraction.invoke(question)["keywords"] openalex_query = " AND ".join(keywords) print(f"... OpenAlex query: {openalex_query}") retriever_openalex = OpenAlexRetriever( min_year = state.get("min_year",1960), max_year = state.get("max_year",None), k = k_before_reranking ) docs_question = await retriever_openalex.ainvoke(openalex_query,config) else: raise Exception(f"Index {index} not found in the routing index") # Rerank if reranker is not None: with suppress_output(): docs_question = rerank_docs(reranker,docs_question,question) else: # Add a default reranking score for doc in docs_question: doc.metadata["reranking_score"] = doc.metadata["similarity_score"] # If rerank by question we select the top documents for each question if rerank_by_question: docs_question = docs_question[:k_by_question] # Add sources used in the metadata for doc in docs_question: doc.metadata["sources_used"] = sources doc.metadata["question_used"] = question doc.metadata["index_used"] = index # Add to the list of docs docs.extend(docs_question) # Sorting the list in descending order by rerank_score docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True) new_state = {"documents":docs,"remaining_questions":remaining_questions} return new_state return retrieve_documents