|
import sys |
|
import os |
|
from contextlib import contextmanager |
|
|
|
from langchain_core.tools import tool |
|
from langchain_core.runnables import chain |
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough |
|
from langchain_core.runnables import RunnableLambda |
|
|
|
from ..reranker import rerank_docs |
|
from ...knowledge.retriever import ClimateQARetriever |
|
from ...knowledge.openalex import OpenAlexRetriever |
|
from .keywords_extraction import make_keywords_extraction_chain |
|
from ..utils import log_event |
|
|
|
|
|
|
|
def divide_into_parts(target, parts): |
|
|
|
base = target // parts |
|
|
|
remainder = target % parts |
|
|
|
result = [] |
|
|
|
for i in range(parts): |
|
if i < remainder: |
|
|
|
result.append(base + 1) |
|
else: |
|
|
|
result.append(base) |
|
|
|
return result |
|
|
|
|
|
@contextmanager |
|
def suppress_output(): |
|
|
|
with open(os.devnull, 'w') as devnull: |
|
|
|
old_stdout = sys.stdout |
|
old_stderr = sys.stderr |
|
|
|
sys.stdout = devnull |
|
sys.stderr = devnull |
|
try: |
|
yield |
|
finally: |
|
|
|
sys.stdout = old_stdout |
|
sys.stderr = old_stderr |
|
|
|
|
|
@tool |
|
def query_retriever(question): |
|
"""Just a dummy tool to simulate the retriever query""" |
|
return question |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5): |
|
|
|
|
|
@chain |
|
async def retrieve_documents(state,config): |
|
|
|
keywords_extraction = make_keywords_extraction_chain(llm) |
|
|
|
current_question = state["remaining_questions"][0] |
|
remaining_questions = state["remaining_questions"][1:] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs = [] |
|
k_by_question = k_final // state["n_questions"] |
|
|
|
sources = current_question["sources"] |
|
question = current_question["question"] |
|
index = current_question["index"] |
|
|
|
|
|
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config) |
|
|
|
|
|
if index == "Vector": |
|
|
|
|
|
|
|
retriever = ClimateQARetriever( |
|
vectorstore=vectorstore, |
|
sources = sources, |
|
min_size = 200, |
|
k_summary = k_summary, |
|
k_total = k_before_reranking, |
|
threshold = 0.5, |
|
) |
|
docs_question = await retriever.ainvoke(question,config) |
|
|
|
elif index == "OpenAlex": |
|
|
|
keywords = keywords_extraction.invoke(question)["keywords"] |
|
openalex_query = " AND ".join(keywords) |
|
|
|
print(f"... OpenAlex query: {openalex_query}") |
|
|
|
retriever_openalex = OpenAlexRetriever( |
|
min_year = state.get("min_year",1960), |
|
max_year = state.get("max_year",None), |
|
k = k_before_reranking |
|
) |
|
docs_question = await retriever_openalex.ainvoke(openalex_query,config) |
|
|
|
else: |
|
raise Exception(f"Index {index} not found in the routing index") |
|
|
|
|
|
if reranker is not None: |
|
with suppress_output(): |
|
docs_question = rerank_docs(reranker,docs_question,question) |
|
else: |
|
|
|
for doc in docs_question: |
|
doc.metadata["reranking_score"] = doc.metadata["similarity_score"] |
|
|
|
|
|
if rerank_by_question: |
|
docs_question = docs_question[:k_by_question] |
|
|
|
|
|
for doc in docs_question: |
|
doc.metadata["sources_used"] = sources |
|
doc.metadata["question_used"] = question |
|
doc.metadata["index_used"] = index |
|
|
|
|
|
docs.extend(docs_question) |
|
|
|
|
|
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True) |
|
new_state = {"documents":docs,"remaining_questions":remaining_questions} |
|
return new_state |
|
|
|
return retrieve_documents |
|
|
|
|