add dora graph recommandation
Browse files- app.py +186 -29
- climateqa/constants.py +22 -1
- climateqa/engine/chains/answer_ai_impact.py +1 -0
- climateqa/engine/chains/answer_rag.py +1 -0
- climateqa/engine/chains/chitchat_categorization.py +43 -0
- climateqa/engine/chains/graph_retriever.py +126 -0
- climateqa/engine/chains/intent_categorization.py +37 -8
- climateqa/engine/chains/prompts.py +24 -1
- climateqa/engine/chains/query_transformation.py +7 -2
- climateqa/engine/chains/retriever.py +1 -1
- climateqa/engine/chains/set_defaults.py +13 -0
- climateqa/engine/graph.py +192 -14
- climateqa/engine/graph_retriever.py +48 -0
- climateqa/engine/reranker.py +11 -2
- climateqa/engine/retriever.py +1 -0
- climateqa/engine/vectorstore.py +6 -0
- climateqa/utils.py +13 -0
- front/utils.py +79 -0
- style.css +30 -2
app.py
CHANGED
@@ -25,7 +25,8 @@ from azure.storage.fileshare import ShareServiceClient
|
|
25 |
|
26 |
from utils import create_user_id
|
27 |
|
28 |
-
|
|
|
29 |
|
30 |
# ClimateQ&A imports
|
31 |
from climateqa.engine.llm import get_llm
|
@@ -35,13 +36,14 @@ from climateqa.engine.reranker import get_reranker
|
|
35 |
from climateqa.engine.embeddings import get_embeddings_function
|
36 |
from climateqa.engine.chains.prompts import audience_prompts
|
37 |
from climateqa.sample_questions import QUESTIONS
|
38 |
-
from climateqa.constants import POSSIBLE_REPORTS
|
39 |
from climateqa.utils import get_image_from_azure_blob_storage
|
40 |
from climateqa.engine.keywords import make_keywords_chain
|
41 |
# from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
42 |
from climateqa.engine.graph import make_graph_agent,display_graph
|
|
|
43 |
|
44 |
-
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox
|
45 |
|
46 |
# Load environment variables in local mode
|
47 |
try:
|
@@ -50,6 +52,7 @@ try:
|
|
50 |
except Exception as e:
|
51 |
pass
|
52 |
|
|
|
53 |
# Set up Gradio Theme
|
54 |
theme = gr.themes.Base(
|
55 |
primary_hue="blue",
|
@@ -83,17 +86,18 @@ share_client = service.get_share_client(file_share_name)
|
|
83 |
user_id = create_user_id()
|
84 |
|
85 |
|
|
|
|
|
|
|
86 |
|
87 |
# Create vectorstore and retriever
|
88 |
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
89 |
-
|
90 |
-
reranker = get_reranker("large")
|
91 |
-
agent = make_graph_agent(llm,vectorstore,reranker)
|
92 |
-
|
93 |
-
|
94 |
|
|
|
|
|
95 |
|
96 |
-
async def chat(query,history,audience,sources,reports):
|
97 |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
98 |
(messages in gradio format, messages in langchain format, source documents)"""
|
99 |
|
@@ -110,13 +114,14 @@ async def chat(query,history,audience,sources,reports):
|
|
110 |
audience_prompt = audience_prompts["experts"]
|
111 |
|
112 |
# Prepare default values
|
113 |
-
if len(sources) == 0:
|
114 |
-
sources = ["IPCC"]
|
115 |
|
116 |
-
if len(reports) == 0:
|
117 |
reports = []
|
118 |
|
119 |
inputs = {"user_input": query,"audience": audience_prompt,"sources":sources}
|
|
|
120 |
result = agent.astream_events(inputs,version = "v1") #{"callbacks":[MyCustomAsyncHandler()]})
|
121 |
# result = rag_chain.stream(inputs)
|
122 |
|
@@ -126,11 +131,14 @@ async def chat(query,history,audience,sources,reports):
|
|
126 |
# path_answer = "/logs/answer/streamed_output_str/-"
|
127 |
|
128 |
docs = []
|
|
|
129 |
docs_html = ""
|
|
|
130 |
output_query = ""
|
131 |
output_language = ""
|
132 |
output_keywords = ""
|
133 |
gallery = []
|
|
|
134 |
start_streaming = False
|
135 |
|
136 |
steps_display = {
|
@@ -142,7 +150,7 @@ async def chat(query,history,audience,sources,reports):
|
|
142 |
try:
|
143 |
async for event in result:
|
144 |
|
145 |
-
if event["event"] == "on_chat_model_stream":
|
146 |
if start_streaming == False:
|
147 |
start_streaming = True
|
148 |
history[-1] = (query,"")
|
@@ -155,14 +163,17 @@ async def chat(query,history,audience,sources,reports):
|
|
155 |
answer_yet = parse_output_llm_with_sources(answer_yet)
|
156 |
history[-1] = (query,answer_yet)
|
157 |
|
|
|
|
|
158 |
|
159 |
-
elif event["name"] == "retrieve_documents" and event["event"] == "on_chain_end":
|
160 |
try:
|
161 |
docs = event["data"]["output"]["documents"]
|
162 |
docs_html = []
|
163 |
for i, d in enumerate(docs, 1):
|
164 |
docs_html.append(make_html_source(d, i))
|
165 |
docs_html = "".join(docs_html)
|
|
|
166 |
except Exception as e:
|
167 |
print(f"Error getting documents: {e}")
|
168 |
print(event)
|
@@ -174,6 +185,55 @@ async def chat(query,history,audience,sources,reports):
|
|
174 |
# answer_yet = "🔄️ Searching in the knowledge base\n{questions}"
|
175 |
# history[-1] = (query,answer_yet)
|
176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
for event_name,(event_description,display_output) in steps_display.items():
|
179 |
if event["name"] == event_name:
|
@@ -181,6 +241,7 @@ async def chat(query,history,audience,sources,reports):
|
|
181 |
# answer_yet = f"<p><span class='loader'></span>{event_description}</p>"
|
182 |
# answer_yet = make_toolbox(event_description, "", checked = False)
|
183 |
answer_yet = event_description
|
|
|
184 |
history[-1] = (query,answer_yet)
|
185 |
# elif event["event"] == "on_chain_end":
|
186 |
# answer_yet = ""
|
@@ -205,7 +266,8 @@ async def chat(query,history,audience,sources,reports):
|
|
205 |
|
206 |
|
207 |
history = [tuple(x) for x in history]
|
208 |
-
yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
|
|
|
209 |
|
210 |
except Exception as e:
|
211 |
raise gr.Error(f"{e}")
|
@@ -268,12 +330,14 @@ async def chat(query,history,audience,sources,reports):
|
|
268 |
history[-1] = (history[-1][0],answer_yet)
|
269 |
history = [tuple(x) for x in history]
|
270 |
|
|
|
|
|
271 |
# gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])]
|
272 |
# if len(gallery) > 0:
|
273 |
# gallery = list(set("|".join(gallery).split("|")))
|
274 |
# gallery = [get_image_from_azure_blob_storage(x) for x in gallery]
|
275 |
|
276 |
-
|
277 |
|
278 |
|
279 |
|
@@ -405,16 +469,27 @@ def vote(data: gr.LikeData):
|
|
405 |
else:
|
406 |
print(data)
|
407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
|
410 |
with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
|
411 |
-
|
|
|
|
|
|
|
|
|
412 |
|
413 |
with gr.Tab("ClimateQ&A"):
|
414 |
|
415 |
with gr.Row(elem_id="chatbot-row"):
|
416 |
with gr.Column(scale=2):
|
417 |
-
|
418 |
chatbot = gr.Chatbot(
|
419 |
value=[(None,init_prompt)],
|
420 |
show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
|
@@ -468,13 +543,13 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
|
|
468 |
# with Modal(visible = False) as config_modal:
|
469 |
with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
|
470 |
|
471 |
-
gr.Markdown("
|
472 |
|
473 |
|
474 |
dropdown_sources = gr.CheckboxGroup(
|
475 |
["IPCC", "IPBES","IPOS"],
|
476 |
label="Select source",
|
477 |
-
value=["IPCC"],
|
478 |
interactive=True,
|
479 |
)
|
480 |
|
@@ -495,13 +570,84 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
|
|
495 |
|
496 |
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
497 |
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
498 |
-
|
499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
|
501 |
#---------------------------------------------------------------------------------------
|
502 |
# OTHER TABS
|
503 |
#---------------------------------------------------------------------------------------
|
504 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
|
506 |
with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
|
507 |
gallery_component = gr.Gallery()
|
@@ -526,7 +672,11 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
|
|
526 |
# with gr.Tab("Citations network",elem_id="papers-network-tab"):
|
527 |
# citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
|
528 |
|
529 |
-
|
|
|
|
|
|
|
|
|
530 |
|
531 |
with gr.Tab("About",elem_classes = "max-height other-tabs"):
|
532 |
with gr.Row():
|
@@ -540,18 +690,25 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
|
|
540 |
return (gr.update(interactive = False),gr.update(selected=1),history)
|
541 |
|
542 |
def finish_chat():
|
543 |
-
return (gr.update(interactive = True,value = ""))
|
544 |
|
|
|
|
|
|
|
|
|
|
|
545 |
(textbox
|
546 |
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
547 |
-
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component],concurrency_limit = 8,api_name = "chat_textbox")
|
548 |
-
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
|
|
|
549 |
)
|
550 |
|
551 |
(examples_hidden
|
552 |
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
553 |
-
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component],concurrency_limit = 8,api_name = "chat_examples")
|
554 |
-
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
|
|
555 |
)
|
556 |
|
557 |
|
@@ -570,4 +727,4 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
|
|
570 |
|
571 |
demo.queue()
|
572 |
|
573 |
-
demo.launch()
|
|
|
25 |
|
26 |
from utils import create_user_id
|
27 |
|
28 |
+
from langchain_chroma import Chroma
|
29 |
+
from collections import defaultdict
|
30 |
|
31 |
# ClimateQ&A imports
|
32 |
from climateqa.engine.llm import get_llm
|
|
|
36 |
from climateqa.engine.embeddings import get_embeddings_function
|
37 |
from climateqa.engine.chains.prompts import audience_prompts
|
38 |
from climateqa.sample_questions import QUESTIONS
|
39 |
+
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
|
40 |
from climateqa.utils import get_image_from_azure_blob_storage
|
41 |
from climateqa.engine.keywords import make_keywords_chain
|
42 |
# from climateqa.engine.chains.answer_rag import make_rag_papers_chain
|
43 |
from climateqa.engine.graph import make_graph_agent,display_graph
|
44 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
45 |
|
46 |
+
from front.utils import make_html_source,parse_output_llm_with_sources,serialize_docs,make_toolbox,generate_html_graphs
|
47 |
|
48 |
# Load environment variables in local mode
|
49 |
try:
|
|
|
52 |
except Exception as e:
|
53 |
pass
|
54 |
|
55 |
+
|
56 |
# Set up Gradio Theme
|
57 |
theme = gr.themes.Base(
|
58 |
primary_hue="blue",
|
|
|
86 |
user_id = create_user_id()
|
87 |
|
88 |
|
89 |
+
embeddings_function = get_embeddings_function()
|
90 |
+
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
91 |
+
reranker = get_reranker("nano")
|
92 |
|
93 |
# Create vectorstore and retriever
|
94 |
vectorstore = get_pinecone_vectorstore(embeddings_function)
|
95 |
+
vectorstore_graphs = Chroma(persist_directory="/home/tim/ai4s/climate_qa/dora/climate-question-answering-graphs/climate-question-answering-graphs/vectorstore_owid", embedding_function=embeddings_function)
|
|
|
|
|
|
|
|
|
96 |
|
97 |
+
# agent = make_graph_agent(llm,vectorstore,reranker)
|
98 |
+
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
99 |
|
100 |
+
async def chat(query,history,audience,sources,reports,current_graphs):
|
101 |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
|
102 |
(messages in gradio format, messages in langchain format, source documents)"""
|
103 |
|
|
|
114 |
audience_prompt = audience_prompts["experts"]
|
115 |
|
116 |
# Prepare default values
|
117 |
+
if sources is None or len(sources) == 0:
|
118 |
+
sources = ["IPCC", "IPBES", "IPOS"]
|
119 |
|
120 |
+
if reports is None or len(reports) == 0:
|
121 |
reports = []
|
122 |
|
123 |
inputs = {"user_input": query,"audience": audience_prompt,"sources":sources}
|
124 |
+
print(f"\n\nInputs:\n {inputs}\n\n")
|
125 |
result = agent.astream_events(inputs,version = "v1") #{"callbacks":[MyCustomAsyncHandler()]})
|
126 |
# result = rag_chain.stream(inputs)
|
127 |
|
|
|
131 |
# path_answer = "/logs/answer/streamed_output_str/-"
|
132 |
|
133 |
docs = []
|
134 |
+
docs_used = True
|
135 |
docs_html = ""
|
136 |
+
current_graphs = []
|
137 |
output_query = ""
|
138 |
output_language = ""
|
139 |
output_keywords = ""
|
140 |
gallery = []
|
141 |
+
updates = []
|
142 |
start_streaming = False
|
143 |
|
144 |
steps_display = {
|
|
|
150 |
try:
|
151 |
async for event in result:
|
152 |
|
153 |
+
if event["event"] == "on_chat_model_stream" and event["metadata"]["langgraph_node"] in ["answer_rag", "answer_rag_no_docs", "answer_chitchat", "answer_ai_impact"]:
|
154 |
if start_streaming == False:
|
155 |
start_streaming = True
|
156 |
history[-1] = (query,"")
|
|
|
163 |
answer_yet = parse_output_llm_with_sources(answer_yet)
|
164 |
history[-1] = (query,answer_yet)
|
165 |
|
166 |
+
if docs_used is True and event["metadata"]["langgraph_node"] in ["answer_rag_no_docs", "answer_chitchat", "answer_ai_impact"]:
|
167 |
+
docs_used = False
|
168 |
|
169 |
+
elif docs_used is True and event["name"] == "retrieve_documents" and event["event"] == "on_chain_end":
|
170 |
try:
|
171 |
docs = event["data"]["output"]["documents"]
|
172 |
docs_html = []
|
173 |
for i, d in enumerate(docs, 1):
|
174 |
docs_html.append(make_html_source(d, i))
|
175 |
docs_html = "".join(docs_html)
|
176 |
+
|
177 |
except Exception as e:
|
178 |
print(f"Error getting documents: {e}")
|
179 |
print(event)
|
|
|
185 |
# answer_yet = "🔄️ Searching in the knowledge base\n{questions}"
|
186 |
# history[-1] = (query,answer_yet)
|
187 |
|
188 |
+
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
|
189 |
+
try:
|
190 |
+
recommended_content = event["data"]["output"]["recommended_content"]
|
191 |
+
# graphs = [
|
192 |
+
# {
|
193 |
+
# "embedding": x.metadata["returned_content"],
|
194 |
+
# "metadata": {
|
195 |
+
# "source": x.metadata["source"],
|
196 |
+
# "category": x.metadata["category"]
|
197 |
+
# }
|
198 |
+
# } for x in recommended_content if x.metadata["source"] == "OWID"
|
199 |
+
# ]
|
200 |
+
|
201 |
+
unique_graphs = []
|
202 |
+
seen_embeddings = set()
|
203 |
+
|
204 |
+
for x in recommended_content:
|
205 |
+
embedding = x.metadata["returned_content"]
|
206 |
+
|
207 |
+
# Check if the embedding has already been seen
|
208 |
+
if embedding not in seen_embeddings:
|
209 |
+
unique_graphs.append({
|
210 |
+
"embedding": embedding,
|
211 |
+
"metadata": {
|
212 |
+
"source": x.metadata["source"],
|
213 |
+
"category": x.metadata["category"]
|
214 |
+
}
|
215 |
+
})
|
216 |
+
# Add the embedding to the seen set
|
217 |
+
seen_embeddings.add(embedding)
|
218 |
+
|
219 |
+
|
220 |
+
categories = {}
|
221 |
+
for graph in unique_graphs:
|
222 |
+
category = graph['metadata']['category']
|
223 |
+
if category not in categories:
|
224 |
+
categories[category] = []
|
225 |
+
categories[category].append(graph['embedding'])
|
226 |
+
|
227 |
+
# graphs_html = ""
|
228 |
+
for category, embeddings in categories.items():
|
229 |
+
# graphs_html += f"<h3>{category}</h3>"
|
230 |
+
# current_graphs.append(f"<h3>{category}</h3>")
|
231 |
+
for embedding in embeddings:
|
232 |
+
current_graphs.append([embedding, category])
|
233 |
+
# graphs_html += f"<div>{embedding}</div>"
|
234 |
+
|
235 |
+
except Exception as e:
|
236 |
+
print(f"Error getting graphs: {e}")
|
237 |
|
238 |
for event_name,(event_description,display_output) in steps_display.items():
|
239 |
if event["name"] == event_name:
|
|
|
241 |
# answer_yet = f"<p><span class='loader'></span>{event_description}</p>"
|
242 |
# answer_yet = make_toolbox(event_description, "", checked = False)
|
243 |
answer_yet = event_description
|
244 |
+
|
245 |
history[-1] = (query,answer_yet)
|
246 |
# elif event["event"] == "on_chain_end":
|
247 |
# answer_yet = ""
|
|
|
266 |
|
267 |
|
268 |
history = [tuple(x) for x in history]
|
269 |
+
yield history,docs_html,output_query,output_language,gallery,current_graphs #,output_query,output_keywords
|
270 |
+
|
271 |
|
272 |
except Exception as e:
|
273 |
raise gr.Error(f"{e}")
|
|
|
330 |
history[-1] = (history[-1][0],answer_yet)
|
331 |
history = [tuple(x) for x in history]
|
332 |
|
333 |
+
print(f"\n\nImages:\n{gallery}")
|
334 |
+
|
335 |
# gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])]
|
336 |
# if len(gallery) > 0:
|
337 |
# gallery = list(set("|".join(gallery).split("|")))
|
338 |
# gallery = [get_image_from_azure_blob_storage(x) for x in gallery]
|
339 |
|
340 |
+
yield history,docs_html,output_query,output_language,gallery,current_graphs #,output_query,output_keywords
|
341 |
|
342 |
|
343 |
|
|
|
469 |
else:
|
470 |
print(data)
|
471 |
|
472 |
+
def save_graph(saved_graphs_state, embedding, category):
|
473 |
+
print(f"\nCategory:\n{saved_graphs_state}\n")
|
474 |
+
if category not in saved_graphs_state:
|
475 |
+
saved_graphs_state[category] = []
|
476 |
+
if embedding not in saved_graphs_state[category]:
|
477 |
+
saved_graphs_state[category].append(embedding)
|
478 |
+
return saved_graphs_state, gr.Button("Graph Saved")
|
479 |
|
480 |
|
481 |
with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo:
|
482 |
+
user_id_state = gr.State([user_id])
|
483 |
+
|
484 |
+
chat_completed_state = gr.State(0)
|
485 |
+
current_graphs = gr.State([])
|
486 |
+
saved_graphs = gr.State({})
|
487 |
|
488 |
with gr.Tab("ClimateQ&A"):
|
489 |
|
490 |
with gr.Row(elem_id="chatbot-row"):
|
491 |
with gr.Column(scale=2):
|
492 |
+
state = gr.State([system_template])
|
493 |
chatbot = gr.Chatbot(
|
494 |
value=[(None,init_prompt)],
|
495 |
show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",
|
|
|
543 |
# with Modal(visible = False) as config_modal:
|
544 |
with gr.Tab("Configuration",elem_id = "tab-config",id = 2):
|
545 |
|
546 |
+
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
|
547 |
|
548 |
|
549 |
dropdown_sources = gr.CheckboxGroup(
|
550 |
["IPCC", "IPBES","IPOS"],
|
551 |
label="Select source",
|
552 |
+
value=["IPCC", "IPBES","IPOS"],
|
553 |
interactive=True,
|
554 |
)
|
555 |
|
|
|
570 |
|
571 |
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
|
572 |
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
|
573 |
+
|
574 |
+
|
575 |
+
# with gr.Tab("Recommended content", elem_id="tab-recommended_content", id=3) as recommended_content_tab:
|
576 |
+
|
577 |
+
# @gr.render(inputs=[current_graphs])
|
578 |
+
# def display_default_recommended(current_graphs):
|
579 |
+
# if len(current_graphs)==0:
|
580 |
+
# placeholder_message = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
|
581 |
+
|
582 |
+
# @gr.render(inputs=[current_graphs],triggers=[chat_completed_state.change])
|
583 |
+
# def render_graphs(current_graph_list):
|
584 |
+
# global saved_graphs
|
585 |
+
# with gr.Column():
|
586 |
+
# print(f"\ncurrent_graph_list:\n{current_graph_list}")
|
587 |
+
# for (embedding, category) in current_graph_list:
|
588 |
+
# graphs_placeholder = gr.HTML(embedding, elem_id="graphs-placeholder")
|
589 |
+
# save_btn = gr.Button("Save Graph")
|
590 |
+
# save_btn.click(
|
591 |
+
# save_graph,
|
592 |
+
# [saved_graphs, gr.State(embedding), gr.State(category)],
|
593 |
+
# [saved_graphs, save_btn]
|
594 |
+
# )
|
595 |
|
596 |
#---------------------------------------------------------------------------------------
|
597 |
# OTHER TABS
|
598 |
#---------------------------------------------------------------------------------------
|
599 |
|
600 |
+
# with gr.Tab("Recommended content", elem_id="tab-recommended_content2") as recommended_content_tab2:
|
601 |
+
|
602 |
+
# @gr.render(inputs=[current_graphs])
|
603 |
+
# def display_default_recommended_head(current_graphs_list):
|
604 |
+
# if len(current_graphs_list)==0:
|
605 |
+
# gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>")
|
606 |
+
|
607 |
+
# @gr.render(inputs=[current_graphs],triggers=[chat_completed_state.change])
|
608 |
+
# def render_graphs_head(current_graph_list):
|
609 |
+
# global saved_graphs
|
610 |
+
|
611 |
+
# category_dict = defaultdict(list)
|
612 |
+
# for (embedding, category) in current_graph_list:
|
613 |
+
# category_dict[category].append(embedding)
|
614 |
+
|
615 |
+
# for category in category_dict:
|
616 |
+
# with gr.Tab(category):
|
617 |
+
# splits = [category_dict[category][i:i+3] for i in range(0, len(category_dict[category]), 3)]
|
618 |
+
# for row in splits:
|
619 |
+
# with gr.Row():
|
620 |
+
# for embedding in row:
|
621 |
+
# with gr.Column():
|
622 |
+
# gr.HTML(embedding, elem_id="graphs-placeholder")
|
623 |
+
# save_btn = gr.Button("Save Graph")
|
624 |
+
# save_btn.click(
|
625 |
+
# save_graph,
|
626 |
+
# [saved_graphs, gr.State(embedding), gr.State(category)],
|
627 |
+
# [saved_graphs, save_btn]
|
628 |
+
# )
|
629 |
+
|
630 |
+
|
631 |
+
|
632 |
+
# with gr.Tab("Saved Graphs", elem_id="tab-saved-graphs") as saved_graphs_tab:
|
633 |
+
|
634 |
+
# @gr.render(inputs=[saved_graphs])
|
635 |
+
# def display_default_save(saved):
|
636 |
+
# if len(saved)==0:
|
637 |
+
# gr.HTML("<h2>You have not saved any graphs yet</h2>")
|
638 |
+
|
639 |
+
# @gr.render(inputs=[saved_graphs], triggers=[saved_graphs.change])
|
640 |
+
# def view_saved_graphs(graphs_list):
|
641 |
+
# categories = [category for category in graphs_list] # graphs_list.keys()
|
642 |
+
# for category in categories:
|
643 |
+
# with gr.Tab(category):
|
644 |
+
# splits = [graphs_list[category][i:i+3] for i in range(0, len(graphs_list[category]), 3)]
|
645 |
+
# for row in splits:
|
646 |
+
# with gr.Row():
|
647 |
+
# for graph in row:
|
648 |
+
# gr.HTML(graph, elem_id="graphs-placeholder")
|
649 |
+
|
650 |
+
|
651 |
|
652 |
with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
|
653 |
gallery_component = gr.Gallery()
|
|
|
672 |
# with gr.Tab("Citations network",elem_id="papers-network-tab"):
|
673 |
# citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
|
674 |
|
675 |
+
# with gr.Tab("Saved Graphs", elem_id="tab-saved-graphs", id=4) as saved_graphs_tab:
|
676 |
+
# @gr.render(inputs=[saved_graphs], triggers=[saved_graphs.change])
|
677 |
+
# def view_saved_graphs(graphs_list):
|
678 |
+
# for graph in graphs_list:
|
679 |
+
# gr.HTML(graph, elem_id="graphs-placeholder")
|
680 |
|
681 |
with gr.Tab("About",elem_classes = "max-height other-tabs"):
|
682 |
with gr.Row():
|
|
|
690 |
return (gr.update(interactive = False),gr.update(selected=1),history)
|
691 |
|
692 |
def finish_chat():
|
693 |
+
return (gr.update(interactive = True,value = ""),gr.update(selected=3))
|
694 |
|
695 |
+
def change_completion_status(current_state):
|
696 |
+
current_state = 1 - current_state
|
697 |
+
return current_state
|
698 |
+
|
699 |
+
|
700 |
(textbox
|
701 |
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
|
702 |
+
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
|
703 |
+
.then(finish_chat, None, [textbox,tabs],api_name = "finish_chat_textbox")
|
704 |
+
.then(change_completion_status, [chat_completed_state], [chat_completed_state])
|
705 |
)
|
706 |
|
707 |
(examples_hidden
|
708 |
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
709 |
+
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports,current_graphs], [chatbot,sources_textbox,output_query,output_language,gallery_component, current_graphs],concurrency_limit = 8,api_name = "chat_examples")
|
710 |
+
.then(finish_chat, None, [textbox,tabs],api_name = "finish_chat_examples")
|
711 |
+
.then(change_completion_status, [chat_completed_state], [chat_completed_state])
|
712 |
)
|
713 |
|
714 |
|
|
|
727 |
|
728 |
demo.queue()
|
729 |
|
730 |
+
demo.launch(debug=True)
|
climateqa/constants.py
CHANGED
@@ -42,4 +42,25 @@ POSSIBLE_REPORTS = [
|
|
42 |
"IPBES IAS A C5",
|
43 |
"IPBES IAS A C6",
|
44 |
"IPBES IAS A SPM"
|
45 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
"IPBES IAS A C5",
|
43 |
"IPBES IAS A C6",
|
44 |
"IPBES IAS A SPM"
|
45 |
+
]
|
46 |
+
|
47 |
+
OWID_CATEGORIES = ['Access to Energy', 'Agricultural Production',
|
48 |
+
'Agricultural Regulation & Policy', 'Air Pollution',
|
49 |
+
'Animal Welfare', 'Antibiotics', 'Biodiversity', 'Biofuels',
|
50 |
+
'Biological & Chemical Weapons', 'CO2 & Greenhouse Gas Emissions',
|
51 |
+
'COVID-19', 'Clean Water', 'Clean Water & Sanitation',
|
52 |
+
'Climate Change', 'Crop Yields', 'Diet Compositions',
|
53 |
+
'Electricity', 'Electricity Mix', 'Energy', 'Energy Efficiency',
|
54 |
+
'Energy Prices', 'Environmental Impacts of Food Production',
|
55 |
+
'Environmental Protection & Regulation', 'Famines', 'Farm Size',
|
56 |
+
'Fertilizers', 'Fish & Overfishing', 'Food Supply', 'Food Trade',
|
57 |
+
'Food Waste', 'Food and Agriculture', 'Forests & Deforestation',
|
58 |
+
'Fossil Fuels', 'Future Population Growth',
|
59 |
+
'Hunger & Undernourishment', 'Indoor Air Pollution', 'Land Use',
|
60 |
+
'Land Use & Yields in Agriculture', 'Lead Pollution',
|
61 |
+
'Meat & Dairy Production', 'Metals & Minerals',
|
62 |
+
'Natural Disasters', 'Nuclear Energy', 'Nuclear Weapons',
|
63 |
+
'Oil Spills', 'Outdoor Air Pollution', 'Ozone Layer', 'Pandemics',
|
64 |
+
'Pesticides', 'Plastic Pollution', 'Renewable Energy', 'Soil',
|
65 |
+
'Transport', 'Urbanization', 'Waste Management', 'Water Pollution',
|
66 |
+
'Water Use & Stress', 'Wildfires']
|
climateqa/engine/chains/answer_ai_impact.py
CHANGED
@@ -38,6 +38,7 @@ def make_ai_impact_chain(llm):
|
|
38 |
def make_ai_impact_node(llm):
|
39 |
|
40 |
ai_impact_chain = make_ai_impact_chain(llm)
|
|
|
41 |
|
42 |
async def answer_ai_impact(state,config):
|
43 |
answer = await ai_impact_chain.ainvoke({"question":state["user_input"]},config)
|
|
|
38 |
def make_ai_impact_node(llm):
|
39 |
|
40 |
ai_impact_chain = make_ai_impact_chain(llm)
|
41 |
+
|
42 |
|
43 |
async def answer_ai_impact(state,config):
|
44 |
answer = await ai_impact_chain.ainvoke({"question":state["user_input"]},config)
|
climateqa/engine/chains/answer_rag.py
CHANGED
@@ -61,6 +61,7 @@ def make_rag_node(llm,with_docs = True):
|
|
61 |
|
62 |
async def answer_rag(state,config):
|
63 |
answer = await rag_chain.ainvoke(state,config)
|
|
|
64 |
return {"answer":answer}
|
65 |
|
66 |
return answer_rag
|
|
|
61 |
|
62 |
async def answer_rag(state,config):
|
63 |
answer = await rag_chain.ainvoke(state,config)
|
64 |
+
print(f"\n\nAnswer:\n{answer}")
|
65 |
return {"answer":answer}
|
66 |
|
67 |
return answer_rag
|
climateqa/engine/chains/chitchat_categorization.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
3 |
+
from typing import List
|
4 |
+
from typing import Literal
|
5 |
+
from langchain.prompts import ChatPromptTemplate
|
6 |
+
from langchain_core.utils.function_calling import convert_to_openai_function
|
7 |
+
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
+
|
9 |
+
|
10 |
+
class IntentCategorizer(BaseModel):
|
11 |
+
"""Analyzing the user message input"""
|
12 |
+
|
13 |
+
environment: bool = Field(
|
14 |
+
description="Return 'True' if the question relates to climate change, the environment, nature, etc. (Example: should I eat fish?). Return 'False' if the question is just chit chat or not related to the environment or climate change.",
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def make_chitchat_intent_categorization_chain(llm):
|
19 |
+
|
20 |
+
openai_functions = [convert_to_openai_function(IntentCategorizer)]
|
21 |
+
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
22 |
+
|
23 |
+
prompt = ChatPromptTemplate.from_messages([
|
24 |
+
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"),
|
25 |
+
("user", "input: {input}")
|
26 |
+
])
|
27 |
+
|
28 |
+
chain = prompt | llm_with_functions | JsonOutputFunctionsParser()
|
29 |
+
return chain
|
30 |
+
|
31 |
+
|
32 |
+
def make_chitchat_intent_categorization_node(llm):
|
33 |
+
|
34 |
+
categorization_chain = make_chitchat_intent_categorization_chain(llm)
|
35 |
+
|
36 |
+
def categorize_message(state):
|
37 |
+
output = categorization_chain.invoke({"input": state["user_input"]})
|
38 |
+
print(f"\n\nChit chat output intent categorization: {output}\n")
|
39 |
+
state["search_graphs_chitchat"] = output["environment"]
|
40 |
+
print(f"\n\nChit chat output intent categorization: {state}\n")
|
41 |
+
return state
|
42 |
+
|
43 |
+
return categorize_message
|
climateqa/engine/chains/graph_retriever.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
from contextlib import contextmanager
|
4 |
+
|
5 |
+
from ..reranker import rerank_docs
|
6 |
+
from ..graph_retriever import GraphRetriever
|
7 |
+
from ...utils import remove_duplicates_keep_highest_score
|
8 |
+
|
9 |
+
|
10 |
+
def divide_into_parts(target, parts):
|
11 |
+
# Base value for each part
|
12 |
+
base = target // parts
|
13 |
+
# Remainder to distribute
|
14 |
+
remainder = target % parts
|
15 |
+
# List to hold the result
|
16 |
+
result = []
|
17 |
+
|
18 |
+
for i in range(parts):
|
19 |
+
if i < remainder:
|
20 |
+
# These parts get base value + 1
|
21 |
+
result.append(base + 1)
|
22 |
+
else:
|
23 |
+
# The rest get the base value
|
24 |
+
result.append(base)
|
25 |
+
|
26 |
+
return result
|
27 |
+
|
28 |
+
|
29 |
+
@contextmanager
|
30 |
+
def suppress_output():
|
31 |
+
# Open a null device
|
32 |
+
with open(os.devnull, 'w') as devnull:
|
33 |
+
# Store the original stdout and stderr
|
34 |
+
old_stdout = sys.stdout
|
35 |
+
old_stderr = sys.stderr
|
36 |
+
# Redirect stdout and stderr to the null device
|
37 |
+
sys.stdout = devnull
|
38 |
+
sys.stderr = devnull
|
39 |
+
try:
|
40 |
+
yield
|
41 |
+
finally:
|
42 |
+
# Restore stdout and stderr
|
43 |
+
sys.stdout = old_stdout
|
44 |
+
sys.stderr = old_stderr
|
45 |
+
|
46 |
+
|
47 |
+
def make_graph_retriever_node(vectorstore, reranker, rerank_by_question=True, k_final=15, k_before_reranking=100):
|
48 |
+
|
49 |
+
def retrieve_graphs(state):
|
50 |
+
print("---- Retrieving graphs ----")
|
51 |
+
|
52 |
+
POSSIBLE_SOURCES = ["IEA", "OWID"]
|
53 |
+
questions = state["questions"] if state["questions"] is not None else [state["query"]]
|
54 |
+
sources_input = state["sources_input"]
|
55 |
+
|
56 |
+
auto_mode = "auto" in sources_input
|
57 |
+
|
58 |
+
# There are several options to get the final top k
|
59 |
+
# Option 1 - Get 100 documents by question and rerank by question
|
60 |
+
# Option 2 - Get 100/n documents by question and rerank the total
|
61 |
+
if rerank_by_question:
|
62 |
+
k_by_question = divide_into_parts(k_final,len(questions))
|
63 |
+
|
64 |
+
docs = []
|
65 |
+
|
66 |
+
for i,q in enumerate(questions):
|
67 |
+
|
68 |
+
question = q["question"] if isinstance(q, dict) else q
|
69 |
+
|
70 |
+
print(f"Subquestion {i}: {question}")
|
71 |
+
|
72 |
+
# If auto mode, we use all sources
|
73 |
+
if auto_mode:
|
74 |
+
sources = POSSIBLE_SOURCES
|
75 |
+
# Otherwise, we use the config
|
76 |
+
else:
|
77 |
+
sources = sources_input
|
78 |
+
|
79 |
+
if any([x in POSSIBLE_SOURCES for x in sources]):
|
80 |
+
|
81 |
+
sources = [x for x in sources if x in POSSIBLE_SOURCES]
|
82 |
+
|
83 |
+
# Search the document store using the retriever
|
84 |
+
retriever = GraphRetriever(
|
85 |
+
vectorstore = vectorstore,
|
86 |
+
sources = sources,
|
87 |
+
k_total = k_before_reranking,
|
88 |
+
threshold = 0.5,
|
89 |
+
)
|
90 |
+
docs_question = retriever.get_relevant_documents(question)
|
91 |
+
|
92 |
+
# Rerank
|
93 |
+
if reranker is not None:
|
94 |
+
with suppress_output():
|
95 |
+
docs_question = rerank_docs(reranker,docs_question,question)
|
96 |
+
else:
|
97 |
+
# Add a default reranking score
|
98 |
+
for doc in docs_question:
|
99 |
+
doc.metadata["reranking_score"] = doc.metadata["similarity_score"]
|
100 |
+
|
101 |
+
# If rerank by question we select the top documents for each question
|
102 |
+
if rerank_by_question:
|
103 |
+
docs_question = docs_question[:k_by_question[i]]
|
104 |
+
|
105 |
+
# Add sources used in the metadata
|
106 |
+
for doc in docs_question:
|
107 |
+
doc.metadata["sources_used"] = sources
|
108 |
+
|
109 |
+
print(f"{len(docs_question)} graphs retrieved for subquestion {i + 1}: {docs_question}")
|
110 |
+
|
111 |
+
docs.extend(docs_question)
|
112 |
+
|
113 |
+
else:
|
114 |
+
print(f"There are no graphs which match the sources filtered on. Sources filtered on: {sources}. Sources available: {POSSIBLE_SOURCES}.")
|
115 |
+
|
116 |
+
# Remove duplicates and keep the duplicate document with the highest reranking score
|
117 |
+
docs = remove_duplicates_keep_highest_score(docs)
|
118 |
+
|
119 |
+
# Sorting the list in descending order by rerank_score
|
120 |
+
# Then select the top k
|
121 |
+
docs = sorted(docs, key=lambda x: x.metadata["reranking_score"], reverse=True)
|
122 |
+
docs = docs[:k_final]
|
123 |
+
|
124 |
+
return {"recommended_content": docs}
|
125 |
+
|
126 |
+
return retrieve_graphs
|
climateqa/engine/chains/intent_categorization.py
CHANGED
@@ -7,6 +7,34 @@ from langchain_core.utils.function_calling import convert_to_openai_function
|
|
7 |
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
class IntentCategorizer(BaseModel):
|
11 |
"""Analyzing the user message input"""
|
12 |
|
@@ -16,9 +44,9 @@ class IntentCategorizer(BaseModel):
|
|
16 |
)
|
17 |
intent: str = Field(
|
18 |
enum=[
|
19 |
-
"
|
20 |
-
"geo_info",
|
21 |
-
"esg"
|
22 |
"search",
|
23 |
"chitchat",
|
24 |
],
|
@@ -27,12 +55,12 @@ class IntentCategorizer(BaseModel):
|
|
27 |
Any question
|
28 |
|
29 |
Examples:
|
30 |
-
-
|
31 |
-
- geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
|
32 |
-
- esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
|
33 |
- search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
|
34 |
- chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
|
35 |
""",
|
|
|
|
|
36 |
)
|
37 |
|
38 |
|
@@ -43,7 +71,7 @@ def make_intent_categorization_chain(llm):
|
|
43 |
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
44 |
|
45 |
prompt = ChatPromptTemplate.from_messages([
|
46 |
-
("system", "You are a helpful assistant, you will analyze,
|
47 |
("user", "input: {input}")
|
48 |
])
|
49 |
|
@@ -56,7 +84,8 @@ def make_intent_categorization_node(llm):
|
|
56 |
categorization_chain = make_intent_categorization_chain(llm)
|
57 |
|
58 |
def categorize_message(state):
|
59 |
-
output = categorization_chain.invoke({"input":state["user_input"]})
|
|
|
60 |
if "language" not in output: output["language"] = "English"
|
61 |
output["query"] = state["user_input"]
|
62 |
return output
|
|
|
7 |
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
8 |
|
9 |
|
10 |
+
# class IntentCategorizer(BaseModel):
|
11 |
+
# """Analyzing the user message input"""
|
12 |
+
|
13 |
+
# language: str = Field(
|
14 |
+
# description="Find the language of the message input in full words (ex: French, English, Spanish, ...), defaults to English",
|
15 |
+
# default="English",
|
16 |
+
# )
|
17 |
+
# intent: str = Field(
|
18 |
+
# enum=[
|
19 |
+
# "ai",
|
20 |
+
# # "geo_info",
|
21 |
+
# # "esg"
|
22 |
+
# "search",
|
23 |
+
# "chitchat",
|
24 |
+
# ],
|
25 |
+
# description="""
|
26 |
+
# Categorize the user input in one of the following category
|
27 |
+
# Any question
|
28 |
+
|
29 |
+
# Examples:
|
30 |
+
# - ai = any question related to AI: "What are the environmental consequences of AI", "How does AI affect the environment"
|
31 |
+
# - search = Searching for any question about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers. Also questions about individual actions or anything loosely related to the environment.
|
32 |
+
# - chitchat = Any chit chat or any question that is not related to the environment or climate change or for which it is not necessary to look for the answer in the IPCC, IPBES, IPOS or scientific reports.
|
33 |
+
# """,
|
34 |
+
# # - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
|
35 |
+
# # - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
|
36 |
+
# )
|
37 |
+
|
38 |
class IntentCategorizer(BaseModel):
|
39 |
"""Analyzing the user message input"""
|
40 |
|
|
|
44 |
)
|
45 |
intent: str = Field(
|
46 |
enum=[
|
47 |
+
"ai",
|
48 |
+
# "geo_info",
|
49 |
+
# "esg"
|
50 |
"search",
|
51 |
"chitchat",
|
52 |
],
|
|
|
55 |
Any question
|
56 |
|
57 |
Examples:
|
58 |
+
- ai = Any query related to Artificial Intelligence: "What are the environmental consequences of AI", "How does AI affect the environment"
|
|
|
|
|
59 |
- search = Searching for any quesiton about climate change, energy, biodiversity, nature, and everything we can find the IPCC or IPBES reports or scientific papers,
|
60 |
- chitchat = Any general question that is not related to the environment or climate change or just conversational, or if you don't think searching the IPCC or IPBES reports would be relevant
|
61 |
""",
|
62 |
+
# - geo_info = Geolocated info about climate change: Any question where the user wants to know localized impacts of climate change, eg: "What will be the temperature in Marseille in 2050"
|
63 |
+
# - esg = Any question about the ESG regulation, frameworks and standards like the CSRD, TCFD, SASB, GRI, CDP, etc.
|
64 |
)
|
65 |
|
66 |
|
|
|
71 |
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"IntentCategorizer"})
|
72 |
|
73 |
prompt = ChatPromptTemplate.from_messages([
|
74 |
+
("system", "You are a helpful assistant, you will analyze, and categorize the user input message using the function provided. Categorize the user input as ai ONLY if it is related to Artificial Intelligence, search if it is related to the environment, climate change, energy, biodiversity, nature, etc. and chitchat if it is just general conversation."),
|
75 |
("user", "input: {input}")
|
76 |
])
|
77 |
|
|
|
84 |
categorization_chain = make_intent_categorization_chain(llm)
|
85 |
|
86 |
def categorize_message(state):
|
87 |
+
output = categorization_chain.invoke({"input": state["user_input"]})
|
88 |
+
print(f"\n\nOutput intent categorization: {output}\n")
|
89 |
if "language" not in output: output["language"] = "English"
|
90 |
output["query"] = state["user_input"]
|
91 |
return output
|
climateqa/engine/chains/prompts.py
CHANGED
@@ -147,4 +147,27 @@ audience_prompts = {
|
|
147 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
148 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
149 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
150 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
"children": "6 year old children that don't know anything about science and climate change and need metaphors to learn",
|
148 |
"general": "the general public who know the basics in science and climate change and want to learn more about it without technical terms. Still use references to passages.",
|
149 |
"experts": "expert and climate scientists that are not afraid of technical terms",
|
150 |
+
}
|
151 |
+
|
152 |
+
|
153 |
+
answer_prompt_graph_template = """
|
154 |
+
Given the user question and a list of graphs which are related to the question, rank the graphs based on relevance to the user question. ALWAYS follow the guidelines given below.
|
155 |
+
|
156 |
+
### Guidelines ###
|
157 |
+
- Keep all the graphs that are given to you.
|
158 |
+
- NEVER modify the graph HTML embedding, the category or the source leave them exactly as they are given.
|
159 |
+
- Return the ranked graphs as a list of dictionaries with keys 'embedding', 'category', and 'source'.
|
160 |
+
- Return a valid JSON output.
|
161 |
+
|
162 |
+
-----------------------
|
163 |
+
User question:
|
164 |
+
{query}
|
165 |
+
|
166 |
+
Graphs and their HTML embedding:
|
167 |
+
{recommended_content}
|
168 |
+
|
169 |
+
-----------------------
|
170 |
+
{format_instructions}
|
171 |
+
|
172 |
+
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
173 |
+
"""
|
climateqa/engine/chains/query_transformation.py
CHANGED
@@ -62,15 +62,15 @@ class QueryAnalysis(BaseModel):
|
|
62 |
# """
|
63 |
# )
|
64 |
|
65 |
-
sources: List[Literal["IPCC", "IPBES", "IPOS"
|
66 |
...,
|
67 |
description="""
|
68 |
Given a user question choose which documents would be most relevant for answering their question,
|
69 |
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
70 |
- IPBES is for questions about biodiversity and nature
|
71 |
- IPOS is for questions about the ocean and deep sea mining
|
72 |
-
- OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
|
73 |
""",
|
|
|
74 |
)
|
75 |
# topics: List[Literal[
|
76 |
# "Climate change",
|
@@ -143,6 +143,11 @@ def make_query_transform_node(llm):
|
|
143 |
for question in new_state["questions"]:
|
144 |
question_state = {"question":question}
|
145 |
analysis_output = rewriter_chain.invoke({"input":question})
|
|
|
|
|
|
|
|
|
|
|
146 |
question_state.update(analysis_output)
|
147 |
questions.append(question_state)
|
148 |
new_state["questions"] = questions
|
|
|
62 |
# """
|
63 |
# )
|
64 |
|
65 |
+
sources: List[Literal["IPCC", "IPBES", "IPOS"]] = Field( #,"OpenAlex"]] = Field(
|
66 |
...,
|
67 |
description="""
|
68 |
Given a user question choose which documents would be most relevant for answering their question,
|
69 |
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports
|
70 |
- IPBES is for questions about biodiversity and nature
|
71 |
- IPOS is for questions about the ocean and deep sea mining
|
|
|
72 |
""",
|
73 |
+
# - OpenAlex is for any other questions that are not in the previous categories but could be found in the scientific litterature
|
74 |
)
|
75 |
# topics: List[Literal[
|
76 |
# "Climate change",
|
|
|
143 |
for question in new_state["questions"]:
|
144 |
question_state = {"question":question}
|
145 |
analysis_output = rewriter_chain.invoke({"input":question})
|
146 |
+
|
147 |
+
# The case when the llm does not return any sources
|
148 |
+
if not analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS"] for source in analysis_output["sources"]):
|
149 |
+
analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"]
|
150 |
+
|
151 |
question_state.update(analysis_output)
|
152 |
questions.append(question_state)
|
153 |
new_state["questions"] = questions
|
climateqa/engine/chains/retriever.py
CHANGED
@@ -49,7 +49,7 @@ def make_retriever_node(vectorstore,reranker,rerank_by_question=True, k_final=15
|
|
49 |
|
50 |
def retrieve_documents(state):
|
51 |
|
52 |
-
POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS","OpenAlex"]
|
53 |
questions = state["questions"]
|
54 |
|
55 |
# Use sources from the user input or from the LLM detection
|
|
|
49 |
|
50 |
def retrieve_documents(state):
|
51 |
|
52 |
+
POSSIBLE_SOURCES = ["IPCC","IPBES","IPOS"] # ,"OpenAlex"]
|
53 |
questions = state["questions"]
|
54 |
|
55 |
# Use sources from the user input or from the LLM detection
|
climateqa/engine/chains/set_defaults.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def set_defaults(state):
|
2 |
+
print("---- Setting defaults ----")
|
3 |
+
|
4 |
+
if not state["audience"] or state["audience"] is None:
|
5 |
+
state.update({"audience": "experts"})
|
6 |
+
|
7 |
+
sources_input = state["sources_input"] if "sources_input" in state else ["auto"]
|
8 |
+
state.update({"sources_input": sources_input})
|
9 |
+
|
10 |
+
# if not state["sources_input"] or state["sources_input"] is None:
|
11 |
+
# state.update({"sources_input": ["auto"]})
|
12 |
+
|
13 |
+
return state
|
climateqa/engine/graph.py
CHANGED
@@ -4,10 +4,10 @@ from contextlib import contextmanager
|
|
4 |
|
5 |
from langchain.schema import Document
|
6 |
from langgraph.graph import END, StateGraph
|
7 |
-
from langchain_core.runnables.graph import CurveStyle,
|
8 |
|
9 |
from typing_extensions import TypedDict
|
10 |
-
from typing import List
|
11 |
|
12 |
from IPython.display import display, HTML, Image
|
13 |
|
@@ -27,12 +27,15 @@ class GraphState(TypedDict):
|
|
27 |
user_input : str
|
28 |
language : str
|
29 |
intent : str
|
|
|
30 |
query: str
|
31 |
questions : List[dict]
|
32 |
answer: str
|
33 |
audience: str = "experts"
|
34 |
sources_input: List[str] = ["auto"]
|
35 |
documents: List[Document]
|
|
|
|
|
36 |
|
37 |
def search(state):
|
38 |
return {}
|
@@ -46,6 +49,13 @@ def route_intent(state):
|
|
46 |
else:
|
47 |
# Search route
|
48 |
return "search"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def route_translation(state):
|
51 |
if state["language"].lower() == "english":
|
@@ -64,7 +74,7 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
|
64 |
def make_id_dict(values):
|
65 |
return {k:k for k in values}
|
66 |
|
67 |
-
def make_graph_agent(llm,
|
68 |
|
69 |
workflow = StateGraph(GraphState)
|
70 |
|
@@ -74,23 +84,35 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
|
|
74 |
translate_query = make_translation_node(llm)
|
75 |
answer_chitchat = make_chitchat_node(llm)
|
76 |
answer_ai_impact = make_ai_impact_node(llm)
|
77 |
-
retrieve_documents = make_retriever_node(
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
80 |
|
81 |
# Define the nodes
|
|
|
82 |
workflow.add_node("categorize_intent", categorize_intent)
|
83 |
workflow.add_node("search", search)
|
84 |
workflow.add_node("transform_query", transform_query)
|
85 |
workflow.add_node("translate_query", translate_query)
|
|
|
|
|
86 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
|
|
87 |
workflow.add_node("answer_ai_impact", answer_ai_impact)
|
88 |
-
workflow.add_node("
|
89 |
-
workflow.add_node("
|
90 |
-
workflow.add_node("
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
# Entry point
|
93 |
-
workflow.set_entry_point("
|
94 |
|
95 |
# CONDITIONAL EDGES
|
96 |
workflow.add_conditional_edges(
|
@@ -99,6 +121,12 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
|
|
99 |
make_id_dict(["answer_chitchat","answer_ai_impact","search"])
|
100 |
)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
workflow.add_conditional_edges(
|
103 |
"search",
|
104 |
route_translation,
|
@@ -112,13 +140,24 @@ def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
|
|
112 |
)
|
113 |
|
114 |
# Define the edges
|
|
|
115 |
workflow.add_edge("translate_query", "transform_query")
|
116 |
-
workflow.add_edge("transform_query", "
|
117 |
-
workflow.add_edge("
|
|
|
|
|
118 |
workflow.add_edge("answer_rag", END)
|
119 |
workflow.add_edge("answer_rag_no_docs", END)
|
120 |
-
workflow.add_edge("answer_chitchat",
|
|
|
121 |
workflow.add_edge("answer_ai_impact", END)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# Compile
|
124 |
app = workflow.compile()
|
@@ -135,4 +174,143 @@ def display_graph(app):
|
|
135 |
draw_method=MermaidDrawMethod.API,
|
136 |
)
|
137 |
)
|
138 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
from langchain.schema import Document
|
6 |
from langgraph.graph import END, StateGraph
|
7 |
+
from langchain_core.runnables.graph import CurveStyle, MermaidDrawMethod
|
8 |
|
9 |
from typing_extensions import TypedDict
|
10 |
+
from typing import List, Dict
|
11 |
|
12 |
from IPython.display import display, HTML, Image
|
13 |
|
|
|
27 |
user_input : str
|
28 |
language : str
|
29 |
intent : str
|
30 |
+
search_graphs_chitchat : bool
|
31 |
query: str
|
32 |
questions : List[dict]
|
33 |
answer: str
|
34 |
audience: str = "experts"
|
35 |
sources_input: List[str] = ["auto"]
|
36 |
documents: List[Document]
|
37 |
+
recommended_content : List[Document]
|
38 |
+
# graphs_returned: Dict[str,str]
|
39 |
|
40 |
def search(state):
|
41 |
return {}
|
|
|
49 |
else:
|
50 |
# Search route
|
51 |
return "search"
|
52 |
+
|
53 |
+
def chitchat_route_intent(state):
|
54 |
+
intent = state["search_graphs_chitchat"]
|
55 |
+
if intent is True:
|
56 |
+
return "retrieve_graphs_chitchat"
|
57 |
+
elif intent is False:
|
58 |
+
return END
|
59 |
|
60 |
def route_translation(state):
|
61 |
if state["language"].lower() == "english":
|
|
|
74 |
def make_id_dict(values):
|
75 |
return {k:k for k in values}
|
76 |
|
77 |
+
def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, reranker, threshold_docs=0.2):
|
78 |
|
79 |
workflow = StateGraph(GraphState)
|
80 |
|
|
|
84 |
translate_query = make_translation_node(llm)
|
85 |
answer_chitchat = make_chitchat_node(llm)
|
86 |
answer_ai_impact = make_ai_impact_node(llm)
|
87 |
+
retrieve_documents = make_retriever_node(vectorstore_ipcc, reranker)
|
88 |
+
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
89 |
+
# answer_rag_graph = make_rag_graph_node(llm)
|
90 |
+
answer_rag = make_rag_node(llm, with_docs=True)
|
91 |
+
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
92 |
+
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
93 |
|
94 |
# Define the nodes
|
95 |
+
workflow.add_node("set_defaults", set_defaults)
|
96 |
workflow.add_node("categorize_intent", categorize_intent)
|
97 |
workflow.add_node("search", search)
|
98 |
workflow.add_node("transform_query", transform_query)
|
99 |
workflow.add_node("translate_query", translate_query)
|
100 |
+
# workflow.add_node("transform_query_ai", transform_query)
|
101 |
+
# workflow.add_node("translate_query_ai", translate_query)
|
102 |
workflow.add_node("answer_chitchat", answer_chitchat)
|
103 |
+
workflow.add_node("chitchat_categorize_intent", chitchat_categorize_intent)
|
104 |
workflow.add_node("answer_ai_impact", answer_ai_impact)
|
105 |
+
workflow.add_node("retrieve_graphs", retrieve_graphs)
|
106 |
+
workflow.add_node("retrieve_graphs_chitchat", retrieve_graphs)
|
107 |
+
# workflow.add_node("retrieve_graphs_ai", retrieve_graphs)
|
108 |
+
# workflow.add_node("answer_rag_graph", answer_rag_graph)
|
109 |
+
# workflow.add_node("answer_rag_graph_ai", answer_rag_graph)
|
110 |
+
workflow.add_node("retrieve_documents", retrieve_documents)
|
111 |
+
workflow.add_node("answer_rag", answer_rag)
|
112 |
+
workflow.add_node("answer_rag_no_docs", answer_rag_no_docs)
|
113 |
|
114 |
# Entry point
|
115 |
+
workflow.set_entry_point("set_defaults")
|
116 |
|
117 |
# CONDITIONAL EDGES
|
118 |
workflow.add_conditional_edges(
|
|
|
121 |
make_id_dict(["answer_chitchat","answer_ai_impact","search"])
|
122 |
)
|
123 |
|
124 |
+
workflow.add_conditional_edges(
|
125 |
+
"chitchat_categorize_intent",
|
126 |
+
chitchat_route_intent,
|
127 |
+
make_id_dict(["retrieve_graphs_chitchat", END])
|
128 |
+
)
|
129 |
+
|
130 |
workflow.add_conditional_edges(
|
131 |
"search",
|
132 |
route_translation,
|
|
|
140 |
)
|
141 |
|
142 |
# Define the edges
|
143 |
+
workflow.add_edge("set_defaults", "categorize_intent")
|
144 |
workflow.add_edge("translate_query", "transform_query")
|
145 |
+
workflow.add_edge("transform_query", "retrieve_graphs")
|
146 |
+
# workflow.add_edge("retrieve_graphs", "answer_rag_graph")
|
147 |
+
workflow.add_edge("retrieve_graphs", "retrieve_documents")
|
148 |
+
# workflow.add_edge("answer_rag_graph", "retrieve_documents")
|
149 |
workflow.add_edge("answer_rag", END)
|
150 |
workflow.add_edge("answer_rag_no_docs", END)
|
151 |
+
workflow.add_edge("answer_chitchat", "chitchat_categorize_intent")
|
152 |
+
# workflow.add_edge("answer_chitchat", END)
|
153 |
workflow.add_edge("answer_ai_impact", END)
|
154 |
+
workflow.add_edge("retrieve_graphs_chitchat", END)
|
155 |
+
# workflow.add_edge("answer_ai_impact", "translate_query_ai")
|
156 |
+
# workflow.add_edge("translate_query_ai", "transform_query_ai")
|
157 |
+
# workflow.add_edge("transform_query_ai", "retrieve_graphs_ai")
|
158 |
+
# workflow.add_edge("retrieve_graphs_ai", "answer_rag_graph_ai")
|
159 |
+
# workflow.add_edge("answer_rag_graph_ai", END)
|
160 |
+
# workflow.add_edge("retrieve_graphs_ai", END)
|
161 |
|
162 |
# Compile
|
163 |
app = workflow.compile()
|
|
|
174 |
draw_method=MermaidDrawMethod.API,
|
175 |
)
|
176 |
)
|
177 |
+
)
|
178 |
+
|
179 |
+
# import sys
|
180 |
+
# import os
|
181 |
+
# from contextlib import contextmanager
|
182 |
+
|
183 |
+
# from langchain.schema import Document
|
184 |
+
# from langgraph.graph import END, StateGraph
|
185 |
+
# from langchain_core.runnables.graph import CurveStyle, NodeColors, MermaidDrawMethod
|
186 |
+
|
187 |
+
# from typing_extensions import TypedDict
|
188 |
+
# from typing import List
|
189 |
+
|
190 |
+
# from IPython.display import display, HTML, Image
|
191 |
+
|
192 |
+
# from .chains.answer_chitchat import make_chitchat_node
|
193 |
+
# from .chains.answer_ai_impact import make_ai_impact_node
|
194 |
+
# from .chains.query_transformation import make_query_transform_node
|
195 |
+
# from .chains.translation import make_translation_node
|
196 |
+
# from .chains.intent_categorization import make_intent_categorization_node
|
197 |
+
# from .chains.retriever import make_retriever_node
|
198 |
+
# from .chains.answer_rag import make_rag_node
|
199 |
+
|
200 |
+
|
201 |
+
# class GraphState(TypedDict):
|
202 |
+
# """
|
203 |
+
# Represents the state of our graph.
|
204 |
+
# """
|
205 |
+
# user_input : str
|
206 |
+
# language : str
|
207 |
+
# intent : str
|
208 |
+
# query: str
|
209 |
+
# questions : List[dict]
|
210 |
+
# answer: str
|
211 |
+
# audience: str = "experts"
|
212 |
+
# sources_input: List[str] = ["auto"]
|
213 |
+
# documents: List[Document]
|
214 |
+
|
215 |
+
# def search(state):
|
216 |
+
# return {}
|
217 |
+
|
218 |
+
# def route_intent(state):
|
219 |
+
# intent = state["intent"]
|
220 |
+
# if intent in ["chitchat","esg"]:
|
221 |
+
# return "answer_chitchat"
|
222 |
+
# elif intent == "ai_impact":
|
223 |
+
# return "answer_ai_impact"
|
224 |
+
# else:
|
225 |
+
# # Search route
|
226 |
+
# return "search"
|
227 |
+
|
228 |
+
# def route_translation(state):
|
229 |
+
# if state["language"].lower() == "english":
|
230 |
+
# return "transform_query"
|
231 |
+
# else:
|
232 |
+
# return "translate_query"
|
233 |
+
|
234 |
+
# def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
235 |
+
# docs = [x for x in state["documents"] if x.metadata["reranking_score"] > threshold_docs]
|
236 |
+
# if len(docs) > 0:
|
237 |
+
# return "answer_rag"
|
238 |
+
# else:
|
239 |
+
# return "answer_rag_no_docs"
|
240 |
+
|
241 |
+
|
242 |
+
# def make_id_dict(values):
|
243 |
+
# return {k:k for k in values}
|
244 |
+
|
245 |
+
# def make_graph_agent(llm,vectorstore,reranker,threshold_docs = 0.2):
|
246 |
+
|
247 |
+
# workflow = StateGraph(GraphState)
|
248 |
+
|
249 |
+
# # Define the node functions
|
250 |
+
# categorize_intent = make_intent_categorization_node(llm)
|
251 |
+
# transform_query = make_query_transform_node(llm)
|
252 |
+
# translate_query = make_translation_node(llm)
|
253 |
+
# answer_chitchat = make_chitchat_node(llm)
|
254 |
+
# answer_ai_impact = make_ai_impact_node(llm)
|
255 |
+
# retrieve_documents = make_retriever_node(vectorstore,reranker)
|
256 |
+
# answer_rag = make_rag_node(llm,with_docs=True)
|
257 |
+
# answer_rag_no_docs = make_rag_node(llm,with_docs=False)
|
258 |
+
|
259 |
+
# # Define the nodes
|
260 |
+
# workflow.add_node("categorize_intent", categorize_intent)
|
261 |
+
# workflow.add_node("search", search)
|
262 |
+
# workflow.add_node("transform_query", transform_query)
|
263 |
+
# workflow.add_node("translate_query", translate_query)
|
264 |
+
# workflow.add_node("answer_chitchat", answer_chitchat)
|
265 |
+
# workflow.add_node("answer_ai_impact", answer_ai_impact)
|
266 |
+
# workflow.add_node("retrieve_documents",retrieve_documents)
|
267 |
+
# workflow.add_node("answer_rag",answer_rag)
|
268 |
+
# workflow.add_node("answer_rag_no_docs",answer_rag_no_docs)
|
269 |
+
|
270 |
+
# # Entry point
|
271 |
+
# workflow.set_entry_point("categorize_intent")
|
272 |
+
|
273 |
+
# # CONDITIONAL EDGES
|
274 |
+
# workflow.add_conditional_edges(
|
275 |
+
# "categorize_intent",
|
276 |
+
# route_intent,
|
277 |
+
# make_id_dict(["answer_chitchat","answer_ai_impact","search"])
|
278 |
+
# )
|
279 |
+
|
280 |
+
# workflow.add_conditional_edges(
|
281 |
+
# "search",
|
282 |
+
# route_translation,
|
283 |
+
# make_id_dict(["translate_query","transform_query"])
|
284 |
+
# )
|
285 |
+
|
286 |
+
# workflow.add_conditional_edges(
|
287 |
+
# "retrieve_documents",
|
288 |
+
# lambda x : route_based_on_relevant_docs(x,threshold_docs=threshold_docs),
|
289 |
+
# make_id_dict(["answer_rag","answer_rag_no_docs"])
|
290 |
+
# )
|
291 |
+
|
292 |
+
# # Define the edges
|
293 |
+
# workflow.add_edge("translate_query", "transform_query")
|
294 |
+
# workflow.add_edge("transform_query", "retrieve_documents")
|
295 |
+
# workflow.add_edge("retrieve_documents", "answer_rag")
|
296 |
+
# workflow.add_edge("answer_rag", END)
|
297 |
+
# workflow.add_edge("answer_rag_no_docs", END)
|
298 |
+
# workflow.add_edge("answer_chitchat", END)
|
299 |
+
# workflow.add_edge("answer_ai_impact", END)
|
300 |
+
|
301 |
+
# # Compile
|
302 |
+
# app = workflow.compile()
|
303 |
+
# return app
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
# def display_graph(app):
|
309 |
+
|
310 |
+
# display(
|
311 |
+
# Image(
|
312 |
+
# app.get_graph(xray = True).draw_mermaid_png(
|
313 |
+
# draw_method=MermaidDrawMethod.API,
|
314 |
+
# )
|
315 |
+
# )
|
316 |
+
# )
|
climateqa/engine/graph_retriever.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.retrievers import BaseRetriever
|
2 |
+
from langchain_core.documents.base import Document
|
3 |
+
from langchain_core.vectorstores import VectorStore
|
4 |
+
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
class GraphRetriever(BaseRetriever):
|
9 |
+
vectorstore:VectorStore
|
10 |
+
sources:list = ["OWID"] # plus tard ajouter OurWorldInData # faudra integrate avec l'autre retriever
|
11 |
+
threshold:float = 0.5
|
12 |
+
k_total:int = 10
|
13 |
+
|
14 |
+
def _get_relevant_documents(
|
15 |
+
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
16 |
+
) -> List[Document]:
|
17 |
+
|
18 |
+
# Check if all elements in the list are IEA or OWID
|
19 |
+
assert isinstance(self.sources,list)
|
20 |
+
assert self.sources
|
21 |
+
assert any([x in ["OWID"] for x in self.sources])
|
22 |
+
|
23 |
+
# Prepare base search kwargs
|
24 |
+
filters = {}
|
25 |
+
|
26 |
+
filters["source"] = {"$in": self.sources}
|
27 |
+
|
28 |
+
docs = self.vectorstore.similarity_search_with_score(query=query, filter=filters, k=self.k_total)
|
29 |
+
|
30 |
+
# Filter if scores are below threshold
|
31 |
+
docs = [x for x in docs if x[1] > self.threshold]
|
32 |
+
|
33 |
+
# Remove duplicate documents
|
34 |
+
unique_docs = []
|
35 |
+
seen_docs = []
|
36 |
+
for i, doc in enumerate(docs):
|
37 |
+
if doc[0].page_content not in seen_docs:
|
38 |
+
unique_docs.append(doc)
|
39 |
+
seen_docs.append(doc[0].page_content)
|
40 |
+
|
41 |
+
# Add score to metadata
|
42 |
+
results = []
|
43 |
+
for i,(doc,score) in enumerate(unique_docs):
|
44 |
+
doc.metadata["similarity_score"] = score
|
45 |
+
doc.metadata["content"] = doc.page_content
|
46 |
+
results.append(doc)
|
47 |
+
|
48 |
+
return results
|
climateqa/engine/reranker.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
import os
|
|
|
2 |
from scipy.special import expit, logit
|
3 |
from rerankers import Reranker
|
|
|
4 |
|
|
|
5 |
|
6 |
-
def get_reranker(model = "
|
7 |
|
8 |
-
assert model in ["nano","tiny","small","large"]
|
9 |
|
10 |
if model == "nano":
|
11 |
reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
|
@@ -17,6 +20,11 @@ def get_reranker(model = "nano",cohere_api_key = None):
|
|
17 |
if cohere_api_key is None:
|
18 |
cohere_api_key = os.environ["COHERE_API_KEY"]
|
19 |
reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
|
|
|
|
|
|
|
|
|
|
|
20 |
return reranker
|
21 |
|
22 |
|
@@ -26,6 +34,7 @@ def rerank_docs(reranker,docs,query):
|
|
26 |
# Get a list of texts from langchain docs
|
27 |
input_docs = [x.page_content for x in docs]
|
28 |
|
|
|
29 |
# Rerank using rerankers library
|
30 |
results = reranker.rank(query=query, docs=input_docs)
|
31 |
|
|
|
1 |
import os
|
2 |
+
from dotenv import load_dotenv
|
3 |
from scipy.special import expit, logit
|
4 |
from rerankers import Reranker
|
5 |
+
from sentence_transformers import CrossEncoder
|
6 |
|
7 |
+
load_dotenv()
|
8 |
|
9 |
+
def get_reranker(model = "jina", cohere_api_key = None):
|
10 |
|
11 |
+
assert model in ["nano","tiny","small","large", "jina"]
|
12 |
|
13 |
if model == "nano":
|
14 |
reranker = Reranker('ms-marco-TinyBERT-L-2-v2', model_type='flashrank')
|
|
|
20 |
if cohere_api_key is None:
|
21 |
cohere_api_key = os.environ["COHERE_API_KEY"]
|
22 |
reranker = Reranker("cohere", lang='en', api_key = cohere_api_key)
|
23 |
+
elif model == "jina":
|
24 |
+
# Reached token quota so does not work
|
25 |
+
reranker = Reranker("jina-reranker-v2-base-multilingual", api_key = os.getenv("JINA_RERANKER_API_KEY"))
|
26 |
+
# marche pas sans gpu ? et anyways returns with another structure donc faudrait changer le code du retriever node
|
27 |
+
# reranker = CrossEncoder("jinaai/jina-reranker-v2-base-multilingual", automodel_args={"torch_dtype": "auto"}, trust_remote_code=True,)
|
28 |
return reranker
|
29 |
|
30 |
|
|
|
34 |
# Get a list of texts from langchain docs
|
35 |
input_docs = [x.page_content for x in docs]
|
36 |
|
37 |
+
print(f"\n\nDOCS:{input_docs}\n\n")
|
38 |
# Rerank using rerankers library
|
39 |
results = reranker.rank(query=query, docs=input_docs)
|
40 |
|
climateqa/engine/retriever.py
CHANGED
@@ -28,6 +28,7 @@ class ClimateQARetriever(BaseRetriever):
|
|
28 |
|
29 |
# Check if all elements in the list are either IPCC or IPBES
|
30 |
assert isinstance(self.sources,list)
|
|
|
31 |
assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
|
32 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
33 |
|
|
|
28 |
|
29 |
# Check if all elements in the list are either IPCC or IPBES
|
30 |
assert isinstance(self.sources,list)
|
31 |
+
assert self.sources
|
32 |
assert all([x in ["IPCC","IPBES","IPOS"] for x in self.sources])
|
33 |
assert self.k_total > self.k_summary, "k_total should be greater than k_summary"
|
34 |
|
climateqa/engine/vectorstore.py
CHANGED
@@ -4,6 +4,7 @@
|
|
4 |
import os
|
5 |
from pinecone import Pinecone
|
6 |
from langchain_community.vectorstores import Pinecone as PineconeVectorstore
|
|
|
7 |
|
8 |
# LOAD ENVIRONMENT VARIABLES
|
9 |
try:
|
@@ -13,6 +14,11 @@ except:
|
|
13 |
pass
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
def get_pinecone_vectorstore(embeddings,text_key = "content"):
|
17 |
|
18 |
# # initialize pinecone
|
|
|
4 |
import os
|
5 |
from pinecone import Pinecone
|
6 |
from langchain_community.vectorstores import Pinecone as PineconeVectorstore
|
7 |
+
from langchain_chroma import Chroma
|
8 |
|
9 |
# LOAD ENVIRONMENT VARIABLES
|
10 |
try:
|
|
|
14 |
pass
|
15 |
|
16 |
|
17 |
+
def get_chroma_vectorstore(embedding_function, persist_directory="/home/dora/climate-question-answering/data/vectorstore"):
|
18 |
+
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
|
19 |
+
return vectorstore
|
20 |
+
|
21 |
+
|
22 |
def get_pinecone_vectorstore(embeddings,text_key = "content"):
|
23 |
|
24 |
# # initialize pinecone
|
climateqa/utils.py
CHANGED
@@ -20,3 +20,16 @@ def get_image_from_azure_blob_storage(path):
|
|
20 |
file_object = get_file_from_azure_blob_storage(path)
|
21 |
image = Image.open(file_object)
|
22 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
file_object = get_file_from_azure_blob_storage(path)
|
21 |
image = Image.open(file_object)
|
22 |
return image
|
23 |
+
|
24 |
+
def remove_duplicates_keep_highest_score(documents):
|
25 |
+
unique_docs = {}
|
26 |
+
|
27 |
+
for doc in documents:
|
28 |
+
doc_id = doc.metadata.get('doc_id')
|
29 |
+
if doc_id in unique_docs:
|
30 |
+
if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
|
31 |
+
unique_docs[doc_id] = doc
|
32 |
+
else:
|
33 |
+
unique_docs[doc_id] = doc
|
34 |
+
|
35 |
+
return list(unique_docs.values())
|
front/utils.py
CHANGED
@@ -33,6 +33,85 @@ def parse_output_llm_with_sources(output):
|
|
33 |
return content_parts
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def make_html_source(source,i):
|
37 |
meta = source.metadata
|
38 |
# content = source.page_content.split(":",1)[1].strip()
|
|
|
33 |
return content_parts
|
34 |
|
35 |
|
36 |
+
from collections import defaultdict
|
37 |
+
|
38 |
+
def generate_html_graphs(graphs):
|
39 |
+
# Organize graphs by category
|
40 |
+
categories = defaultdict(list)
|
41 |
+
for graph in graphs:
|
42 |
+
category = graph['metadata']['category']
|
43 |
+
categories[category].append(graph['embedding'])
|
44 |
+
|
45 |
+
# Begin constructing the HTML
|
46 |
+
html_code = '''
|
47 |
+
<!DOCTYPE html>
|
48 |
+
<html lang="en">
|
49 |
+
<head>
|
50 |
+
<meta charset="UTF-8">
|
51 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
52 |
+
<title>Graphs by Category</title>
|
53 |
+
<style>
|
54 |
+
.tab-content {
|
55 |
+
display: none;
|
56 |
+
}
|
57 |
+
.tab-content.active {
|
58 |
+
display: block;
|
59 |
+
}
|
60 |
+
.tabs {
|
61 |
+
margin-bottom: 20px;
|
62 |
+
}
|
63 |
+
.tab-button {
|
64 |
+
background-color: #ddd;
|
65 |
+
border: none;
|
66 |
+
padding: 10px 20px;
|
67 |
+
cursor: pointer;
|
68 |
+
margin-right: 5px;
|
69 |
+
}
|
70 |
+
.tab-button.active {
|
71 |
+
background-color: #ccc;
|
72 |
+
}
|
73 |
+
</style>
|
74 |
+
<script>
|
75 |
+
function showTab(tabId) {
|
76 |
+
var contents = document.getElementsByClassName('tab-content');
|
77 |
+
var buttons = document.getElementsByClassName('tab-button');
|
78 |
+
for (var i = 0; i < contents.length; i++) {
|
79 |
+
contents[i].classList.remove('active');
|
80 |
+
buttons[i].classList.remove('active');
|
81 |
+
}
|
82 |
+
document.getElementById(tabId).classList.add('active');
|
83 |
+
document.querySelector('button[data-tab="'+tabId+'"]').classList.add('active');
|
84 |
+
}
|
85 |
+
</script>
|
86 |
+
</head>
|
87 |
+
<body>
|
88 |
+
<div class="tabs">
|
89 |
+
'''
|
90 |
+
|
91 |
+
# Add buttons for each category
|
92 |
+
for i, category in enumerate(categories.keys()):
|
93 |
+
active_class = 'active' if i == 0 else ''
|
94 |
+
html_code += f'<button class="tab-button {active_class}" onclick="showTab(\'tab-{i}\')" data-tab="tab-{i}">{category}</button>'
|
95 |
+
|
96 |
+
html_code += '</div>'
|
97 |
+
|
98 |
+
# Add content for each category
|
99 |
+
for i, (category, embeds) in enumerate(categories.items()):
|
100 |
+
active_class = 'active' if i == 0 else ''
|
101 |
+
html_code += f'<div id="tab-{i}" class="tab-content {active_class}">'
|
102 |
+
for embed in embeds:
|
103 |
+
html_code += embed
|
104 |
+
html_code += '</div>'
|
105 |
+
|
106 |
+
html_code += '''
|
107 |
+
</body>
|
108 |
+
</html>
|
109 |
+
'''
|
110 |
+
|
111 |
+
return html_code
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
def make_html_source(source,i):
|
116 |
meta = source.metadata
|
117 |
# content = source.page_content.split(":",1)[1].strip()
|
style.css
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
--user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
|
4 |
} */
|
5 |
|
6 |
-
.warning-box {
|
7 |
background-color: #fff3cd;
|
8 |
border: 1px solid #ffeeba;
|
9 |
border-radius: 4px;
|
@@ -464,4 +464,32 @@ span.chatbot > p > img{
|
|
464 |
|
465 |
.score-orange{
|
466 |
color:red !important;
|
467 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
--user-image: url('https://ih1.redbubble.net/image.4776899543.6215/st,small,507x507-pad,600x600,f8f8f8.jpg');
|
4 |
} */
|
5 |
|
6 |
+
.warning-box {
|
7 |
background-color: #fff3cd;
|
8 |
border: 1px solid #ffeeba;
|
9 |
border-radius: 4px;
|
|
|
464 |
|
465 |
.score-orange{
|
466 |
color:red !important;
|
467 |
+
}
|
468 |
+
|
469 |
+
/* Additional style for scrollable tab content */
|
470 |
+
div#tab-recommended_content {
|
471 |
+
overflow-y: auto; /* Enable vertical scrolling */
|
472 |
+
max-height: 80vh; /* Adjust height as needed */
|
473 |
+
}
|
474 |
+
|
475 |
+
/* Mobile specific adjustments */
|
476 |
+
@media screen and (max-width: 767px) {
|
477 |
+
div#tab-recommended_content {
|
478 |
+
max-height: 50vh; /* Reduce height for smaller screens */
|
479 |
+
overflow-y: auto;
|
480 |
+
}
|
481 |
+
}
|
482 |
+
|
483 |
+
/* Additional style for scrollable tab content */
|
484 |
+
div#tab-saved-graphs {
|
485 |
+
overflow-y: auto; /* Enable vertical scrolling */
|
486 |
+
max-height: 80vh; /* Adjust height as needed */
|
487 |
+
}
|
488 |
+
|
489 |
+
/* Mobile specific adjustments */
|
490 |
+
@media screen and (max-width: 767px) {
|
491 |
+
div#tab-saved-graphs {
|
492 |
+
max-height: 50vh; /* Reduce height for smaller screens */
|
493 |
+
overflow-y: auto;
|
494 |
+
}
|
495 |
+
}
|