fix answer latency when having multiple sources
Browse files- app.py +68 -52
- climateqa/engine/chains/graph_retriever.py +68 -67
- climateqa/engine/chains/retrieve_documents.py +115 -30
- climateqa/engine/graph.py +12 -3
- climateqa/engine/graph_retriever.py +54 -14
- climateqa/engine/reranker.py +2 -0
- climateqa/knowledge/retriever.py +95 -94
- sandbox/20241104 - CQA - StepByStep CQA.ipynb +0 -0
app.py
CHANGED
@@ -120,7 +120,7 @@ reranker = get_reranker("nano")
|
|
120 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
121 |
|
122 |
|
123 |
-
async def chat(query,history,audience,sources,reports,
|
124 |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
125 |
(messages in gradio format, messages in langchain format, source documents)"""
|
126 |
|
@@ -136,7 +136,7 @@ async def chat(query,history,audience,sources,reports,current_graphs):
|
|
136 |
if reports is None or len(reports) == 0:
|
137 |
reports = []
|
138 |
|
139 |
-
inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources}
|
140 |
result = agent.astream_events(inputs,version = "v1")
|
141 |
|
142 |
|
@@ -167,7 +167,16 @@ async def chat(query,history,audience,sources,reports,current_graphs):
|
|
167 |
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
|
168 |
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
|
169 |
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
|
172 |
event_description, display_output = steps_display[node]
|
173 |
if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
|
@@ -260,59 +269,59 @@ papers_cols = list(papers_cols_widths.keys())
|
|
260 |
papers_cols_widths = list(papers_cols_widths.values())
|
261 |
|
262 |
|
263 |
-
async def find_papers(query,after):
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
|
293 |
|
294 |
-
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
|
300 |
-
|
301 |
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
|
306 |
-
|
307 |
|
308 |
-
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
|
317 |
|
318 |
|
@@ -473,7 +482,13 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
473 |
value=["IPCC"],
|
474 |
interactive=True,
|
475 |
)
|
476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
dropdown_reports = gr.Dropdown(
|
478 |
POSSIBLE_REPORTS,
|
479 |
label="Or select specific reports",
|
@@ -488,9 +503,10 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
488 |
value="Experts",
|
489 |
interactive=True,
|
490 |
)
|
|
|
491 |
|
492 |
-
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
493 |
-
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
494 |
|
495 |
|
496 |
|
@@ -603,14 +619,14 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
603 |
|
604 |
(textbox
|
605 |
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
606 |
-
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports,
|
607 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
|
608 |
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_recommended_content, tab_papers] )
|
609 |
)
|
610 |
|
611 |
(examples_hidden
|
612 |
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
613 |
-
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports,
|
614 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
615 |
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_recommended_content, tab_papers] )
|
616 |
)
|
@@ -633,8 +649,8 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
633 |
|
634 |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
635 |
|
636 |
-
textbox.submit(find_papers,[textbox,after], [papers_html,citations_network,papers_summary])
|
637 |
-
examples_hidden.change(find_papers,[examples_hidden,after], [papers_html,citations_network,papers_summary])
|
638 |
|
639 |
btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
|
640 |
btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
|
|
|
120 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
121 |
|
122 |
|
123 |
+
async def chat(query, history, audience, sources, reports, relevant_content_sources):
|
124 |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
125 |
(messages in gradio format, messages in langchain format, source documents)"""
|
126 |
|
|
|
136 |
if reports is None or len(reports) == 0:
|
137 |
reports = []
|
138 |
|
139 |
+
inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources, "relevant_content_sources" : relevant_content_sources}
|
140 |
result = agent.astream_events(inputs,version = "v1")
|
141 |
|
142 |
|
|
|
167 |
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
|
168 |
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
|
169 |
|
170 |
+
elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
|
171 |
+
|
172 |
+
intent = event["data"]["output"]["intent"]
|
173 |
+
if "language" in event["data"]["output"]:
|
174 |
+
output_language = event["data"]["output"]["language"]
|
175 |
+
else :
|
176 |
+
output_language = "English"
|
177 |
+
history[-1].content = f"Language identified : {output_language} \n Intent identified : {intent}"
|
178 |
+
|
179 |
+
|
180 |
elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
|
181 |
event_description, display_output = steps_display[node]
|
182 |
if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
|
|
|
269 |
papers_cols_widths = list(papers_cols_widths.values())
|
270 |
|
271 |
|
272 |
+
async def find_papers(query,after, relevant_content_sources):
|
273 |
+
if "OpenAlex" in relevant_content_sources:
|
274 |
+
summary = ""
|
275 |
+
keywords = generate_keywords(query)
|
276 |
+
df_works = oa.search(keywords,after = after)
|
277 |
+
df_works = df_works.dropna(subset=["abstract"])
|
278 |
+
df_works = oa.rerank(query,df_works,reranker)
|
279 |
+
df_works = df_works.sort_values("rerank_score",ascending=False)
|
280 |
+
docs_html = []
|
281 |
+
for i in range(10):
|
282 |
+
docs_html.append(make_html_df(df_works, i))
|
283 |
+
docs_html = "".join(docs_html)
|
284 |
+
print(docs_html)
|
285 |
+
G = oa.make_network(df_works)
|
286 |
|
287 |
+
height = "750px"
|
288 |
+
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
|
289 |
+
network_html = network.generate_html()
|
290 |
|
291 |
+
network_html = network_html.replace("'", "\"")
|
292 |
+
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
|
293 |
+
network_html = network_html + css_to_inject
|
294 |
|
295 |
+
|
296 |
+
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
|
297 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
298 |
+
allow-scripts allow-same-origin allow-popups
|
299 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
300 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
|
301 |
|
302 |
|
303 |
+
docs = df_works["content"].head(10).tolist()
|
304 |
|
305 |
+
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
|
306 |
+
df_works["doc"] = df_works["doc"] + 1
|
307 |
+
df_works = df_works[papers_cols]
|
308 |
|
309 |
+
yield docs_html, network_html, summary
|
310 |
|
311 |
+
chain = make_rag_papers_chain(llm)
|
312 |
+
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
|
313 |
+
path_answer = "/logs/StrOutputParser/streamed_output/-"
|
314 |
|
315 |
+
async for op in result:
|
316 |
|
317 |
+
op = op.ops[0]
|
318 |
|
319 |
+
if op['path'] == path_answer: # reforulated question
|
320 |
+
new_token = op['value'] # str
|
321 |
+
summary += new_token
|
322 |
+
else:
|
323 |
+
continue
|
324 |
+
yield docs_html, network_html, summary
|
325 |
|
326 |
|
327 |
|
|
|
482 |
value=["IPCC"],
|
483 |
interactive=True,
|
484 |
)
|
485 |
+
dropdown_external_sources = gr.CheckboxGroup(
|
486 |
+
["IPCC figures","OpenAlex", "OurWorldInData"],
|
487 |
+
label="Select database to search for relevant content",
|
488 |
+
value=["IPCC figures"],
|
489 |
+
interactive=True,
|
490 |
+
)
|
491 |
+
|
492 |
dropdown_reports = gr.Dropdown(
|
493 |
POSSIBLE_REPORTS,
|
494 |
label="Or select specific reports",
|
|
|
503 |
value="Experts",
|
504 |
interactive=True,
|
505 |
)
|
506 |
+
|
507 |
|
508 |
+
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False)
|
509 |
+
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False)
|
510 |
|
511 |
|
512 |
|
|
|
619 |
|
620 |
(textbox
|
621 |
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
622 |
+
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
|
623 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
|
624 |
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_recommended_content, tab_papers] )
|
625 |
)
|
626 |
|
627 |
(examples_hidden
|
628 |
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
629 |
+
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
|
630 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
631 |
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_recommended_content, tab_papers] )
|
632 |
)
|
|
|
649 |
|
650 |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
|
651 |
|
652 |
+
textbox.submit(find_papers,[textbox,after, dropdown_external_sources], [papers_html,citations_network,papers_summary])
|
653 |
+
examples_hidden.change(find_papers,[examples_hidden,after,dropdown_external_sources], [papers_html,citations_network,papers_summary])
|
654 |
|
655 |
btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
|
656 |
btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
|
climateqa/engine/chains/graph_retriever.py
CHANGED
@@ -3,7 +3,7 @@ import os
|
|
3 |
from contextlib import contextmanager
|
4 |
|
5 |
from ..reranker import rerank_docs
|
6 |
-
from ..graph_retriever import GraphRetriever
|
7 |
from ...utils import remove_duplicates_keep_highest_score
|
8 |
|
9 |
|
@@ -46,82 +46,83 @@ def suppress_output():
|
|
46 |
|
47 |
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
|
54 |
-
# sources_input = state["sources_input"]
|
55 |
-
sources_input = ["auto"]
|
56 |
-
|
57 |
-
auto_mode = "auto" in sources_input
|
58 |
-
|
59 |
-
# There are several options to get the final top k
|
60 |
-
# Option 1 - Get 100 documents by question and rerank by question
|
61 |
-
# Option 2 - Get 100/n documents by question and rerank the total
|
62 |
-
if rerank_by_question:
|
63 |
-
k_by_question = divide_into_parts(k_final,len(questions))
|
64 |
|
65 |
-
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
#
|
74 |
-
if
|
75 |
-
|
76 |
-
|
77 |
else:
|
78 |
-
|
79 |
-
|
80 |
-
if any([x in POSSIBLE_SOURCES for x in sources]):
|
81 |
-
|
82 |
-
sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
83 |
-
|
84 |
-
# Search the document store using the retriever
|
85 |
-
retriever = GraphRetriever(
|
86 |
-
vectorstore = vectorstore,
|
87 |
-
sources = sources,
|
88 |
-
k_total = k_before_reranking,
|
89 |
-
threshold = 0.5,
|
90 |
-
)
|
91 |
-
docs_question = retriever.get_relevant_documents(question)
|
92 |
-
|
93 |
-
# Rerank
|
94 |
-
if reranker is not None and docs_question!=[]:
|
95 |
-
with suppress_output():
|
96 |
-
docs_question = rerank_docs(reranker,docs_question,question)
|
97 |
-
else:
|
98 |
-
# Add a default reranking score
|
99 |
-
for doc in docs_question:
|
100 |
-
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
101 |
-
|
102 |
-
# If rerank by question we select the top documents for each question
|
103 |
-
if rerank_by_question:
|
104 |
-
docs_question = docs_question[:k_by_question[i]]
|
105 |
-
|
106 |
-
# Add sources used in the metadata
|
107 |
for doc in docs_question:
|
108 |
-
doc.metadata["
|
109 |
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
|
125 |
-
|
126 |
|
127 |
-
|
|
|
3 |
from contextlib import contextmanager
|
4 |
|
5 |
from ..reranker import rerank_docs
|
6 |
+
from ..graph_retriever import retrieve_graphs # GraphRetriever
|
7 |
from ...utils import remove_duplicates_keep_highest_score
|
8 |
|
9 |
|
|
|
46 |
|
47 |
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
|
48 |
|
49 |
+
async def node_retrieve_graphs(state):
|
50 |
+
print("---- Retrieving graphs ----")
|
51 |
+
|
52 |
+
POSSIBLE_SOURCES = ["IEA", "OWID"]
|
53 |
+
questions = state["remaining_questions"] if state["remaining_questions"] is not None and state["remaining_questions"]!=[] else [state["query"]]
|
54 |
+
# sources_input = state["sources_input"]
|
55 |
+
sources_input = ["auto"]
|
56 |
+
|
57 |
+
auto_mode = "auto" in sources_input
|
58 |
+
|
59 |
+
# There are several options to get the final top k
|
60 |
+
# Option 1 - Get 100 documents by question and rerank by question
|
61 |
+
# Option 2 - Get 100/n documents by question and rerank the total
|
62 |
+
if rerank_by_question:
|
63 |
+
k_by_question = divide_into_parts(k_final,len(questions))
|
64 |
+
|
65 |
+
docs = []
|
66 |
+
|
67 |
+
for i,q in enumerate(questions):
|
68 |
|
69 |
+
question = q["question"] if isinstance(q, dict) else q
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
print(f"Subquestion {i}: {question}")
|
72 |
|
73 |
+
# If auto mode, we use all sources
|
74 |
+
if auto_mode:
|
75 |
+
sources = POSSIBLE_SOURCES
|
76 |
+
# Otherwise, we use the config
|
77 |
+
else:
|
78 |
+
sources = sources_input
|
79 |
+
|
80 |
+
if any([x in POSSIBLE_SOURCES for x in sources]):
|
81 |
+
|
82 |
+
sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
83 |
|
84 |
+
# Search the document store using the retriever
|
85 |
+
docs_question = await retrieve_graphs(
|
86 |
+
query = question,
|
87 |
+
vectorstore = vectorstore,
|
88 |
+
sources = sources,
|
89 |
+
k_total = k_before_reranking,
|
90 |
+
threshold = 0.5,
|
91 |
+
)
|
92 |
+
# docs_question = retriever.get_relevant_documents(question)
|
93 |
|
94 |
+
# Rerank
|
95 |
+
if reranker is not None and docs_question!=[]:
|
96 |
+
with suppress_output():
|
97 |
+
docs_question = rerank_docs(reranker,docs_question,question)
|
98 |
else:
|
99 |
+
# Add a default reranking score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
for doc in docs_question:
|
101 |
+
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
102 |
|
103 |
+
# If rerank by question we select the top documents for each question
|
104 |
+
if rerank_by_question:
|
105 |
+
docs_question = docs_question[:k_by_question[i]]
|
106 |
+
|
107 |
+
# Add sources used in the metadata
|
108 |
+
for doc in docs_question:
|
109 |
+
doc.metadata["sources_used"] = sources
|
110 |
+
|
111 |
+
print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
|
112 |
|
113 |
+
docs.extend(docs_question)
|
114 |
|
115 |
+
else:
|
116 |
+
print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
|
117 |
+
|
118 |
+
# Remove duplicates and keep the duplicate document with the highest reranking score
|
119 |
+
docs = remove_duplicates_keep_highest_score(docs)
|
120 |
|
121 |
+
# Sorting the list in descending order by rerank_score
|
122 |
+
# Then select the top k
|
123 |
+
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
124 |
+
docs = docs[:k_final]
|
125 |
|
126 |
+
return {"recommended_content": docs}
|
127 |
|
128 |
+
return node_retrieve_graphs
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
@@ -8,10 +8,13 @@ from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
|
8 |
from langchain_core.runnables import RunnableLambda
|
9 |
|
10 |
from ..reranker import rerank_docs
|
11 |
-
from ...knowledge.retriever import ClimateQARetriever
|
12 |
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
from .keywords_extraction import make_keywords_extraction_chain
|
14 |
from ..utils import log_event
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
|
@@ -76,10 +79,110 @@ def _get_k_summary_by_question(n_questions):
|
|
76 |
else:
|
77 |
return 1
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
81 |
# @chain
|
82 |
-
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
83 |
print("---- Retrieve documents ----")
|
84 |
|
85 |
# Get the documents from the state
|
@@ -93,12 +196,15 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
93 |
else:
|
94 |
related_content = []
|
95 |
|
|
|
|
|
96 |
# Get the current question
|
97 |
current_question = state["remaining_questions"][0]
|
98 |
remaining_questions = state["remaining_questions"][1:]
|
99 |
|
100 |
k_by_question = k_final // state["n_questions"]
|
101 |
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
|
|
|
102 |
|
103 |
sources = current_question["sources"]
|
104 |
question = current_question["question"]
|
@@ -108,40 +214,19 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
108 |
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
109 |
|
110 |
|
111 |
-
if index == "Vector":
|
112 |
-
|
113 |
-
|
114 |
-
retriever = ClimateQARetriever(
|
115 |
vectorstore=vectorstore,
|
|
|
116 |
sources = sources,
|
117 |
min_size = 200,
|
118 |
k_summary = k_summary_by_question,
|
119 |
k_total = k_before_reranking,
|
|
|
120 |
threshold = 0.5,
|
121 |
)
|
122 |
-
docs_question_dict = await retriever.ainvoke(question,config)
|
123 |
-
|
124 |
-
|
125 |
-
# elif index == "OpenAlex":
|
126 |
-
# # keyword extraction
|
127 |
-
# keywords_extraction = make_keywords_extraction_chain(llm)
|
128 |
-
|
129 |
-
# keywords = keywords_extraction.invoke(question)["keywords"]
|
130 |
-
# openalex_query = " AND ".join(keywords)
|
131 |
-
|
132 |
-
# print(f"... OpenAlex query: {openalex_query}")
|
133 |
|
134 |
-
# retriever_openalex = OpenAlexRetriever(
|
135 |
-
# min_year = state.get("min_year",1960),
|
136 |
-
# max_year = state.get("max_year",None),
|
137 |
-
# k = k_before_reranking
|
138 |
-
# )
|
139 |
-
# docs_question = await retriever_openalex.ainvoke(openalex_query,config)
|
140 |
-
|
141 |
-
# else:
|
142 |
-
# raise Exception(f"Index {index} not found in the routing index")
|
143 |
-
|
144 |
-
|
145 |
|
146 |
# Rerank
|
147 |
if reranker is not None:
|
@@ -161,7 +246,7 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
161 |
|
162 |
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
|
163 |
docs_question = docs_question[:k_by_question]
|
164 |
-
images_question = docs_question_images_reranked[:
|
165 |
|
166 |
if reranker is not None and rerank_by_question:
|
167 |
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
@@ -173,7 +258,7 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
173 |
# Add to the list of docs
|
174 |
docs.extend(docs_question)
|
175 |
related_content.extend(images_question)
|
176 |
-
|
177 |
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
|
178 |
return new_state
|
179 |
|
|
|
8 |
from langchain_core.runnables import RunnableLambda
|
9 |
|
10 |
from ..reranker import rerank_docs
|
11 |
+
# from ...knowledge.retriever import ClimateQARetriever
|
12 |
from ...knowledge.openalex import OpenAlexRetriever
|
13 |
from .keywords_extraction import make_keywords_extraction_chain
|
14 |
from ..utils import log_event
|
15 |
+
from langchain_core.vectorstores import VectorStore
|
16 |
+
from typing import List
|
17 |
+
from langchain_core.documents.base import Document
|
18 |
|
19 |
|
20 |
|
|
|
79 |
else:
|
80 |
return 1
|
81 |
|
82 |
+
def _get_k_images_by_question(n_questions):
|
83 |
+
if n_questions == 0:
|
84 |
+
return 0
|
85 |
+
elif n_questions == 1:
|
86 |
+
return 5
|
87 |
+
elif n_questions == 2:
|
88 |
+
return 3
|
89 |
+
elif n_questions == 3:
|
90 |
+
return 2
|
91 |
+
else:
|
92 |
+
return 1
|
93 |
+
|
94 |
+
def _add_metadata_and_score(docs: List) -> Document:
|
95 |
+
# Add score to metadata
|
96 |
+
docs_with_metadata = []
|
97 |
+
for i,(doc,score) in enumerate(docs):
|
98 |
+
doc.page_content = doc.page_content.replace("\r\n"," ")
|
99 |
+
doc.metadata["similarity_score"] = score
|
100 |
+
doc.metadata["content"] = doc.page_content
|
101 |
+
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
102 |
+
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
103 |
+
docs_with_metadata.append(doc)
|
104 |
+
return docs_with_metadata
|
105 |
+
|
106 |
+
async def get_IPCC_relevant_documents(
|
107 |
+
query: str,
|
108 |
+
vectorstore:VectorStore,
|
109 |
+
sources:list = ["IPCC","IPBES","IPOS"],
|
110 |
+
search_figures:bool = False,
|
111 |
+
reports:list = [],
|
112 |
+
threshold:float = 0.6,
|
113 |
+
k_summary:int = 3,
|
114 |
+
k_total:int = 10,
|
115 |
+
k_images: int = 5,
|
116 |
+
namespace:str = "vectors",
|
117 |
+
min_size:int = 200,
|
118 |
+
) :
|
119 |
+
|
120 |
+
# Check if all elements in the list are either IPCC or IPBES
|
121 |
+
assert isinstance(sources,list)
|
122 |
+
assert sources
|
123 |
+
assert all([x in ["IPCC","IPBES","IPOS"] for x in sources])
|
124 |
+
assert k_total > k_summary, "k_total should be greater than k_summary"
|
125 |
+
|
126 |
+
# Prepare base search kwargs
|
127 |
+
filters = {}
|
128 |
+
|
129 |
+
if len(reports) > 0:
|
130 |
+
filters["short_name"] = {"$in":reports}
|
131 |
+
else:
|
132 |
+
filters["source"] = { "$in": sources}
|
133 |
+
|
134 |
+
# INIT
|
135 |
+
docs_summaries = []
|
136 |
+
docs_full = []
|
137 |
+
docs_images = []
|
138 |
+
|
139 |
+
# Search for k_summary documents in the summaries dataset
|
140 |
+
filters_summaries = {
|
141 |
+
**filters,
|
142 |
+
"chunk_type":"text",
|
143 |
+
"report_type": { "$in":["SPM"]},
|
144 |
+
}
|
145 |
+
|
146 |
+
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
|
147 |
+
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
|
148 |
+
# docs_summaries = []
|
149 |
+
|
150 |
+
# Search for k_total - k_summary documents in the full reports dataset
|
151 |
+
filters_full = {
|
152 |
+
**filters,
|
153 |
+
"chunk_type":"text",
|
154 |
+
"report_type": { "$nin":["SPM"]},
|
155 |
+
}
|
156 |
+
k_full = k_total - len(docs_summaries)
|
157 |
+
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
158 |
+
|
159 |
+
if search_figures:
|
160 |
+
# Images
|
161 |
+
filters_image = {
|
162 |
+
**filters,
|
163 |
+
"chunk_type":"image"
|
164 |
+
}
|
165 |
+
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
166 |
+
|
167 |
+
|
168 |
+
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
169 |
+
|
170 |
+
# Filter if length are below threshold
|
171 |
+
docs_summaries = [x for x in docs_summaries if len(x.page_content) > min_size]
|
172 |
+
docs_full = [x for x in docs_full if len(x.page_content) > min_size]
|
173 |
+
|
174 |
+
|
175 |
+
return {
|
176 |
+
"docs_summaries" : docs_summaries,
|
177 |
+
"docs_full" : docs_full,
|
178 |
+
"docs_images" : docs_images,
|
179 |
+
}
|
180 |
+
|
181 |
+
|
182 |
|
183 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
184 |
# @chain
|
185 |
+
async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5, k_images=5):
|
186 |
print("---- Retrieve documents ----")
|
187 |
|
188 |
# Get the documents from the state
|
|
|
196 |
else:
|
197 |
related_content = []
|
198 |
|
199 |
+
search_figures = "IPCC figures" in state["relevant_content_sources"]
|
200 |
+
|
201 |
# Get the current question
|
202 |
current_question = state["remaining_questions"][0]
|
203 |
remaining_questions = state["remaining_questions"][1:]
|
204 |
|
205 |
k_by_question = k_final // state["n_questions"]
|
206 |
k_summary_by_question = _get_k_summary_by_question(state["n_questions"])
|
207 |
+
k_images_by_question = _get_k_images_by_question(state["n_questions"])
|
208 |
|
209 |
sources = current_question["sources"]
|
210 |
question = current_question["question"]
|
|
|
214 |
await log_event({"question":question,"sources":sources,"index":index},"log_retriever",config)
|
215 |
|
216 |
|
217 |
+
if index == "Vector": # always true for now
|
218 |
+
docs_question_dict = await get_IPCC_relevant_documents(
|
219 |
+
query = question,
|
|
|
220 |
vectorstore=vectorstore,
|
221 |
+
search_figures = search_figures,
|
222 |
sources = sources,
|
223 |
min_size = 200,
|
224 |
k_summary = k_summary_by_question,
|
225 |
k_total = k_before_reranking,
|
226 |
+
k_images = k_images_by_question,
|
227 |
threshold = 0.5,
|
228 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
# Rerank
|
232 |
if reranker is not None:
|
|
|
246 |
|
247 |
docs_question = docs_question_summary_reranked + docs_question_fulltext_reranked
|
248 |
docs_question = docs_question[:k_by_question]
|
249 |
+
images_question = docs_question_images_reranked[:k_images]
|
250 |
|
251 |
if reranker is not None and rerank_by_question:
|
252 |
docs_question = sorted(docs_question, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
|
|
258 |
# Add to the list of docs
|
259 |
docs.extend(docs_question)
|
260 |
related_content.extend(images_question)
|
261 |
+
# related_content=[]
|
262 |
new_state = {"documents":docs, "related_contents": related_content,"remaining_questions":remaining_questions}
|
263 |
return new_state
|
264 |
|
climateqa/engine/graph.py
CHANGED
@@ -36,6 +36,7 @@ class GraphState(TypedDict):
|
|
36 |
answer: str
|
37 |
audience: str = "experts"
|
38 |
sources_input: List[str] = ["IPCC","IPBES"]
|
|
|
39 |
sources_auto: bool = True
|
40 |
min_year: int = 1960
|
41 |
max_year: int = None
|
@@ -153,20 +154,28 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, thresh
|
|
153 |
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
154 |
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
155 |
)
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
# Define the edges
|
158 |
# workflow.add_edge("set_defaults", "categorize_intent")
|
159 |
workflow.add_edge("translate_query", "transform_query")
|
160 |
-
workflow.add_edge("transform_query", "retrieve_graphs")
|
|
|
|
|
161 |
# workflow.add_edge("retrieve_graphs", "answer_rag_graph")
|
162 |
-
workflow.add_edge("retrieve_graphs",
|
|
|
163 |
# workflow.add_edge("answer_rag_graph", "retrieve_documents")
|
164 |
workflow.add_edge("answer_rag", END)
|
165 |
workflow.add_edge("answer_rag_no_docs", END)
|
166 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
167 |
# workflow.add_edge("answer_chitchat", END)
|
168 |
# workflow.add_edge("answer_ai_impact", END)
|
169 |
-
workflow.add_edge("retrieve_graphs_chitchat", END)
|
170 |
# workflow.add_edge("answer_ai_impact", "translate_query_ai")
|
171 |
# workflow.add_edge("translate_query_ai", "transform_query_ai")
|
172 |
# workflow.add_edge("transform_query_ai", "retrieve_graphs_ai")
|
|
|
36 |
answer: str
|
37 |
audience: str = "experts"
|
38 |
sources_input: List[str] = ["IPCC","IPBES"]
|
39 |
+
relevant_content_sources: List[str] = ["IPCC figures"]
|
40 |
sources_auto: bool = True
|
41 |
min_year: int = 1960
|
42 |
max_year: int = None
|
|
|
154 |
lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
155 |
make_id_dict(["answer_rag","answer_rag_no_docs"])
|
156 |
)
|
157 |
+
workflow.add_conditional_edges(
|
158 |
+
"transform_query",
|
159 |
+
lambda state : "retrieve_graphs" if "OurWorldInData" in state["relevant_content_sources"] else END,
|
160 |
+
make_id_dict(["retrieve_graphs", END])
|
161 |
+
)
|
162 |
|
163 |
# Define the edges
|
164 |
# workflow.add_edge("set_defaults", "categorize_intent")
|
165 |
workflow.add_edge("translate_query", "transform_query")
|
166 |
+
# workflow.add_edge("transform_query", "retrieve_graphs")
|
167 |
+
workflow.add_edge("transform_query", "retrieve_documents")
|
168 |
+
|
169 |
# workflow.add_edge("retrieve_graphs", "answer_rag_graph")
|
170 |
+
workflow.add_edge("retrieve_graphs", END)
|
171 |
+
# workflow.add_edge("retrieve_graphs", "retrieve_documents")
|
172 |
# workflow.add_edge("answer_rag_graph", "retrieve_documents")
|
173 |
workflow.add_edge("answer_rag", END)
|
174 |
workflow.add_edge("answer_rag_no_docs", END)
|
175 |
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
176 |
# workflow.add_edge("answer_chitchat", END)
|
177 |
# workflow.add_edge("answer_ai_impact", END)
|
178 |
+
# workflow.add_edge("retrieve_graphs_chitchat", END)
|
179 |
# workflow.add_edge("answer_ai_impact", "translate_query_ai")
|
180 |
# workflow.add_edge("translate_query_ai", "transform_query_ai")
|
181 |
# workflow.add_edge("transform_query_ai", "retrieve_graphs_ai")
|
climateqa/engine/graph_retriever.py
CHANGED
@@ -5,30 +5,70 @@ from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
|
5 |
|
6 |
from typing import List
|
7 |
|
8 |
-
class GraphRetriever(BaseRetriever):
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# Check if all elements in the list are IEA or OWID
|
19 |
-
assert isinstance(
|
20 |
-
assert
|
21 |
-
assert any([x in ["OWID"] for x in
|
22 |
|
23 |
# Prepare base search kwargs
|
24 |
filters = {}
|
25 |
|
26 |
-
filters["source"] = {"$in":
|
27 |
|
28 |
-
docs =
|
29 |
|
30 |
# Filter if scores are below threshold
|
31 |
-
docs = [x for x in docs if x[1] >
|
32 |
|
33 |
# Remove duplicate documents
|
34 |
unique_docs = []
|
|
|
5 |
|
6 |
from typing import List
|
7 |
|
8 |
+
# class GraphRetriever(BaseRetriever):
|
9 |
+
# vectorstore:VectorStore
|
10 |
+
# sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
11 |
+
# threshold:float = 0.5
|
12 |
+
# k_total:int = 10
|
13 |
|
14 |
+
# def _get_relevant_documents(
|
15 |
+
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
16 |
+
# ) -> List[Document]:
|
17 |
+
|
18 |
+
# # Check if all elements in the list are IEA or OWID
|
19 |
+
# assert isinstance(self.sources,list)
|
20 |
+
# assert self.sources
|
21 |
+
# assert any([x in ["OWID"] for x in self.sources])
|
22 |
+
|
23 |
+
# # Prepare base search kwargs
|
24 |
+
# filters = {}
|
25 |
+
|
26 |
+
# filters["source"] = {"$in": self.sources}
|
27 |
+
|
28 |
+
# docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
|
29 |
+
|
30 |
+
# # Filter if scores are below threshold
|
31 |
+
# docs = [x for x in docs if x[1] > self.threshold]
|
32 |
+
|
33 |
+
# # Remove duplicate documents
|
34 |
+
# unique_docs = []
|
35 |
+
# seen_docs = []
|
36 |
+
# for i, doc in enumerate(docs):
|
37 |
+
# if doc[0].page_content not in seen_docs:
|
38 |
+
# unique_docs.append(doc)
|
39 |
+
# seen_docs.append(doc[0].page_content)
|
40 |
+
|
41 |
+
# # Add score to metadata
|
42 |
+
# results = []
|
43 |
+
# for i,(doc,score) in enumerate(unique_docs):
|
44 |
+
# doc.metadata["similarity_score"] = score
|
45 |
+
# doc.metadata["content"] = doc.page_content
|
46 |
+
# results.append(doc)
|
47 |
+
|
48 |
+
# return results
|
49 |
+
|
50 |
+
async def retrieve_graphs(
|
51 |
+
query: str,
|
52 |
+
vectorstore:VectorStore,
|
53 |
+
sources:list = ["OWID"], # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
54 |
+
threshold:float = 0.5,
|
55 |
+
k_total:int = 10,
|
56 |
+
)-> List[Document]:
|
57 |
|
58 |
# Check if all elements in the list are IEA or OWID
|
59 |
+
assert isinstance(sources,list)
|
60 |
+
assert sources
|
61 |
+
assert any([x in ["OWID"] for x in sources])
|
62 |
|
63 |
# Prepare base search kwargs
|
64 |
filters = {}
|
65 |
|
66 |
+
filters["source"] = {"$in": sources}
|
67 |
|
68 |
+
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
|
69 |
|
70 |
# Filter if scores are below threshold
|
71 |
+
docs = [x for x in docs if x[1] > threshold]
|
72 |
|
73 |
# Remove duplicate documents
|
74 |
unique_docs = []
|
climateqa/engine/reranker.py
CHANGED
@@ -30,6 +30,8 @@ def get_reranker(model = "nano", cohere_api_key = None):
|
|
30 |
|
31 |
|
32 |
def rerank_docs(reranker,docs,query):
|
|
|
|
|
33 |
|
34 |
# Get a list of texts from langchain docs
|
35 |
input_docs = [x.page_content for x in docs]
|
|
|
30 |
|
31 |
|
32 |
def rerank_docs(reranker,docs,query):
|
33 |
+
if docs == []:
|
34 |
+
return []
|
35 |
|
36 |
# Get a list of texts from langchain docs
|
37 |
input_docs = [x.page_content for x in docs]
|
climateqa/knowledge/retriever.py
CHANGED
@@ -1,101 +1,102 @@
|
|
1 |
-
# https://github.com/langchain-ai/langchain/issues/8623
|
2 |
-
|
3 |
-
import pandas as pd
|
4 |
-
|
5 |
-
from langchain_core.retrievers import BaseRetriever
|
6 |
-
from langchain_core.vectorstores import VectorStoreRetriever
|
7 |
-
from langchain_core.documents.base import Document
|
8 |
-
from langchain_core.vectorstores import VectorStore
|
9 |
-
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
10 |
-
|
11 |
-
from typing import List
|
12 |
-
from pydantic import Field
|
13 |
-
|
14 |
-
def _add_metadata_and_score(docs: List) -> Document:
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
class ClimateQARetriever(BaseRetriever):
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
|
36 |
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
docs = docs_summaries + docs_full + docs_images
|
84 |
-
|
85 |
-
# Filter if scores are below threshold
|
86 |
-
docs = [x for x in docs if len(x[0].page_content) > self.min_size]
|
87 |
-
# docs = [x for x in docs if x[1] > self.threshold]
|
88 |
-
|
89 |
-
docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
|
94 |
-
|
95 |
-
return {
|
96 |
-
"docs_summaries" : docs_summaries,
|
97 |
-
"docs_full" : docs_full,
|
98 |
-
"docs_images" : docs_images
|
99 |
-
}
|
100 |
|
|
|
|
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # https://github.com/langchain-ai/langchain/issues/8623
|
2 |
+
|
3 |
+
# import pandas as pd
|
4 |
+
|
5 |
+
# from langchain_core.retrievers import BaseRetriever
|
6 |
+
# from langchain_core.vectorstores import VectorStoreRetriever
|
7 |
+
# from langchain_core.documents.base import Document
|
8 |
+
# from langchain_core.vectorstores import VectorStore
|
9 |
+
# from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
10 |
+
|
11 |
+
# from typing import List
|
12 |
+
# from pydantic import Field
|
13 |
+
|
14 |
+
# def _add_metadata_and_score(docs: List) -> Document:
|
15 |
+
# # Add score to metadata
|
16 |
+
# docs_with_metadata = []
|
17 |
+
# for i,(doc,score) in enumerate(docs):
|
18 |
+
# doc.page_content = doc.page_content.replace("\r\n"," ")
|
19 |
+
# doc.metadata["similarity_score"] = score
|
20 |
+
# doc.metadata["content"] = doc.page_content
|
21 |
+
# doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
22 |
+
# # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
23 |
+
# docs_with_metadata.append(doc)
|
24 |
+
# return docs_with_metadata
|
25 |
+
|
26 |
+
# class ClimateQARetriever(BaseRetriever):
|
27 |
+
# vectorstore:VectorStore
|
28 |
+
# sources:list = ["IPCC","IPBES","IPOS"]
|
29 |
+
# reports:list = []
|
30 |
+
# threshold:float = 0.6
|
31 |
+
# k_summary:int = 3
|
32 |
+
# k_total:int = 10
|
33 |
+
# namespace:str = "vectors",
|
34 |
+
# min_size:int = 200,
|
35 |
|
36 |
|
37 |
|
38 |
+
# def _get_relevant_documents(
|
39 |
+
# self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
40 |
+
# ) -> List[Document]:
|
41 |
+
|
42 |
+
# # Check if all elements in the list are either IPCC or IPBES
|
43 |
+
# assert isinstance(self.sources,list)
|
44 |
+
# assert self.sources
|
45 |
+
# assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
|
46 |
+
# assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
47 |
+
|
48 |
+
# # Prepare base search kwargs
|
49 |
+
# filters = {}
|
50 |
+
|
51 |
+
# if len(self.reports) > 0:
|
52 |
+
# filters["short_name"] = {"$in":self.reports}
|
53 |
+
# else:
|
54 |
+
# filters["source"] = { "$in":self.sources}
|
55 |
+
|
56 |
+
# # Search for k_summary documents in the summaries dataset
|
57 |
+
# filters_summaries = {
|
58 |
+
# **filters,
|
59 |
+
# "chunk_type":"text",
|
60 |
+
# "report_type": { "$in":["SPM"]},
|
61 |
+
# }
|
62 |
+
|
63 |
+
# docs_summaries = self.vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = self.k_summary)
|
64 |
+
# docs_summaries = [x for x in docs_summaries if x[1] > self.threshold]
|
65 |
+
# # docs_summaries = []
|
66 |
+
|
67 |
+
# # Search for k_total - k_summary documents in the full reports dataset
|
68 |
+
# filters_full = {
|
69 |
+
# **filters,
|
70 |
+
# "chunk_type":"text",
|
71 |
+
# "report_type": { "$nin":["SPM"]},
|
72 |
+
# }
|
73 |
+
# k_full = self.k_total - len(docs_summaries)
|
74 |
+
# docs_full = self.vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_full)
|
75 |
|
76 |
+
# # Images
|
77 |
+
# filters_image = {
|
78 |
+
# **filters,
|
79 |
+
# "chunk_type":"image"
|
80 |
+
# }
|
81 |
+
# docs_images = self.vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_full)
|
82 |
+
|
83 |
+
# # docs_images = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
# # Concatenate documents
|
86 |
+
# # docs = docs_summaries + docs_full + docs_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
+
# # Filter if scores are below threshold
|
89 |
+
# # docs = [x for x in docs if x[1] > self.threshold]
|
90 |
|
91 |
+
# docs_summaries, docs_full, docs_images = _add_metadata_and_score(docs_summaries), _add_metadata_and_score(docs_full), _add_metadata_and_score(docs_images)
|
92 |
+
|
93 |
+
# # Filter if length are below threshold
|
94 |
+
# docs_summaries = [x for x in docs_summaries if len(x.page_content) > self.min_size]
|
95 |
+
# docs_full = [x for x in docs_full if len(x.page_content) > self.min_size]
|
96 |
+
|
97 |
+
|
98 |
+
# return {
|
99 |
+
# "docs_summaries" : docs_summaries,
|
100 |
+
# "docs_full" : docs_full,
|
101 |
+
# "docs_images" : docs_images,
|
102 |
+
# }
|
sandbox/20241104 - CQA - StepByStep CQA.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|