File size: 6,948 Bytes
99e91d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d562d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e91d8
d562d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9609df9
d562d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e91d8
eee8932
d562d38
 
 
99e91d8
d562d38
 
99e91d8
d562d38
99e91d8
d562d38
 
 
 
 
 
99e91d8
d562d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e91d8
 
d562d38
 
 
 
 
 
 
99e91d8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import sys
import os
from contextlib import contextmanager

from langchain_core.tools import tool
from langchain_core.runnables import chain
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.runnables import RunnableLambda

from ..reranker import rerank_docs
from ...knowledge.retriever import ClimateQARetriever
from ...knowledge.openalex import OpenAlexRetriever
from .keywords_extraction import make_keywords_extraction_chain
from ..utils import log_event



def divide_into_parts(target, parts):
    # Base value for each part
    base = target // parts
    # Remainder to distribute
    remainder = target % parts
    # List to hold the result
    result = []
    
    for i in range(parts):
        if i < remainder:
            # These parts get base value + 1
            result.append(base + 1)
        else:
            # The rest get the base value
            result.append(base)
    
    return result


@contextmanager
def suppress_output():
    # Open a null device
    with open(os.devnull, 'w') as devnull:
        # Store the original stdout and stderr
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        # Redirect stdout and stderr to the null device
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            # Restore stdout and stderr
            sys.stdout = old_stdout
            sys.stderr = old_stderr


@tool
def query_retriever(question):
    """Just a dummy tool to simulate the retriever query"""
    return question

def _add_sources_used_in_metadata(docs,sources,question,index):
    for doc in docs:
        doc.metadata["sources_used"] = sources
        doc.metadata["question_used"] = question
        doc.metadata["index_used"] = index
    return docs

def _get_k_summary_by_question(n_questions):
    if n_questions == 0:
        return 0
    elif n_questions == 1:
        return 5
    elif n_questions == 2:
        return 3
    elif n_questions == 3:
        return 2
    else:
        return 1
    

# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
# @chain
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
    print("---- Retrieve documents ----")
    
    # Get the documents from the state
    if "documents" in state and state["documents"] is not None:
        docs = state["documents"]
    else:
        docs = []
    # Get the related_content from the state
    if "related_content" in state and state["related_content"] is not None:
        related_content = state["related_content"]
    else:
        related_content = []
    
    # Get the current question
    current_question = state["remaining_questions"][0]
    remaining_questions = state["remaining_questions"][1:]
         
    k_by_question = k_final // state["n_questions"]
    k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
    
    sources = current_question["sources"]
    question = current_question["question"]
    index = current_question["index"]
    
    print(f"Retrieve documents for question: {question}")
    await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)


    if index == "Vector":
        # Search the document store using the retriever
        # Configure high top k for further reranking step
        retriever = ClimateQARetriever(
            vectorstore=vectorstore,
            sources = sources,
            min_size = 200,
            k_summary = k_summary_by_question,
            k_total = k_before_reranking,
            threshold = 0.5,
        )
        docs_question_dict = await retriever.ainvoke(question,config)
        
        
    # elif index == "OpenAlex":
    #     # keyword extraction
    #     keywords_extraction = make_keywords_extraction_chain(llm)
        
    #     keywords = keywords_extraction.invoke(question)["keywords"]
    #     openalex_query = " AND ".join(keywords)

    #     print(f"... OpenAlex query: {openalex_query}")

    #     retriever_openalex = OpenAlexRetriever(
    #         min_year = state.get("min_year",1960), 
    #         max_year = state.get("max_year",None), 
    #         k = k_before_reranking
    #     )
    #     docs_question = await retriever_openalex.ainvoke(openalex_query,config)

    # else:
    #     raise Exception(f"Index {index} not found in the routing index")
    
    
    
    # Rerank
    if reranker is not None:
        with suppress_output():
            docs_question_summary_reranked = rerank_docs(reranker,docs_question_dict["docs_summaries"],question)
            docs_question_fulltext_reranked = rerank_docs(reranker,docs_question_dict["docs_full"],question)
            docs_question_images_reranked = rerank_docs(reranker,docs_question_dict["docs_images"],question)
            if rerank_by_question:
                docs_question_summary_reranked = sorted(docs_question_summary_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
                docs_question_fulltext_reranked = sorted(docs_question_fulltext_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
                docs_question_images_reranked = sorted(docs_question_images_reranked, key=lambda x: x.metadata["reranking_score"], reverse=True)
    else:
        docs_question = docs_question_dict["docs_summaries"] + docs_question_dict["docs_full"]
        # Add a default reranking score
        for doc in docs_question:
            doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
    
    docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
    docs_question = docs_question[:k_by_question]
    images_question = docs_question_images_reranked[:k_by_question]
            
    if reranker is not None and rerank_by_question:
        docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
            
    # Add sources used in the metadata
    docs_question = _add_sources_used_in_metadata(docs_question,sources,question,index)
    images_question = _add_sources_used_in_metadata(images_question,sources,question,index)
    
    # Add to the list of docs
    docs.extend(docs_question)
    related_content.extend(images_question)
    
    new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
    return new_state
    


def make_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
    @chain
    async def retrieve_docs(state, config):
        state =  await retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question, k_final, k_before_reranking, k_summary)
        return state
    
    return retrieve_docs