File size: 6,948 Bytes
99e91d8 d562d38 99e91d8 d562d38 9609df9 d562d38 99e91d8 eee8932 d562d38 99e91d8 d562d38 99e91d8 d562d38 99e91d8 d562d38 99e91d8 d562d38 99e91d8 d562d38 99e91d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import sys
import os
from contextlib import contextmanager
from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda
from ..reranker import rerank_docs
from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event
def divide_into_parts(target, parts):
# Base value for each part
base = target // parts
# Remainder to distribute
remainder = target % parts
# List to hold the result
result = []
for i in range(parts):
if i < remainder:
# These parts get base value + 1
result.append(base + 1)
else:
# The rest get the base value
result.append(base)
return result
@contextmanager
def suppress_output():
# Open a null device
with open(os.devnull, 'w') as devnull:
# Store the original stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
# Redirect stdout and stderr to the null device
sys.stdout = devnull
sys.stderr = devnull
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@tool
def query_retriever(question):
"""Just a dummy tool to simulate the retriever query"""
return question
def _add_sources_used_in_metadata(docs,sources,question,index):
for doc in docs:
doc.metadata["sources_used"] = sources
doc.metadata["question_used"] = question
doc.metadata["index_used"] = index
return docs
def _get_k_summary_by_question(n_questions):
if n_questions == 0:
return 0
elif n_questions == 1:
return 5
elif n_questions == 2:
return 3
elif n_questions == 3:
return 2
else:
return 1
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
# @chain
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
print("---- Retrieve documents ----")
# Get the documents from the state
if "documents" in state and state["documents"] is not None:
docs = state["documents"]
else:
docs = []
# Get the related_content from the state
if "related_content" in state and state["related_content"] is not None:
related_content = state["related_content"]
else:
related_content = []
# Get the current question
current_question = state["remaining_questions"][0]
remaining_questions = state["remaining_questions"][1:]
k_by_question = k_final // state["n_questions"]
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
sources = current_question["sources"]
question = current_question["question"]
index = current_question["index"]
print(f"Retrieve documents for question: {question}")
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
if index == "Vector":
# Search the document store using the retriever
# Configure high top k for further reranking step
retriever = ClimateQARetriever(
vectorstore=vectorstore,
sources = sources,
min_size = 200,
k_summary = k_summary_by_question,
k_total = k_before_reranking,
threshold = 0.5,
)
docs_question_dict = await retriever.ainvoke(question,config)
# elif index == "OpenAlex":
# # keyword extraction
# keywords_extraction = make_keywords_extraction_chain(llm)
# keywords = keywords_extraction.invoke(question)["keywords"]
# openalex_query = " AND ".join(keywords)
# print(f"... OpenAlex query: {openalex_query}")
# retriever_openalex = OpenAlexRetriever(
# min_year = state.get("min_year",1960),
# max_year = state.get("max_year",None),
# k = k_before_reranking
# )
# docs_question = await retriever_openalex.ainvoke(openalex_query,config)
# else:
# raise Exception(f"Index {index} not found in the routing index")
# Rerank
if reranker is not None:
with suppress_output():
docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
if rerank_by_question:
docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
else:
docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
# Add a default reranking score
for doc in docs_question:
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
docs_question = docs_question[:k_by_question]
images_question = docs_question_images_reranked[:k_by_question]
if reranker is not None and rerank_by_question:
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
# Add sources used in the metadata
docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
# Add to the list of docs
docs.extend(docs_question)
related_content.extend(images_question)
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
return new_state
def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
@chain
async def retrieve_docs(state, config):
state = await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
return state
return retrieve_docs
|