Merged in dev (pull request #27)
Browse files- app.py +6 -12
- climateqa/engine/chains/retrieve_documents.py +25 -11
- climateqa/engine/graph_retriever.py +3 -4
- climateqa/engine/llm/openai.py +0 -1
- climateqa/engine/talk_to_data/input_processing.py +73 -8
- climateqa/engine/talk_to_data/ipcc/config.py +16 -4
- climateqa/engine/talk_to_data/ipcc/plot_informations.py +23 -0
- climateqa/engine/talk_to_data/ipcc/plots.py +81 -3
- climateqa/engine/talk_to_data/ipcc/queries.py +65 -5
- climateqa/engine/talk_to_data/main.py +2 -2
- climateqa/engine/talk_to_data/myVanna.py +0 -13
- climateqa/engine/talk_to_data/plot.py +0 -418
- climateqa/engine/talk_to_data/sql_query.py +0 -114
- climateqa/engine/talk_to_data/talk_to_drias.py +0 -317
- climateqa/engine/talk_to_data/utils.py +0 -281
- climateqa/engine/talk_to_data/vanna_class.py +0 -325
- climateqa/engine/talk_to_data/workflow/drias.py +8 -3
- climateqa/engine/talk_to_data/workflow/ipcc.py +20 -7
- climateqa/engine/vectorstore.py +137 -45
- climateqa/utils.py +1 -1
- front/tabs/tab_ipcc.py +2 -0
- requirements.txt +3 -0
- sandbox/20241104 - CQA - StepByStep CQA.ipynb +0 -0
- style.css +0 -1
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from azure.storage.fileshare import ShareServiceClient
|
|
| 7 |
# Import custom modules
|
| 8 |
from climateqa.engine.embeddings import get_embeddings_function
|
| 9 |
from climateqa.engine.llm import get_llm
|
| 10 |
-
from climateqa.engine.vectorstore import
|
| 11 |
from climateqa.engine.reranker import get_reranker
|
| 12 |
from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
|
| 13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
|
@@ -66,17 +66,11 @@ user_id = create_user_id()
|
|
| 66 |
|
| 67 |
# Create vectorstore and retriever
|
| 68 |
embeddings_function = get_embeddings_function()
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
index_name=os.getenv("PINECONE_API_INDEX_OWID"),
|
| 75 |
-
text_key="description",
|
| 76 |
-
)
|
| 77 |
-
vectorstore_region = get_pinecone_vectorstore(
|
| 78 |
-
embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
|
| 79 |
-
)
|
| 80 |
|
| 81 |
llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
|
| 82 |
if os.environ["GRADIO_ENV"] == "local":
|
|
|
|
| 7 |
# Import custom modules
|
| 8 |
from climateqa.engine.embeddings import get_embeddings_function
|
| 9 |
from climateqa.engine.llm import get_llm
|
| 10 |
+
from climateqa.engine.vectorstore import get_vectorstore
|
| 11 |
from climateqa.engine.reranker import get_reranker
|
| 12 |
from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
|
| 13 |
from climateqa.engine.chains.retrieve_papers import find_papers
|
|
|
|
| 66 |
|
| 67 |
# Create vectorstore and retriever
|
| 68 |
embeddings_function = get_embeddings_function()
|
| 69 |
+
|
| 70 |
+
vectorstore = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-ipx")
|
| 71 |
+
vectorstore_graphs = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-owid", text_key="description")
|
| 72 |
+
vectorstore_region = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-v2")
|
| 73 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
|
| 76 |
if os.environ["GRADIO_ENV"] == "local":
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
|
@@ -19,7 +19,7 @@ from ..llm import get_llm
|
|
| 19 |
from .prompts import retrieve_chapter_prompt_template
|
| 20 |
from langchain_core.prompts import ChatPromptTemplate
|
| 21 |
from langchain_core.output_parsers import StrOutputParser
|
| 22 |
-
from ..vectorstore import
|
| 23 |
from ..embeddings import get_embeddings_function
|
| 24 |
import ast
|
| 25 |
|
|
@@ -134,7 +134,7 @@ def get_ToCs(version: str) :
|
|
| 134 |
"version": version
|
| 135 |
}
|
| 136 |
embeddings_function = get_embeddings_function()
|
| 137 |
-
vectorstore =
|
| 138 |
tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
|
| 139 |
|
| 140 |
# remove duplicates or almost duplicates
|
|
@@ -236,7 +236,7 @@ async def get_POC_documents_by_ToC_relevant_documents(
|
|
| 236 |
filters_text_toc = {
|
| 237 |
**filters,
|
| 238 |
"chunk_type":"text",
|
| 239 |
-
"toc_level0": {"$in": toc_filters}
|
| 240 |
"version": version
|
| 241 |
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
| 242 |
}
|
|
@@ -273,6 +273,22 @@ async def get_POC_documents_by_ToC_relevant_documents(
|
|
| 273 |
"docs_images" : docs_images
|
| 274 |
}
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
async def get_IPCC_relevant_documents(
|
| 278 |
query: str,
|
|
@@ -299,9 +315,9 @@ async def get_IPCC_relevant_documents(
|
|
| 299 |
filters = {}
|
| 300 |
|
| 301 |
if len(reports) > 0:
|
| 302 |
-
filters["short_name"] = {"$in":reports}
|
| 303 |
else:
|
| 304 |
-
filters["source"] = {
|
| 305 |
|
| 306 |
# INIT
|
| 307 |
docs_summaries = []
|
|
@@ -323,18 +339,16 @@ async def get_IPCC_relevant_documents(
|
|
| 323 |
filters_summaries = {
|
| 324 |
**filters,
|
| 325 |
"chunk_type":"text",
|
| 326 |
-
"report_type": {
|
| 327 |
}
|
| 328 |
|
| 329 |
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
|
| 330 |
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
|
| 331 |
|
| 332 |
# Search for k_total - k_summary documents in the full reports dataset
|
| 333 |
-
filters_full =
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
"report_type": { "$nin":["SPM"]},
|
| 337 |
-
}
|
| 338 |
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
|
| 339 |
|
| 340 |
if search_figures:
|
|
|
|
| 19 |
from .prompts import retrieve_chapter_prompt_template
|
| 20 |
from langchain_core.prompts import ChatPromptTemplate
|
| 21 |
from langchain_core.output_parsers import StrOutputParser
|
| 22 |
+
from ..vectorstore import get_vectorstore
|
| 23 |
from ..embeddings import get_embeddings_function
|
| 24 |
import ast
|
| 25 |
|
|
|
|
| 134 |
"version": version
|
| 135 |
}
|
| 136 |
embeddings_function = get_embeddings_function()
|
| 137 |
+
vectorstore = get_vectorstore(provider="qdrant", embeddings=embeddings_function, index_name="climateqa")
|
| 138 |
tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
|
| 139 |
|
| 140 |
# remove duplicates or almost duplicates
|
|
|
|
| 236 |
filters_text_toc = {
|
| 237 |
**filters,
|
| 238 |
"chunk_type":"text",
|
| 239 |
+
"toc_level0": toc_filters, # Changed from {"$in": toc_filters} to direct list
|
| 240 |
"version": version
|
| 241 |
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
| 242 |
}
|
|
|
|
| 273 |
"docs_images" : docs_images
|
| 274 |
}
|
| 275 |
|
| 276 |
+
def filter_for_full_report_documents(filters: dict) -> dict:
|
| 277 |
+
"""
|
| 278 |
+
Filter for full report documents.
|
| 279 |
+
Returns a dictionary format compatible with all vectorstore providers.
|
| 280 |
+
"""
|
| 281 |
+
# Start with the base filters
|
| 282 |
+
full_filters = filters.copy()
|
| 283 |
+
|
| 284 |
+
# Add chunk_type filter
|
| 285 |
+
full_filters["chunk_type"] = "text"
|
| 286 |
+
|
| 287 |
+
# Add report_type exclusion using the new _exclude suffix format
|
| 288 |
+
# This will be converted to appropriate OData filter by Azure Search wrapper
|
| 289 |
+
full_filters["report_type_exclude"] = ["SPM"]
|
| 290 |
+
|
| 291 |
+
return full_filters
|
| 292 |
|
| 293 |
async def get_IPCC_relevant_documents(
|
| 294 |
query: str,
|
|
|
|
| 315 |
filters = {}
|
| 316 |
|
| 317 |
if len(reports) > 0:
|
| 318 |
+
filters["short_name"] = reports # Changed from {"$in":reports} to direct list
|
| 319 |
else:
|
| 320 |
+
filters["source"] = sources # Changed from {"$in": sources} to direct list
|
| 321 |
|
| 322 |
# INIT
|
| 323 |
docs_summaries = []
|
|
|
|
| 339 |
filters_summaries = {
|
| 340 |
**filters,
|
| 341 |
"chunk_type":"text",
|
| 342 |
+
"report_type": ["SPM"], # Changed from {"$in":["SPM"]} to direct list
|
| 343 |
}
|
| 344 |
|
| 345 |
docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
|
| 346 |
docs_summaries = [x for x in docs_summaries if x[1] > threshold]
|
| 347 |
|
| 348 |
# Search for k_total - k_summary documents in the full reports dataset
|
| 349 |
+
filters_full = filter_for_full_report_documents(filters)
|
| 350 |
+
|
| 351 |
+
|
|
|
|
|
|
|
| 352 |
docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
|
| 353 |
|
| 354 |
if search_figures:
|
climateqa/engine/graph_retriever.py
CHANGED
|
@@ -60,10 +60,9 @@ async def retrieve_graphs(
|
|
| 60 |
assert sources
|
| 61 |
assert any([x in ["OWID"] for x in sources])
|
| 62 |
|
| 63 |
-
# Prepare base search kwargs
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
filters["source"] = {"$in": sources}
|
| 67 |
|
| 68 |
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
|
| 69 |
|
|
|
|
| 60 |
assert sources
|
| 61 |
assert any([x in ["OWID"] for x in sources])
|
| 62 |
|
| 63 |
+
# Prepare base search kwargs for Azure AI Search
|
| 64 |
+
# Azure expects a filter string, e.g. "source eq 'OWID' or source eq 'IEA'"
|
| 65 |
+
filters = {"source":"OWID"}
|
|
|
|
| 66 |
|
| 67 |
docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
|
| 68 |
|
climateqa/engine/llm/openai.py
CHANGED
|
@@ -8,7 +8,6 @@ except Exception:
|
|
| 8 |
pass
|
| 9 |
|
| 10 |
def get_llm(model="gpt-4o-mini",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
|
| 11 |
-
|
| 12 |
llm = ChatOpenAI(
|
| 13 |
model=model,
|
| 14 |
api_key=os.environ.get("THEO_API_KEY", None),
|
|
|
|
| 8 |
pass
|
| 9 |
|
| 10 |
def get_llm(model="gpt-4o-mini",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
|
|
|
|
| 11 |
llm = ChatOpenAI(
|
| 12 |
model=model,
|
| 13 |
api_key=os.environ.get("THEO_API_KEY", None),
|
climateqa/engine/talk_to_data/input_processing.py
CHANGED
|
@@ -10,6 +10,7 @@ from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
|
|
| 10 |
from climateqa.engine.talk_to_data.objects.location import Location
|
| 11 |
from climateqa.engine.talk_to_data.objects.plot import Plot
|
| 12 |
from climateqa.engine.talk_to_data.objects.states import State
|
|
|
|
| 13 |
|
| 14 |
async def detect_location_with_openai(sentence: str) -> str:
|
| 15 |
"""
|
|
@@ -118,7 +119,7 @@ async def detect_year_with_openai(sentence: str) -> str:
|
|
| 118 |
return years_list[0]
|
| 119 |
else:
|
| 120 |
return ""
|
| 121 |
-
|
| 122 |
|
| 123 |
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
|
| 124 |
"""Identifies relevant tables for a plot based on user input.
|
|
@@ -227,6 +228,55 @@ async def find_year(user_input: str) -> str| None:
|
|
| 227 |
return None
|
| 228 |
return year
|
| 229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
|
| 231 |
print("---- Find relevant plots ----")
|
| 232 |
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
|
|
@@ -237,16 +287,28 @@ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: l
|
|
| 237 |
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
|
| 238 |
return relevant_tables
|
| 239 |
|
| 240 |
-
async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
|
| 241 |
-
"""
|
|
|
|
| 242 |
|
| 243 |
Args:
|
| 244 |
-
state (State): state
|
| 245 |
-
param_name (str): name of the
|
| 246 |
-
|
| 247 |
|
| 248 |
Returns:
|
| 249 |
-
dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
"""
|
| 251 |
if param_name == 'location':
|
| 252 |
location = await find_location(state['user_input'], mode)
|
|
@@ -254,4 +316,7 @@ async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'
|
|
| 254 |
if param_name == 'year':
|
| 255 |
year = await find_year(state['user_input'])
|
| 256 |
return {'year': year}
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from climateqa.engine.talk_to_data.objects.location import Location
|
| 11 |
from climateqa.engine.talk_to_data.objects.plot import Plot
|
| 12 |
from climateqa.engine.talk_to_data.objects.states import State
|
| 13 |
+
import calendar
|
| 14 |
|
| 15 |
async def detect_location_with_openai(sentence: str) -> str:
|
| 16 |
"""
|
|
|
|
| 119 |
return years_list[0]
|
| 120 |
else:
|
| 121 |
return ""
|
| 122 |
+
|
| 123 |
|
| 124 |
async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
|
| 125 |
"""Identifies relevant tables for a plot based on user input.
|
|
|
|
| 228 |
return None
|
| 229 |
return year
|
| 230 |
|
| 231 |
+
async def find_month(user_input: str) -> dict[str, str|None]:
|
| 232 |
+
"""
|
| 233 |
+
Extracts month information from user input using an LLM.
|
| 234 |
+
|
| 235 |
+
This function analyzes the user's query to detect if a month is mentioned.
|
| 236 |
+
It returns both the month number (as a string, e.g. '7' for July) and the full English month name (e.g. 'July').
|
| 237 |
+
If no month is found, both values will be None.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
user_input (str): The user's query text.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
dict[str, str|None]: A dictionary with keys:
|
| 244 |
+
- "month_number": the month number as a string (e.g. '7'), or None if not found
|
| 245 |
+
- "month_name": the full English month name (e.g. 'July'), or None if not found
|
| 246 |
+
|
| 247 |
+
Example:
|
| 248 |
+
>>> await find_month("Show me the temperature in Paris in July")
|
| 249 |
+
{'month_number': '7', 'month_name': 'July'}
|
| 250 |
+
>>> await find_month("Show me the temperature in Paris")
|
| 251 |
+
{'month_number': None, 'month_name': None}
|
| 252 |
+
"""
|
| 253 |
+
|
| 254 |
+
llm = get_llm()
|
| 255 |
+
prompt = """
|
| 256 |
+
Extract the month (as a number from 1 to 12) mentioned in the following sentence.
|
| 257 |
+
Return the result as a Python list of integers. If no month is mentioned, return an empty list.
|
| 258 |
+
|
| 259 |
+
Sentence: "{sentence}"
|
| 260 |
+
"""
|
| 261 |
+
prompt = ChatPromptTemplate.from_template(prompt)
|
| 262 |
+
structured_llm = llm.with_structured_output(ArrayOutput)
|
| 263 |
+
chain = prompt | structured_llm
|
| 264 |
+
response: ArrayOutput = await chain.ainvoke({"sentence": user_input})
|
| 265 |
+
months_list = ast.literal_eval(response['array'])
|
| 266 |
+
if len(months_list) > 0:
|
| 267 |
+
month_number = int(months_list[0])
|
| 268 |
+
month_name = calendar.month_name[month_number]
|
| 269 |
+
return {
|
| 270 |
+
"month_number": str(month_number),
|
| 271 |
+
"month_name": month_name
|
| 272 |
+
}
|
| 273 |
+
else:
|
| 274 |
+
return {
|
| 275 |
+
"month_number" : None,
|
| 276 |
+
"month_name" : None
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
|
| 281 |
print("---- Find relevant plots ----")
|
| 282 |
relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
|
|
|
|
| 287 |
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
|
| 288 |
return relevant_tables
|
| 289 |
|
| 290 |
+
async def find_param(state: State, param_name: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
|
| 291 |
+
"""
|
| 292 |
+
Retrieves a specific parameter (location, year, month, etc.) from the user's input using the appropriate extraction method.
|
| 293 |
|
| 294 |
Args:
|
| 295 |
+
state (State): The current state containing at least the user's input under 'user_input'.
|
| 296 |
+
param_name (str): The name of the parameter to extract. Supported: 'location', 'year', 'month'.
|
| 297 |
+
mode (Literal['DRIAS', 'IPCC']): The data mode to use for location extraction.
|
| 298 |
|
| 299 |
Returns:
|
| 300 |
+
- For 'location': a Location object (dict with keys like 'location', 'latitude', etc.), or None if not found.
|
| 301 |
+
- For 'year': a dict {'year': year or None}.
|
| 302 |
+
- For 'month': a dict {'month_number': str or None, 'month_name': str or None}.
|
| 303 |
+
- None if the parameter is not recognized or not found.
|
| 304 |
+
|
| 305 |
+
Example:
|
| 306 |
+
>>> await find_param(state, 'location')
|
| 307 |
+
{'location': 'Paris', 'latitude': ..., ...}
|
| 308 |
+
>>> await find_param(state, 'year')
|
| 309 |
+
{'year': '2050'}
|
| 310 |
+
>>> await find_param(state, 'month')
|
| 311 |
+
{'month_number': '7', 'month_name': 'July'}
|
| 312 |
"""
|
| 313 |
if param_name == 'location':
|
| 314 |
location = await find_location(state['user_input'], mode)
|
|
|
|
| 316 |
if param_name == 'year':
|
| 317 |
year = await find_year(state['user_input'])
|
| 318 |
return {'year': year}
|
| 319 |
+
if param_name == 'month':
|
| 320 |
+
month = await find_month(state['user_input'])
|
| 321 |
+
return month
|
| 322 |
+
return None
|
climateqa/engine/talk_to_data/ipcc/config.py
CHANGED
|
@@ -6,16 +6,22 @@ from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
|
|
| 6 |
IPCC_TABLES = [
|
| 7 |
"mean_temperature",
|
| 8 |
"total_precipitation",
|
|
|
|
|
|
|
| 9 |
]
|
| 10 |
|
| 11 |
IPCC_INDICATOR_COLUMNS_PER_TABLE = {
|
| 12 |
"mean_temperature": "mean_temperature",
|
| 13 |
-
"total_precipitation": "total_precipitation"
|
|
|
|
|
|
|
| 14 |
}
|
| 15 |
|
| 16 |
IPCC_INDICATOR_TO_UNIT = {
|
| 17 |
"mean_temperature": "°C",
|
| 18 |
-
"total_precipitation": "mm/day"
|
|
|
|
|
|
|
| 19 |
}
|
| 20 |
|
| 21 |
IPCC_SCENARIO = [
|
|
@@ -30,7 +36,8 @@ IPCC_MODELS = []
|
|
| 30 |
|
| 31 |
IPCC_PLOT_PARAMETERS = [
|
| 32 |
'year',
|
| 33 |
-
'location'
|
|
|
|
| 34 |
]
|
| 35 |
|
| 36 |
MACRO_COUNTRIES = ['JP',
|
|
@@ -63,7 +70,9 @@ HUGE_MACRO_COUNTRIES = ['CL',
|
|
| 63 |
|
| 64 |
IPCC_INDICATOR_TO_COLORSCALE = {
|
| 65 |
"mean_temperature": TEMPERATURE_COLORSCALE,
|
| 66 |
-
"total_precipitation": PRECIPITATION_COLORSCALE
|
|
|
|
|
|
|
| 67 |
}
|
| 68 |
|
| 69 |
IPCC_UI_TEXT = """
|
|
@@ -77,9 +86,12 @@ By default, we take the **mediane of each climate model**.
|
|
| 77 |
Current available charts :
|
| 78 |
- Yearly evolution of an indicator at a specific location (historical + SSP Projections)
|
| 79 |
- Yearly spatial distribution of an indicator in a specific country
|
|
|
|
| 80 |
|
| 81 |
Current available indicators :
|
| 82 |
- Mean temperature
|
|
|
|
|
|
|
| 83 |
- Total precipitation
|
| 84 |
|
| 85 |
For example, you can ask:
|
|
|
|
| 6 |
IPCC_TABLES = [
|
| 7 |
"mean_temperature",
|
| 8 |
"total_precipitation",
|
| 9 |
+
"minimum_temperature",
|
| 10 |
+
"maximum_temperature"
|
| 11 |
]
|
| 12 |
|
| 13 |
IPCC_INDICATOR_COLUMNS_PER_TABLE = {
|
| 14 |
"mean_temperature": "mean_temperature",
|
| 15 |
+
"total_precipitation": "total_precipitation",
|
| 16 |
+
"minimum_temperature": "minimum_temperature",
|
| 17 |
+
"maximum_temperature": "maximum_temperature"
|
| 18 |
}
|
| 19 |
|
| 20 |
IPCC_INDICATOR_TO_UNIT = {
|
| 21 |
"mean_temperature": "°C",
|
| 22 |
+
"total_precipitation": "mm/day",
|
| 23 |
+
"minimum_temperature": "°C",
|
| 24 |
+
"maximum_temperature": "°C"
|
| 25 |
}
|
| 26 |
|
| 27 |
IPCC_SCENARIO = [
|
|
|
|
| 36 |
|
| 37 |
IPCC_PLOT_PARAMETERS = [
|
| 38 |
'year',
|
| 39 |
+
'location',
|
| 40 |
+
'month'
|
| 41 |
]
|
| 42 |
|
| 43 |
MACRO_COUNTRIES = ['JP',
|
|
|
|
| 70 |
|
| 71 |
IPCC_INDICATOR_TO_COLORSCALE = {
|
| 72 |
"mean_temperature": TEMPERATURE_COLORSCALE,
|
| 73 |
+
"total_precipitation": PRECIPITATION_COLORSCALE,
|
| 74 |
+
"minimum_temperature": TEMPERATURE_COLORSCALE,
|
| 75 |
+
"maximum_temperature": TEMPERATURE_COLORSCALE,
|
| 76 |
}
|
| 77 |
|
| 78 |
IPCC_UI_TEXT = """
|
|
|
|
| 86 |
Current available charts :
|
| 87 |
- Yearly evolution of an indicator at a specific location (historical + SSP Projections)
|
| 88 |
- Yearly spatial distribution of an indicator in a specific country
|
| 89 |
+
- Yearly evolution of an indicator in a specific month at a specific location (historical + SSP Projections)
|
| 90 |
|
| 91 |
Current available indicators :
|
| 92 |
- Mean temperature
|
| 93 |
+
- Minimum temperature
|
| 94 |
+
- Maximum temperature
|
| 95 |
- Total precipitation
|
| 96 |
|
| 97 |
For example, you can ask:
|
climateqa/engine/talk_to_data/ipcc/plot_informations.py
CHANGED
|
@@ -47,4 +47,27 @@ Each grid point is colored according to the value of the indicator ({unit}), all
|
|
| 47 |
- For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
|
| 48 |
- The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
|
| 49 |
- The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""
|
|
|
|
| 47 |
- For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
|
| 48 |
- The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
|
| 49 |
- The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def indicator_specific_month_evolution_informations(
|
| 53 |
+
indicator: str,
|
| 54 |
+
params: dict[str, str]
|
| 55 |
+
) -> str:
|
| 56 |
+
if "location" not in params:
|
| 57 |
+
raise ValueError('"location" must be provided in params')
|
| 58 |
+
location = params["location"]
|
| 59 |
+
if "month_name" not in params:
|
| 60 |
+
raise ValueError('"month_name" must be provided in params')
|
| 61 |
+
month = params["month_name"]
|
| 62 |
+
unit = IPCC_INDICATOR_TO_UNIT[indicator]
|
| 63 |
+
return f"""
|
| 64 |
+
This plot shows how the climate indicator **{indicator}** evolves over time in **{location}** for the month of **{month}**.
|
| 65 |
+
It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
|
| 66 |
+
The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}) for the selected month.
|
| 67 |
+
Each line corresponds to a different scenario, allowing you to compare how {indicator} for month {month} might change under various future conditions.
|
| 68 |
+
|
| 69 |
+
**Data source:**
|
| 70 |
+
- The data comes from the IPCC climate datasets (Parquet files) for the relevant indicator, location, and month.
|
| 71 |
+
- For each year and scenario, the value of {indicator} for month {month} is extracted for the selected location.
|
| 72 |
+
- The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
|
| 73 |
"""
|
climateqa/engine/talk_to_data/ipcc/plots.py
CHANGED
|
@@ -5,8 +5,8 @@ import pandas as pd
|
|
| 5 |
import geojson
|
| 6 |
|
| 7 |
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
|
| 8 |
-
from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
|
| 9 |
-
from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
|
| 10 |
from climateqa.engine.talk_to_data.objects.plot import Plot
|
| 11 |
|
| 12 |
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
|
|
@@ -102,6 +102,82 @@ indicator_evolution_at_location_historical_and_projections: Plot = {
|
|
| 102 |
"short_name": "Evolution"
|
| 103 |
}
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
def plot_choropleth_map_of_country_indicator_for_specific_year(
|
| 106 |
params: dict,
|
| 107 |
) -> Callable[[pd.DataFrame], Figure]:
|
|
@@ -167,6 +243,7 @@ def plot_choropleth_map_of_country_indicator_for_specific_year(
|
|
| 167 |
|
| 168 |
return plot_data
|
| 169 |
|
|
|
|
| 170 |
choropleth_map_of_country_indicator_for_specific_year: Plot = {
|
| 171 |
"name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
|
| 172 |
"description": (
|
|
@@ -185,5 +262,6 @@ choropleth_map_of_country_indicator_for_specific_year: Plot = {
|
|
| 185 |
|
| 186 |
IPCC_PLOTS = [
|
| 187 |
indicator_evolution_at_location_historical_and_projections,
|
| 188 |
-
choropleth_map_of_country_indicator_for_specific_year
|
|
|
|
| 189 |
]
|
|
|
|
| 5 |
import geojson
|
| 6 |
|
| 7 |
from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
|
| 8 |
+
from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations, indicator_specific_month_evolution_informations
|
| 9 |
+
from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_and_specific_month_at_location_query, indicator_per_year_at_location_query
|
| 10 |
from climateqa.engine.talk_to_data.objects.plot import Plot
|
| 11 |
|
| 12 |
def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
|
|
|
|
| 102 |
"short_name": "Evolution"
|
| 103 |
}
|
| 104 |
|
| 105 |
+
def plot_indicator_monthly_evolution_at_location(
|
| 106 |
+
params: dict,
|
| 107 |
+
) -> Callable[[pd.DataFrame], Figure]:
|
| 108 |
+
"""
|
| 109 |
+
Returns a function that generates a line plot showing the evolution of a climate indicator
|
| 110 |
+
for a specific month over time at a specific location, including both historical data
|
| 111 |
+
and future projections for different climate scenarios.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
params (dict): Dictionary with:
|
| 115 |
+
- indicator_column (str): Name of the climate indicator column to plot.
|
| 116 |
+
- location (str): Location (e.g., country, city) for which to plot the indicator.
|
| 117 |
+
- month (str): Month name to plot.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure.
|
| 121 |
+
"""
|
| 122 |
+
indicator = params["indicator_column"]
|
| 123 |
+
location = params["location"]
|
| 124 |
+
month = params["month_name"]
|
| 125 |
+
indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
|
| 126 |
+
unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
|
| 127 |
+
|
| 128 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
| 129 |
+
df = df.sort_values(by='year')
|
| 130 |
+
years = df['year'].astype(int).tolist()
|
| 131 |
+
indicators = df[indicator].astype(float).tolist()
|
| 132 |
+
scenarios = df['scenario'].astype(str).tolist()
|
| 133 |
+
|
| 134 |
+
# Find last historical value for continuity
|
| 135 |
+
last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
|
| 136 |
+
last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
|
| 137 |
+
|
| 138 |
+
fig = go.Figure()
|
| 139 |
+
for scenario in IPCC_SCENARIO:
|
| 140 |
+
x = [y for y, s in zip(years, scenarios) if s == scenario]
|
| 141 |
+
y = [v for v, s in zip(indicators, scenarios) if s == scenario]
|
| 142 |
+
# Connect historical to scenario
|
| 143 |
+
if scenario != 'historical' and last_historical_indicator is not None:
|
| 144 |
+
x = [last_historical_year] + x
|
| 145 |
+
y = [last_historical_indicator] + y
|
| 146 |
+
fig.add_trace(go.Scatter(
|
| 147 |
+
x=x,
|
| 148 |
+
y=y,
|
| 149 |
+
mode='lines',
|
| 150 |
+
name=scenario
|
| 151 |
+
))
|
| 152 |
+
|
| 153 |
+
fig.update_layout(
|
| 154 |
+
title=f'Evolution of {indicator_label} in {month} in {location} (Historical + SSP Scenarios)',
|
| 155 |
+
xaxis_title='Year',
|
| 156 |
+
yaxis_title=f'{indicator_label} ({unit})',
|
| 157 |
+
legend_title='Scenario',
|
| 158 |
+
height=800,
|
| 159 |
+
)
|
| 160 |
+
return fig
|
| 161 |
+
|
| 162 |
+
return plot_data
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
indicator_specific_month_evolution_at_location: Plot = {
|
| 166 |
+
"name": "Indicator specific month Evolution at Location (Historical + Projections)",
|
| 167 |
+
"description": (
|
| 168 |
+
"Shows how a climate indicator (e.g., rainfall, temperature) for a specific month changes over time at a specific location, "
|
| 169 |
+
"including historical data and future projections. "
|
| 170 |
+
"Useful for questions about the value or trend of an indicator for a given month at a location, "
|
| 171 |
+
"such as 'How does July temperature evolve in Paris over time?'. "
|
| 172 |
+
"Parameters: indicator_column (the climate variable), location (e.g., country, city), month (1-12)."
|
| 173 |
+
),
|
| 174 |
+
"params": ["indicator_column", "location", "month"],
|
| 175 |
+
"plot_function": plot_indicator_monthly_evolution_at_location,
|
| 176 |
+
"sql_query": indicator_per_year_and_specific_month_at_location_query,
|
| 177 |
+
"plot_information": indicator_specific_month_evolution_informations,
|
| 178 |
+
"short_name": "Evolution for a specific month"
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
def plot_choropleth_map_of_country_indicator_for_specific_year(
|
| 182 |
params: dict,
|
| 183 |
) -> Callable[[pd.DataFrame], Figure]:
|
|
|
|
| 243 |
|
| 244 |
return plot_data
|
| 245 |
|
| 246 |
+
|
| 247 |
choropleth_map_of_country_indicator_for_specific_year: Plot = {
|
| 248 |
"name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
|
| 249 |
"description": (
|
|
|
|
| 262 |
|
| 263 |
IPCC_PLOTS = [
|
| 264 |
indicator_evolution_at_location_historical_and_projections,
|
| 265 |
+
choropleth_map_of_country_indicator_for_specific_year,
|
| 266 |
+
indicator_specific_month_evolution_at_location
|
| 267 |
]
|
climateqa/engine/talk_to_data/ipcc/queries.py
CHANGED
|
@@ -43,7 +43,7 @@ def indicator_per_year_at_location_query(
|
|
| 43 |
return ""
|
| 44 |
|
| 45 |
if country_code in MACRO_COUNTRIES:
|
| 46 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}
|
| 47 |
sql_query = f"""
|
| 48 |
SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
|
| 49 |
FROM {table_path}
|
|
@@ -52,7 +52,7 @@ def indicator_per_year_at_location_query(
|
|
| 52 |
ORDER BY year, scenario
|
| 53 |
"""
|
| 54 |
elif country_code in HUGE_MACRO_COUNTRIES:
|
| 55 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}
|
| 56 |
sql_query = f"""
|
| 57 |
SELECT year, scenario, {indicator_column}
|
| 58 |
FROM {table_path}
|
|
@@ -75,6 +75,66 @@ def indicator_per_year_at_location_query(
|
|
| 75 |
"""
|
| 76 |
return sql_query.strip()
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
| 79 |
"""
|
| 80 |
Parameters for querying an indicator's values across locations for a specific year.
|
|
@@ -110,7 +170,7 @@ def indicator_for_given_year_query(
|
|
| 110 |
return ""
|
| 111 |
|
| 112 |
if country_code in MACRO_COUNTRIES:
|
| 113 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}
|
| 114 |
sql_query = f"""
|
| 115 |
SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
|
| 116 |
FROM {table_path}
|
|
@@ -119,7 +179,7 @@ def indicator_for_given_year_query(
|
|
| 119 |
ORDER BY latitude, longitude, scenario
|
| 120 |
"""
|
| 121 |
elif country_code in HUGE_MACRO_COUNTRIES:
|
| 122 |
-
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}
|
| 123 |
sql_query = f"""
|
| 124 |
SELECT latitude, longitude, scenario, {indicator_column}
|
| 125 |
FROM {table_path}
|
|
@@ -141,4 +201,4 @@ def indicator_for_given_year_query(
|
|
| 141 |
ORDER BY latitude, longitude, scenario
|
| 142 |
"""
|
| 143 |
|
| 144 |
-
return sql_query.strip()
|
|
|
|
| 43 |
return ""
|
| 44 |
|
| 45 |
if country_code in MACRO_COUNTRIES:
|
| 46 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_monthly_macro.parquet'"
|
| 47 |
sql_query = f"""
|
| 48 |
SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
|
| 49 |
FROM {table_path}
|
|
|
|
| 52 |
ORDER BY year, scenario
|
| 53 |
"""
|
| 54 |
elif country_code in HUGE_MACRO_COUNTRIES:
|
| 55 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_annualy_macro.parquet'"
|
| 56 |
sql_query = f"""
|
| 57 |
SELECT year, scenario, {indicator_column}
|
| 58 |
FROM {table_path}
|
|
|
|
| 75 |
"""
|
| 76 |
return sql_query.strip()
|
| 77 |
|
| 78 |
+
class IndicatorPerYearAndSpecificMonthAtLocationQueryParams(TypedDict, total=False):
|
| 79 |
+
"""
|
| 80 |
+
Parameters for querying the evolution of an indicator per year for a specific month at a specific location.
|
| 81 |
+
|
| 82 |
+
Attributes:
|
| 83 |
+
indicator_column (str): Name of the climate indicator column.
|
| 84 |
+
latitude (str): Latitude of the location.
|
| 85 |
+
longitude (str): Longitude of the location.
|
| 86 |
+
country_code (str): Country code.
|
| 87 |
+
month (str): Month targeted
|
| 88 |
+
"""
|
| 89 |
+
indicator_column: str
|
| 90 |
+
latitude: str
|
| 91 |
+
longitude: str
|
| 92 |
+
country_code: str
|
| 93 |
+
month: str
|
| 94 |
+
|
| 95 |
+
def indicator_per_year_and_specific_month_at_location_query(
|
| 96 |
+
table: str, params: IndicatorPerYearAndSpecificMonthAtLocationQueryParams
|
| 97 |
+
) -> str:
|
| 98 |
+
"""
|
| 99 |
+
Builds an SQL query to get the evolution of an indicator per year for a specific month at a specific location.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
table (str): SQL table of the indicator.
|
| 103 |
+
params (dict): Dictionary with required params:
|
| 104 |
+
- indicator_column (str)
|
| 105 |
+
- latitude (str or float)
|
| 106 |
+
- longitude (str or float)
|
| 107 |
+
- month (int)
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
str: The SQL query string.
|
| 111 |
+
"""
|
| 112 |
+
indicator_column = params.get("indicator_column")
|
| 113 |
+
latitude = params.get("latitude")
|
| 114 |
+
longitude = params.get("longitude")
|
| 115 |
+
country_code = params.get("country_code")
|
| 116 |
+
month = params.get('month_number')
|
| 117 |
+
|
| 118 |
+
if not all([indicator_column, latitude, longitude, country_code, month]):
|
| 119 |
+
return ""
|
| 120 |
+
|
| 121 |
+
if country_code in (MACRO_COUNTRIES+HUGE_MACRO_COUNTRIES):
|
| 122 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_monthly_macro.parquet'"
|
| 123 |
+
sql_query = f"""
|
| 124 |
+
SELECT year, scenario, {indicator_column}
|
| 125 |
+
FROM {table_path}
|
| 126 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950 AND month={month}
|
| 127 |
+
ORDER BY year, scenario
|
| 128 |
+
"""
|
| 129 |
+
else:
|
| 130 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
|
| 131 |
+
sql_query = f"""
|
| 132 |
+
SELECT year, scenario, MEDIAN({indicator_column}) AS {indicator_column}
|
| 133 |
+
FROM {table_path}
|
| 134 |
+
WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950 AND month={month}
|
| 135 |
+
GROUP BY scenario, year
|
| 136 |
+
"""
|
| 137 |
+
return sql_query.strip()
|
| 138 |
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
| 139 |
"""
|
| 140 |
Parameters for querying an indicator's values across locations for a specific year.
|
|
|
|
| 170 |
return ""
|
| 171 |
|
| 172 |
if country_code in MACRO_COUNTRIES:
|
| 173 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_monthly_macro.parquet'"
|
| 174 |
sql_query = f"""
|
| 175 |
SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
|
| 176 |
FROM {table_path}
|
|
|
|
| 179 |
ORDER BY latitude, longitude, scenario
|
| 180 |
"""
|
| 181 |
elif country_code in HUGE_MACRO_COUNTRIES:
|
| 182 |
+
table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_annualy_macro.parquet'"
|
| 183 |
sql_query = f"""
|
| 184 |
SELECT latitude, longitude, scenario, {indicator_column}
|
| 185 |
FROM {table_path}
|
|
|
|
| 201 |
ORDER BY latitude, longitude, scenario
|
| 202 |
"""
|
| 203 |
|
| 204 |
+
return sql_query.strip()
|
climateqa/engine/talk_to_data/main.py
CHANGED
|
@@ -50,7 +50,7 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None
|
|
| 50 |
|
| 51 |
if "error" in final_state and final_state["error"] != "":
|
| 52 |
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
| 53 |
-
return None, None, None, None, [], [], [], 0, [], final_state["error"]
|
| 54 |
|
| 55 |
sql_query = sql_queries[index_state]
|
| 56 |
dataframe = result_dataframes[index_state]
|
|
@@ -112,7 +112,7 @@ async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None)
|
|
| 112 |
|
| 113 |
if "error" in final_state and final_state["error"] != "":
|
| 114 |
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
| 115 |
-
return None, None, None, None, [], [], [], 0, [], final_state["error"]
|
| 116 |
|
| 117 |
sql_query = sql_queries[index_state]
|
| 118 |
dataframe = result_dataframes[index_state]
|
|
|
|
| 50 |
|
| 51 |
if "error" in final_state and final_state["error"] != "":
|
| 52 |
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
| 53 |
+
return None, None, None, None, [], [], [], [], 0, [], final_state["error"]
|
| 54 |
|
| 55 |
sql_query = sql_queries[index_state]
|
| 56 |
dataframe = result_dataframes[index_state]
|
|
|
|
| 112 |
|
| 113 |
if "error" in final_state and final_state["error"] != "":
|
| 114 |
# No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
|
| 115 |
+
return None, None, None, None, [], [], [], [], 0, [], final_state["error"]
|
| 116 |
|
| 117 |
sql_query = sql_queries[index_state]
|
| 118 |
dataframe = result_dataframes[index_state]
|
climateqa/engine/talk_to_data/myVanna.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
from dotenv import load_dotenv
|
| 2 |
-
from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
|
| 3 |
-
from vanna.openai import OpenAI_Chat
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
load_dotenv()
|
| 7 |
-
|
| 8 |
-
OPENAI_API_KEY = os.getenv('THEO_API_KEY')
|
| 9 |
-
|
| 10 |
-
class MyVanna(MyCustomVectorDB, OpenAI_Chat):
|
| 11 |
-
def __init__(self, config=None):
|
| 12 |
-
MyCustomVectorDB.__init__(self, config=config)
|
| 13 |
-
OpenAI_Chat.__init__(self, config=config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/plot.py
DELETED
|
@@ -1,418 +0,0 @@
|
|
| 1 |
-
from typing import Callable, TypedDict
|
| 2 |
-
from matplotlib.figure import figaspect
|
| 3 |
-
import pandas as pd
|
| 4 |
-
from plotly.graph_objects import Figure
|
| 5 |
-
import plotly.graph_objects as go
|
| 6 |
-
import plotly.express as px
|
| 7 |
-
|
| 8 |
-
from climateqa.engine.talk_to_data.sql_query import (
|
| 9 |
-
indicator_for_given_year_query,
|
| 10 |
-
indicator_per_year_at_location_query,
|
| 11 |
-
)
|
| 12 |
-
from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class Plot(TypedDict):
|
| 18 |
-
"""Represents a plot configuration in the DRIAS system.
|
| 19 |
-
|
| 20 |
-
This class defines the structure for configuring different types of plots
|
| 21 |
-
that can be generated from climate data.
|
| 22 |
-
|
| 23 |
-
Attributes:
|
| 24 |
-
name (str): The name of the plot type
|
| 25 |
-
description (str): A description of what the plot shows
|
| 26 |
-
params (list[str]): List of required parameters for the plot
|
| 27 |
-
plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
|
| 28 |
-
sql_query (Callable[..., str]): Function to generate the SQL query for the plot
|
| 29 |
-
"""
|
| 30 |
-
name: str
|
| 31 |
-
description: str
|
| 32 |
-
params: list[str]
|
| 33 |
-
plot_function: Callable[..., Callable[..., Figure]]
|
| 34 |
-
sql_query: Callable[..., str]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
| 38 |
-
"""Generates a function to plot indicator evolution over time at a location.
|
| 39 |
-
|
| 40 |
-
This function creates a line plot showing how a climate indicator changes
|
| 41 |
-
over time at a specific location. It handles temperature, precipitation,
|
| 42 |
-
and other climate indicators.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
params (dict): Dictionary containing:
|
| 46 |
-
- indicator_column (str): The column name for the indicator
|
| 47 |
-
- location (str): The location to plot
|
| 48 |
-
- model (str): The climate model to use
|
| 49 |
-
|
| 50 |
-
Returns:
|
| 51 |
-
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
| 52 |
-
|
| 53 |
-
Example:
|
| 54 |
-
>>> plot_func = plot_indicator_evolution_at_location({
|
| 55 |
-
... 'indicator_column': 'mean_temperature',
|
| 56 |
-
... 'location': 'Paris',
|
| 57 |
-
... 'model': 'ALL'
|
| 58 |
-
... })
|
| 59 |
-
>>> fig = plot_func(df)
|
| 60 |
-
"""
|
| 61 |
-
indicator = params["indicator_column"]
|
| 62 |
-
location = params["location"]
|
| 63 |
-
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
| 64 |
-
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
| 65 |
-
|
| 66 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
| 67 |
-
"""Generates the actual plot from the data.
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
df (pd.DataFrame): DataFrame containing the data to plot
|
| 71 |
-
|
| 72 |
-
Returns:
|
| 73 |
-
Figure: A plotly Figure object showing the indicator evolution
|
| 74 |
-
"""
|
| 75 |
-
fig = go.Figure()
|
| 76 |
-
if df['model'].nunique() != 1:
|
| 77 |
-
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
| 78 |
-
|
| 79 |
-
# Transform to list to avoid pandas encoding
|
| 80 |
-
indicators = df_avg[indicator].astype(float).tolist()
|
| 81 |
-
years = df_avg["year"].astype(int).tolist()
|
| 82 |
-
|
| 83 |
-
# Compute the 10-year rolling average
|
| 84 |
-
rolling_window = 10
|
| 85 |
-
sliding_averages = (
|
| 86 |
-
df_avg[indicator]
|
| 87 |
-
.rolling(window=rolling_window, min_periods=rolling_window)
|
| 88 |
-
.mean()
|
| 89 |
-
.astype(float)
|
| 90 |
-
.tolist()
|
| 91 |
-
)
|
| 92 |
-
model_label = "Model Average"
|
| 93 |
-
|
| 94 |
-
# Only add rolling average if we have enough data points
|
| 95 |
-
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
| 96 |
-
# Sliding average dashed line
|
| 97 |
-
fig.add_scatter(
|
| 98 |
-
x=years,
|
| 99 |
-
y=sliding_averages,
|
| 100 |
-
mode="lines",
|
| 101 |
-
name="10 years rolling average",
|
| 102 |
-
line=dict(dash="dash"),
|
| 103 |
-
marker=dict(color="#d62728"),
|
| 104 |
-
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
else:
|
| 108 |
-
df_model = df
|
| 109 |
-
|
| 110 |
-
# Transform to list to avoid pandas encoding
|
| 111 |
-
indicators = df_model[indicator].astype(float).tolist()
|
| 112 |
-
years = df_model["year"].astype(int).tolist()
|
| 113 |
-
|
| 114 |
-
# Compute the 10-year rolling average
|
| 115 |
-
rolling_window = 10
|
| 116 |
-
sliding_averages = (
|
| 117 |
-
df_model[indicator]
|
| 118 |
-
.rolling(window=rolling_window, min_periods=rolling_window)
|
| 119 |
-
.mean()
|
| 120 |
-
.astype(float)
|
| 121 |
-
.tolist()
|
| 122 |
-
)
|
| 123 |
-
model_label = f"Model : {df['model'].unique()[0]}"
|
| 124 |
-
|
| 125 |
-
# Only add rolling average if we have enough data points
|
| 126 |
-
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
| 127 |
-
# Sliding average dashed line
|
| 128 |
-
fig.add_scatter(
|
| 129 |
-
x=years,
|
| 130 |
-
y=sliding_averages,
|
| 131 |
-
mode="lines",
|
| 132 |
-
name="10 years rolling average",
|
| 133 |
-
line=dict(dash="dash"),
|
| 134 |
-
marker=dict(color="#d62728"),
|
| 135 |
-
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
# Indicator per year plot
|
| 139 |
-
fig.add_scatter(
|
| 140 |
-
x=years,
|
| 141 |
-
y=indicators,
|
| 142 |
-
name=f"Yearly {indicator_label}",
|
| 143 |
-
mode="lines",
|
| 144 |
-
marker=dict(color="#1f77b4"),
|
| 145 |
-
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
| 146 |
-
)
|
| 147 |
-
fig.update_layout(
|
| 148 |
-
title=f"Plot of {indicator_label} in {location} ({model_label})",
|
| 149 |
-
xaxis_title="Year",
|
| 150 |
-
yaxis_title=f"{indicator_label} ({unit})",
|
| 151 |
-
template="plotly_white",
|
| 152 |
-
)
|
| 153 |
-
return fig
|
| 154 |
-
|
| 155 |
-
return plot_data
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
indicator_evolution_at_location: Plot = {
|
| 159 |
-
"name": "Indicator evolution at location",
|
| 160 |
-
"description": "Plot an evolution of the indicator at a certain location",
|
| 161 |
-
"params": ["indicator_column", "location", "model"],
|
| 162 |
-
"plot_function": plot_indicator_evolution_at_location,
|
| 163 |
-
"sql_query": indicator_per_year_at_location_query,
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def plot_indicator_number_of_days_per_year_at_location(
|
| 168 |
-
params: dict,
|
| 169 |
-
) -> Callable[..., Figure]:
|
| 170 |
-
"""Generates a function to plot the number of days per year for an indicator.
|
| 171 |
-
|
| 172 |
-
This function creates a bar chart showing the frequency of certain climate
|
| 173 |
-
events (like days above a temperature threshold) per year at a specific location.
|
| 174 |
-
|
| 175 |
-
Args:
|
| 176 |
-
params (dict): Dictionary containing:
|
| 177 |
-
- indicator_column (str): The column name for the indicator
|
| 178 |
-
- location (str): The location to plot
|
| 179 |
-
- model (str): The climate model to use
|
| 180 |
-
|
| 181 |
-
Returns:
|
| 182 |
-
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
| 183 |
-
"""
|
| 184 |
-
indicator = params["indicator_column"]
|
| 185 |
-
location = params["location"]
|
| 186 |
-
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
| 187 |
-
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
| 188 |
-
|
| 189 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
| 190 |
-
"""Generate the figure thanks to the dataframe
|
| 191 |
-
|
| 192 |
-
Args:
|
| 193 |
-
df (pd.DataFrame): pandas dataframe with the required data
|
| 194 |
-
|
| 195 |
-
Returns:
|
| 196 |
-
Figure: Plotly figure
|
| 197 |
-
"""
|
| 198 |
-
fig = go.Figure()
|
| 199 |
-
if df['model'].nunique() != 1:
|
| 200 |
-
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
| 201 |
-
|
| 202 |
-
# Transform to list to avoid pandas encoding
|
| 203 |
-
indicators = df_avg[indicator].astype(float).tolist()
|
| 204 |
-
years = df_avg["year"].astype(int).tolist()
|
| 205 |
-
model_label = "Model Average"
|
| 206 |
-
|
| 207 |
-
else:
|
| 208 |
-
df_model = df
|
| 209 |
-
# Transform to list to avoid pandas encoding
|
| 210 |
-
indicators = df_model[indicator].astype(float).tolist()
|
| 211 |
-
years = df_model["year"].astype(int).tolist()
|
| 212 |
-
model_label = f"Model : {df['model'].unique()[0]}"
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
# Bar plot
|
| 216 |
-
fig.add_trace(
|
| 217 |
-
go.Bar(
|
| 218 |
-
x=years,
|
| 219 |
-
y=indicators,
|
| 220 |
-
width=0.5,
|
| 221 |
-
marker=dict(color="#1f77b4"),
|
| 222 |
-
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
| 223 |
-
)
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
fig.update_layout(
|
| 227 |
-
title=f"{indicator_label} in {location} ({model_label})",
|
| 228 |
-
xaxis_title="Year",
|
| 229 |
-
yaxis_title=f"{indicator_label} ({unit})",
|
| 230 |
-
yaxis=dict(range=[0, max(indicators)]),
|
| 231 |
-
bargap=0.5,
|
| 232 |
-
template="plotly_white",
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
return fig
|
| 236 |
-
|
| 237 |
-
return plot_data
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
indicator_number_of_days_per_year_at_location: Plot = {
|
| 241 |
-
"name": "Indicator number of days per year at location",
|
| 242 |
-
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
|
| 243 |
-
"params": ["indicator_column", "location", "model"],
|
| 244 |
-
"plot_function": plot_indicator_number_of_days_per_year_at_location,
|
| 245 |
-
"sql_query": indicator_per_year_at_location_query,
|
| 246 |
-
}
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
def plot_distribution_of_indicator_for_given_year(
|
| 250 |
-
params: dict,
|
| 251 |
-
) -> Callable[..., Figure]:
|
| 252 |
-
"""Generates a function to plot the distribution of an indicator for a year.
|
| 253 |
-
|
| 254 |
-
This function creates a histogram showing the distribution of a climate
|
| 255 |
-
indicator across different locations for a specific year.
|
| 256 |
-
|
| 257 |
-
Args:
|
| 258 |
-
params (dict): Dictionary containing:
|
| 259 |
-
- indicator_column (str): The column name for the indicator
|
| 260 |
-
- year (str): The year to plot
|
| 261 |
-
- model (str): The climate model to use
|
| 262 |
-
|
| 263 |
-
Returns:
|
| 264 |
-
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
| 265 |
-
"""
|
| 266 |
-
indicator = params["indicator_column"]
|
| 267 |
-
year = params["year"]
|
| 268 |
-
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
| 269 |
-
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
| 270 |
-
|
| 271 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
| 272 |
-
"""Generate the figure thanks to the dataframe
|
| 273 |
-
|
| 274 |
-
Args:
|
| 275 |
-
df (pd.DataFrame): pandas dataframe with the required data
|
| 276 |
-
|
| 277 |
-
Returns:
|
| 278 |
-
Figure: Plotly figure
|
| 279 |
-
"""
|
| 280 |
-
fig = go.Figure()
|
| 281 |
-
if df['model'].nunique() != 1:
|
| 282 |
-
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
| 283 |
-
indicator
|
| 284 |
-
].mean()
|
| 285 |
-
|
| 286 |
-
# Transform to list to avoid pandas encoding
|
| 287 |
-
indicators = df_avg[indicator].astype(float).tolist()
|
| 288 |
-
model_label = "Model Average"
|
| 289 |
-
|
| 290 |
-
else:
|
| 291 |
-
df_model = df
|
| 292 |
-
|
| 293 |
-
# Transform to list to avoid pandas encoding
|
| 294 |
-
indicators = df_model[indicator].astype(float).tolist()
|
| 295 |
-
model_label = f"Model : {df['model'].unique()[0]}"
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
fig.add_trace(
|
| 299 |
-
go.Histogram(
|
| 300 |
-
x=indicators,
|
| 301 |
-
opacity=0.8,
|
| 302 |
-
histnorm="percent",
|
| 303 |
-
marker=dict(color="#1f77b4"),
|
| 304 |
-
hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
|
| 305 |
-
)
|
| 306 |
-
)
|
| 307 |
-
|
| 308 |
-
fig.update_layout(
|
| 309 |
-
title=f"Distribution of {indicator_label} in {year} ({model_label})",
|
| 310 |
-
xaxis_title=f"{indicator_label} ({unit})",
|
| 311 |
-
yaxis_title="Frequency (%)",
|
| 312 |
-
plot_bgcolor="rgba(0, 0, 0, 0)",
|
| 313 |
-
showlegend=False,
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
return fig
|
| 317 |
-
|
| 318 |
-
return plot_data
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
distribution_of_indicator_for_given_year: Plot = {
|
| 322 |
-
"name": "Distribution of an indicator for a given year",
|
| 323 |
-
"description": "Plot an histogram of the distribution for a given year of the values of an indicator",
|
| 324 |
-
"params": ["indicator_column", "model", "year"],
|
| 325 |
-
"plot_function": plot_distribution_of_indicator_for_given_year,
|
| 326 |
-
"sql_query": indicator_for_given_year_query,
|
| 327 |
-
}
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
def plot_map_of_france_of_indicator_for_given_year(
|
| 331 |
-
params: dict,
|
| 332 |
-
) -> Callable[..., Figure]:
|
| 333 |
-
"""Generates a function to plot a map of France for an indicator.
|
| 334 |
-
|
| 335 |
-
This function creates a choropleth map of France showing the spatial
|
| 336 |
-
distribution of a climate indicator for a specific year.
|
| 337 |
-
|
| 338 |
-
Args:
|
| 339 |
-
params (dict): Dictionary containing:
|
| 340 |
-
- indicator_column (str): The column name for the indicator
|
| 341 |
-
- year (str): The year to plot
|
| 342 |
-
- model (str): The climate model to use
|
| 343 |
-
|
| 344 |
-
Returns:
|
| 345 |
-
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
| 346 |
-
"""
|
| 347 |
-
indicator = params["indicator_column"]
|
| 348 |
-
year = params["year"]
|
| 349 |
-
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
| 350 |
-
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
| 351 |
-
|
| 352 |
-
def plot_data(df: pd.DataFrame) -> Figure:
|
| 353 |
-
fig = go.Figure()
|
| 354 |
-
if df['model'].nunique() != 1:
|
| 355 |
-
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
| 356 |
-
indicator
|
| 357 |
-
].mean()
|
| 358 |
-
|
| 359 |
-
indicators = df_avg[indicator].astype(float).tolist()
|
| 360 |
-
latitudes = df_avg["latitude"].astype(float).tolist()
|
| 361 |
-
longitudes = df_avg["longitude"].astype(float).tolist()
|
| 362 |
-
model_label = "Model Average"
|
| 363 |
-
|
| 364 |
-
else:
|
| 365 |
-
df_model = df
|
| 366 |
-
|
| 367 |
-
# Transform to list to avoid pandas encoding
|
| 368 |
-
indicators = df_model[indicator].astype(float).tolist()
|
| 369 |
-
latitudes = df_model["latitude"].astype(float).tolist()
|
| 370 |
-
longitudes = df_model["longitude"].astype(float).tolist()
|
| 371 |
-
model_label = f"Model : {df['model'].unique()[0]}"
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
fig.add_trace(
|
| 375 |
-
go.Scattermapbox(
|
| 376 |
-
lat=latitudes,
|
| 377 |
-
lon=longitudes,
|
| 378 |
-
mode="markers",
|
| 379 |
-
marker=dict(
|
| 380 |
-
size=10,
|
| 381 |
-
color=indicators, # Color mapped to values
|
| 382 |
-
colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
|
| 383 |
-
cmin=min(indicators), # Minimum color range
|
| 384 |
-
cmax=max(indicators), # Maximum color range
|
| 385 |
-
showscale=True, # Show colorbar
|
| 386 |
-
),
|
| 387 |
-
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
| 388 |
-
hoverinfo="text" # Only show the custom text on hover
|
| 389 |
-
)
|
| 390 |
-
)
|
| 391 |
-
|
| 392 |
-
fig.update_layout(
|
| 393 |
-
mapbox_style="open-street-map", # Use OpenStreetMap
|
| 394 |
-
mapbox_zoom=3,
|
| 395 |
-
mapbox_center={"lat": 46.6, "lon": 2.0},
|
| 396 |
-
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
|
| 397 |
-
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
|
| 398 |
-
)
|
| 399 |
-
return fig
|
| 400 |
-
|
| 401 |
-
return plot_data
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
map_of_france_of_indicator_for_given_year: Plot = {
|
| 405 |
-
"name": "Map of France of an indicator for a given year",
|
| 406 |
-
"description": "Heatmap on the map of France of the values of an in indicator for a given year",
|
| 407 |
-
"params": ["indicator_column", "year", "model"],
|
| 408 |
-
"plot_function": plot_map_of_france_of_indicator_for_given_year,
|
| 409 |
-
"sql_query": indicator_for_given_year_query,
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
PLOTS = [
|
| 414 |
-
indicator_evolution_at_location,
|
| 415 |
-
indicator_number_of_days_per_year_at_location,
|
| 416 |
-
distribution_of_indicator_for_given_year,
|
| 417 |
-
map_of_france_of_indicator_for_given_year,
|
| 418 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/sql_query.py
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
import asyncio
|
| 2 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
-
from typing import TypedDict
|
| 4 |
-
import duckdb
|
| 5 |
-
import pandas as pd
|
| 6 |
-
|
| 7 |
-
async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
| 8 |
-
"""Executes a SQL query on the DRIAS database and returns the results.
|
| 9 |
-
|
| 10 |
-
This function connects to the DuckDB database containing DRIAS climate data
|
| 11 |
-
and executes the provided SQL query. It handles the database connection and
|
| 12 |
-
returns the results as a pandas DataFrame.
|
| 13 |
-
|
| 14 |
-
Args:
|
| 15 |
-
sql_query (str): The SQL query to execute
|
| 16 |
-
|
| 17 |
-
Returns:
|
| 18 |
-
pd.DataFrame: A DataFrame containing the query results
|
| 19 |
-
|
| 20 |
-
Raises:
|
| 21 |
-
duckdb.Error: If there is an error executing the SQL query
|
| 22 |
-
"""
|
| 23 |
-
def _execute_query():
|
| 24 |
-
# Execute the query
|
| 25 |
-
con = duckdb.connect()
|
| 26 |
-
results = con.sql(sql_query).fetchdf()
|
| 27 |
-
# return fetched data
|
| 28 |
-
return results
|
| 29 |
-
|
| 30 |
-
# Run the query in a thread pool to avoid blocking
|
| 31 |
-
loop = asyncio.get_event_loop()
|
| 32 |
-
with ThreadPoolExecutor() as executor:
|
| 33 |
-
return await loop.run_in_executor(executor, _execute_query)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
| 37 |
-
"""Parameters for querying an indicator's values over time at a location.
|
| 38 |
-
|
| 39 |
-
This class defines the parameters needed to query climate indicator data
|
| 40 |
-
for a specific location over multiple years.
|
| 41 |
-
|
| 42 |
-
Attributes:
|
| 43 |
-
indicator_column (str): The column name for the climate indicator
|
| 44 |
-
latitude (str): The latitude coordinate of the location
|
| 45 |
-
longitude (str): The longitude coordinate of the location
|
| 46 |
-
model (str): The climate model to use (optional)
|
| 47 |
-
"""
|
| 48 |
-
indicator_column: str
|
| 49 |
-
latitude: str
|
| 50 |
-
longitude: str
|
| 51 |
-
model: str
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def indicator_per_year_at_location_query(
|
| 55 |
-
table: str, params: IndicatorPerYearAtLocationQueryParams
|
| 56 |
-
) -> str:
|
| 57 |
-
"""SQL Query to get the evolution of an indicator per year at a certain location
|
| 58 |
-
|
| 59 |
-
Args:
|
| 60 |
-
table (str): sql table of the indicator
|
| 61 |
-
params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
|
| 62 |
-
|
| 63 |
-
Returns:
|
| 64 |
-
str: the sql query
|
| 65 |
-
"""
|
| 66 |
-
indicator_column = params.get("indicator_column")
|
| 67 |
-
latitude = params.get("latitude")
|
| 68 |
-
longitude = params.get("longitude")
|
| 69 |
-
|
| 70 |
-
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
| 71 |
-
return ""
|
| 72 |
-
|
| 73 |
-
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
| 74 |
-
|
| 75 |
-
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
|
| 76 |
-
|
| 77 |
-
return sql_query
|
| 78 |
-
|
| 79 |
-
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
| 80 |
-
"""Parameters for querying an indicator's values across locations for a year.
|
| 81 |
-
|
| 82 |
-
This class defines the parameters needed to query climate indicator data
|
| 83 |
-
across different locations for a specific year.
|
| 84 |
-
|
| 85 |
-
Attributes:
|
| 86 |
-
indicator_column (str): The column name for the climate indicator
|
| 87 |
-
year (str): The year to query
|
| 88 |
-
model (str): The climate model to use (optional)
|
| 89 |
-
"""
|
| 90 |
-
indicator_column: str
|
| 91 |
-
year: str
|
| 92 |
-
model: str
|
| 93 |
-
|
| 94 |
-
def indicator_for_given_year_query(
|
| 95 |
-
table:str, params: IndicatorForGivenYearQueryParams
|
| 96 |
-
) -> str:
|
| 97 |
-
"""SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
|
| 98 |
-
|
| 99 |
-
Args:
|
| 100 |
-
table (str): sql table of the indicator
|
| 101 |
-
params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
|
| 102 |
-
|
| 103 |
-
Returns:
|
| 104 |
-
str: the sql query
|
| 105 |
-
"""
|
| 106 |
-
indicator_column = params.get("indicator_column")
|
| 107 |
-
year = params.get('year')
|
| 108 |
-
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
| 109 |
-
return ""
|
| 110 |
-
|
| 111 |
-
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
| 112 |
-
|
| 113 |
-
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
| 114 |
-
return sql_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/talk_to_drias.py
DELETED
|
@@ -1,317 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
from typing import Any, Callable, TypedDict, Optional
|
| 4 |
-
from numpy import sort
|
| 5 |
-
import pandas as pd
|
| 6 |
-
import asyncio
|
| 7 |
-
from plotly.graph_objects import Figure
|
| 8 |
-
from climateqa.engine.llm import get_llm
|
| 9 |
-
from climateqa.engine.talk_to_data import sql_query
|
| 10 |
-
from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
|
| 11 |
-
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
| 12 |
-
from climateqa.engine.talk_to_data.sql_query import execute_sql_query
|
| 13 |
-
from climateqa.engine.talk_to_data.utils import (
|
| 14 |
-
detect_relevant_plots,
|
| 15 |
-
detect_year_with_openai,
|
| 16 |
-
loc2coords,
|
| 17 |
-
detect_location_with_openai,
|
| 18 |
-
nearestNeighbourSQL,
|
| 19 |
-
detect_relevant_tables,
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
| 24 |
-
|
| 25 |
-
class TableState(TypedDict):
|
| 26 |
-
"""Represents the state of a table in the DRIAS workflow.
|
| 27 |
-
|
| 28 |
-
This class defines the structure for tracking the state of a table during the
|
| 29 |
-
data processing workflow, including its name, parameters, SQL query, and results.
|
| 30 |
-
|
| 31 |
-
Attributes:
|
| 32 |
-
table_name (str): The name of the table in the database
|
| 33 |
-
params (dict[str, Any]): Parameters used for querying the table
|
| 34 |
-
sql_query (str, optional): The SQL query used to fetch data
|
| 35 |
-
dataframe (pd.DataFrame | None, optional): The resulting data
|
| 36 |
-
figure (Callable[..., Figure], optional): Function to generate visualization
|
| 37 |
-
status (str): The current status of the table processing ('OK' or 'ERROR')
|
| 38 |
-
"""
|
| 39 |
-
table_name: str
|
| 40 |
-
params: dict[str, Any]
|
| 41 |
-
sql_query: Optional[str]
|
| 42 |
-
dataframe: Optional[pd.DataFrame | None]
|
| 43 |
-
figure: Optional[Callable[..., Figure]]
|
| 44 |
-
status: str
|
| 45 |
-
|
| 46 |
-
class PlotState(TypedDict):
|
| 47 |
-
"""Represents the state of a plot in the DRIAS workflow.
|
| 48 |
-
|
| 49 |
-
This class defines the structure for tracking the state of a plot during the
|
| 50 |
-
data processing workflow, including its name and associated tables.
|
| 51 |
-
|
| 52 |
-
Attributes:
|
| 53 |
-
plot_name (str): The name of the plot
|
| 54 |
-
tables (list[str]): List of tables used in the plot
|
| 55 |
-
table_states (dict[str, TableState]): States of the tables used in the plot
|
| 56 |
-
"""
|
| 57 |
-
plot_name: str
|
| 58 |
-
tables: list[str]
|
| 59 |
-
table_states: dict[str, TableState]
|
| 60 |
-
|
| 61 |
-
class State(TypedDict):
|
| 62 |
-
user_input: str
|
| 63 |
-
plots: list[str]
|
| 64 |
-
plot_states: dict[str, PlotState]
|
| 65 |
-
error: Optional[str]
|
| 66 |
-
|
| 67 |
-
async def find_relevant_plots(state: State, llm) -> list[str]:
|
| 68 |
-
print("---- Find relevant plots ----")
|
| 69 |
-
relevant_plots = await detect_relevant_plots(state['user_input'], llm)
|
| 70 |
-
return relevant_plots
|
| 71 |
-
|
| 72 |
-
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
|
| 73 |
-
print(f"---- Find relevant tables for {plot['name']} ----")
|
| 74 |
-
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
|
| 75 |
-
return relevant_tables
|
| 76 |
-
|
| 77 |
-
async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
|
| 78 |
-
"""Perform the good method to retrieve the desired parameter
|
| 79 |
-
|
| 80 |
-
Args:
|
| 81 |
-
state (State): state of the workflow
|
| 82 |
-
param_name (str): name of the desired parameter
|
| 83 |
-
table (str): name of the table
|
| 84 |
-
|
| 85 |
-
Returns:
|
| 86 |
-
dict[str, Any] | None:
|
| 87 |
-
"""
|
| 88 |
-
if param_name == 'location':
|
| 89 |
-
location = await find_location(state['user_input'], table)
|
| 90 |
-
return location
|
| 91 |
-
if param_name == 'year':
|
| 92 |
-
year = await find_year(state['user_input'])
|
| 93 |
-
return {'year': year}
|
| 94 |
-
return None
|
| 95 |
-
|
| 96 |
-
class Location(TypedDict):
|
| 97 |
-
location: str
|
| 98 |
-
latitude: Optional[str]
|
| 99 |
-
longitude: Optional[str]
|
| 100 |
-
|
| 101 |
-
async def find_location(user_input: str, table: str) -> Location:
|
| 102 |
-
print(f"---- Find location in table {table} ----")
|
| 103 |
-
location = await detect_location_with_openai(user_input)
|
| 104 |
-
output: Location = {'location' : location}
|
| 105 |
-
if location:
|
| 106 |
-
coords = loc2coords(location)
|
| 107 |
-
neighbour = nearestNeighbourSQL(coords, table)
|
| 108 |
-
output.update({
|
| 109 |
-
"latitude": neighbour[0],
|
| 110 |
-
"longitude": neighbour[1],
|
| 111 |
-
})
|
| 112 |
-
return output
|
| 113 |
-
|
| 114 |
-
async def find_year(user_input: str) -> str:
|
| 115 |
-
"""Extracts year information from user input using LLM.
|
| 116 |
-
|
| 117 |
-
This function uses an LLM to identify and extract year information from the
|
| 118 |
-
user's query, which is used to filter data in subsequent queries.
|
| 119 |
-
|
| 120 |
-
Args:
|
| 121 |
-
user_input (str): The user's query text
|
| 122 |
-
|
| 123 |
-
Returns:
|
| 124 |
-
str: The extracted year, or empty string if no year found
|
| 125 |
-
"""
|
| 126 |
-
print(f"---- Find year ---")
|
| 127 |
-
year = await detect_year_with_openai(user_input)
|
| 128 |
-
return year
|
| 129 |
-
|
| 130 |
-
def find_indicator_column(table: str) -> str:
|
| 131 |
-
"""Retrieves the name of the indicator column within a table.
|
| 132 |
-
|
| 133 |
-
This function maps table names to their corresponding indicator columns
|
| 134 |
-
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
|
| 135 |
-
|
| 136 |
-
Args:
|
| 137 |
-
table (str): Name of the table in the database
|
| 138 |
-
|
| 139 |
-
Returns:
|
| 140 |
-
str: Name of the indicator column for the specified table
|
| 141 |
-
|
| 142 |
-
Raises:
|
| 143 |
-
KeyError: If the table name is not found in the mapping
|
| 144 |
-
"""
|
| 145 |
-
print(f"---- Find indicator column in table {table} ----")
|
| 146 |
-
return INDICATOR_COLUMNS_PER_TABLE[table]
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
async def process_table(
|
| 150 |
-
table: str,
|
| 151 |
-
params: dict[str, Any],
|
| 152 |
-
plot: Plot,
|
| 153 |
-
) -> TableState:
|
| 154 |
-
"""Processes a table to extract relevant data and generate visualizations.
|
| 155 |
-
|
| 156 |
-
This function retrieves the SQL query for the specified table, executes it,
|
| 157 |
-
and generates a visualization based on the results.
|
| 158 |
-
|
| 159 |
-
Args:
|
| 160 |
-
table (str): The name of the table to process
|
| 161 |
-
params (dict[str, Any]): Parameters used for querying the table
|
| 162 |
-
plot (Plot): The plot object containing SQL query and visualization function
|
| 163 |
-
|
| 164 |
-
Returns:
|
| 165 |
-
TableState: The state of the processed table
|
| 166 |
-
"""
|
| 167 |
-
table_state: TableState = {
|
| 168 |
-
'table_name': table,
|
| 169 |
-
'params': params.copy(),
|
| 170 |
-
'status': 'OK',
|
| 171 |
-
'dataframe': None,
|
| 172 |
-
'sql_query': None,
|
| 173 |
-
'figure': None
|
| 174 |
-
}
|
| 175 |
-
|
| 176 |
-
table_state['params']['indicator_column'] = find_indicator_column(table)
|
| 177 |
-
sql_query = plot['sql_query'](table, table_state['params'])
|
| 178 |
-
|
| 179 |
-
if sql_query == "":
|
| 180 |
-
table_state['status'] = 'ERROR'
|
| 181 |
-
return table_state
|
| 182 |
-
table_state['sql_query'] = sql_query
|
| 183 |
-
df = await execute_sql_query(sql_query)
|
| 184 |
-
|
| 185 |
-
table_state['dataframe'] = df
|
| 186 |
-
table_state['figure'] = plot['plot_function'](table_state['params'])
|
| 187 |
-
|
| 188 |
-
return table_state
|
| 189 |
-
|
| 190 |
-
async def drias_workflow(user_input: str) -> State:
|
| 191 |
-
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
|
| 192 |
-
|
| 193 |
-
Args:
|
| 194 |
-
user_input (str): initial user input
|
| 195 |
-
|
| 196 |
-
Returns:
|
| 197 |
-
State: Final state with all the results
|
| 198 |
-
"""
|
| 199 |
-
state: State = {
|
| 200 |
-
'user_input': user_input,
|
| 201 |
-
'plots': [],
|
| 202 |
-
'plot_states': {},
|
| 203 |
-
'error': ''
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
-
llm = get_llm(provider="openai")
|
| 207 |
-
|
| 208 |
-
plots = await find_relevant_plots(state, llm)
|
| 209 |
-
|
| 210 |
-
state['plots'] = plots
|
| 211 |
-
|
| 212 |
-
if len(state['plots']) < 1:
|
| 213 |
-
state['error'] = 'There is no plot to answer to the question'
|
| 214 |
-
return state
|
| 215 |
-
|
| 216 |
-
have_relevant_table = False
|
| 217 |
-
have_sql_query = False
|
| 218 |
-
have_dataframe = False
|
| 219 |
-
|
| 220 |
-
for plot_name in state['plots']:
|
| 221 |
-
|
| 222 |
-
plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
|
| 223 |
-
if plot is None:
|
| 224 |
-
continue
|
| 225 |
-
|
| 226 |
-
plot_state: PlotState = {
|
| 227 |
-
'plot_name': plot_name,
|
| 228 |
-
'tables': [],
|
| 229 |
-
'table_states': {}
|
| 230 |
-
}
|
| 231 |
-
|
| 232 |
-
plot_state['plot_name'] = plot_name
|
| 233 |
-
|
| 234 |
-
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
|
| 235 |
-
|
| 236 |
-
if len(relevant_tables) > 0 :
|
| 237 |
-
have_relevant_table = True
|
| 238 |
-
|
| 239 |
-
plot_state['tables'] = relevant_tables
|
| 240 |
-
|
| 241 |
-
params = {}
|
| 242 |
-
for param_name in plot['params']:
|
| 243 |
-
param = await find_param(state, param_name, relevant_tables[0])
|
| 244 |
-
if param:
|
| 245 |
-
params.update(param)
|
| 246 |
-
|
| 247 |
-
tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
|
| 248 |
-
results = await asyncio.gather(*tasks)
|
| 249 |
-
|
| 250 |
-
# Store results back in plot_state
|
| 251 |
-
have_dataframe = False
|
| 252 |
-
have_sql_query = False
|
| 253 |
-
for table_state in results:
|
| 254 |
-
if table_state['sql_query']:
|
| 255 |
-
have_sql_query = True
|
| 256 |
-
if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
|
| 257 |
-
have_dataframe = True
|
| 258 |
-
plot_state['table_states'][table_state['table_name']] = table_state
|
| 259 |
-
|
| 260 |
-
state['plot_states'][plot_name] = plot_state
|
| 261 |
-
|
| 262 |
-
if not have_relevant_table:
|
| 263 |
-
state['error'] = "There is no relevant table in our database to answer your question"
|
| 264 |
-
elif not have_sql_query:
|
| 265 |
-
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
| 266 |
-
elif not have_dataframe:
|
| 267 |
-
state['error'] = "There is no data in our table that can answer to your question"
|
| 268 |
-
|
| 269 |
-
return state
|
| 270 |
-
|
| 271 |
-
# def make_write_query_node():
|
| 272 |
-
|
| 273 |
-
# def write_query(state):
|
| 274 |
-
# print("---- Write query ----")
|
| 275 |
-
# for table in state["tables"]:
|
| 276 |
-
# sql_query = QUERIES[state[table]['query_type']](
|
| 277 |
-
# table=table,
|
| 278 |
-
# indicator_column=state[table]["columns"],
|
| 279 |
-
# longitude=state[table]["longitude"],
|
| 280 |
-
# latitude=state[table]["latitude"],
|
| 281 |
-
# )
|
| 282 |
-
# state[table].update({"sql_query": sql_query})
|
| 283 |
-
|
| 284 |
-
# return state
|
| 285 |
-
|
| 286 |
-
# return write_query
|
| 287 |
-
|
| 288 |
-
# def make_fetch_data_node(db_path):
|
| 289 |
-
|
| 290 |
-
# def fetch_data(state):
|
| 291 |
-
# print("---- Fetch data ----")
|
| 292 |
-
# for table in state["tables"]:
|
| 293 |
-
# results = execute_sql_query(db_path, state[table]['sql_query'])
|
| 294 |
-
# state[table].update(results)
|
| 295 |
-
|
| 296 |
-
# return state
|
| 297 |
-
|
| 298 |
-
# return fetch_data
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
## V2
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
# def make_fetch_data_node(db_path: str, llm):
|
| 306 |
-
# def fetch_data(state):
|
| 307 |
-
# print("---- Fetch data ----")
|
| 308 |
-
# db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
| 309 |
-
# output = {}
|
| 310 |
-
# sql_query = write_sql_query(state["query"], db, state["tables"], llm)
|
| 311 |
-
# # TO DO : Add query checker
|
| 312 |
-
# print(f"SQL query : {sql_query}")
|
| 313 |
-
# output["sql_query"] = sql_query
|
| 314 |
-
# output.update(fetch_data_from_sql_query(db_path, sql_query))
|
| 315 |
-
# return output
|
| 316 |
-
|
| 317 |
-
# return fetch_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/utils.py
DELETED
|
@@ -1,281 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
from typing import Annotated, TypedDict
|
| 3 |
-
import duckdb
|
| 4 |
-
from geopy.geocoders import Nominatim
|
| 5 |
-
import ast
|
| 6 |
-
from climateqa.engine.llm import get_llm
|
| 7 |
-
from climateqa.engine.talk_to_data.config import DRIAS_TABLES
|
| 8 |
-
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
| 9 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
async def detect_location_with_openai(sentence):
|
| 13 |
-
"""
|
| 14 |
-
Detects locations in a sentence using OpenAI's API via LangChain.
|
| 15 |
-
"""
|
| 16 |
-
llm = get_llm()
|
| 17 |
-
|
| 18 |
-
prompt = f"""
|
| 19 |
-
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
| 20 |
-
Return the result as a Python list. If no locations are mentioned, return an empty list.
|
| 21 |
-
|
| 22 |
-
Sentence: "{sentence}"
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
response = await llm.ainvoke(prompt)
|
| 26 |
-
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
| 27 |
-
if location_list:
|
| 28 |
-
return location_list[0]
|
| 29 |
-
else:
|
| 30 |
-
return ""
|
| 31 |
-
|
| 32 |
-
class ArrayOutput(TypedDict):
|
| 33 |
-
"""Represents the output of a function that returns an array.
|
| 34 |
-
|
| 35 |
-
This class is used to type-hint functions that return arrays,
|
| 36 |
-
ensuring consistent return types across the codebase.
|
| 37 |
-
|
| 38 |
-
Attributes:
|
| 39 |
-
array (str): A syntactically valid Python array string
|
| 40 |
-
"""
|
| 41 |
-
array: Annotated[str, "Syntactically valid python array."]
|
| 42 |
-
|
| 43 |
-
async def detect_year_with_openai(sentence: str) -> str:
|
| 44 |
-
"""
|
| 45 |
-
Detects years in a sentence using OpenAI's API via LangChain.
|
| 46 |
-
"""
|
| 47 |
-
llm = get_llm()
|
| 48 |
-
|
| 49 |
-
prompt = """
|
| 50 |
-
Extract all years mentioned in the following sentence.
|
| 51 |
-
Return the result as a Python list. If no year are mentioned, return an empty list.
|
| 52 |
-
|
| 53 |
-
Sentence: "{sentence}"
|
| 54 |
-
"""
|
| 55 |
-
|
| 56 |
-
prompt = ChatPromptTemplate.from_template(prompt)
|
| 57 |
-
structured_llm = llm.with_structured_output(ArrayOutput)
|
| 58 |
-
chain = prompt | structured_llm
|
| 59 |
-
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
|
| 60 |
-
years_list = eval(response['array'])
|
| 61 |
-
if len(years_list) > 0:
|
| 62 |
-
return years_list[0]
|
| 63 |
-
else:
|
| 64 |
-
return ""
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def detectTable(sql_query: str) -> list[str]:
|
| 68 |
-
"""Extracts table names from a SQL query.
|
| 69 |
-
|
| 70 |
-
This function uses regular expressions to find all table names
|
| 71 |
-
referenced in a SQL query's FROM clause.
|
| 72 |
-
|
| 73 |
-
Args:
|
| 74 |
-
sql_query (str): The SQL query to analyze
|
| 75 |
-
|
| 76 |
-
Returns:
|
| 77 |
-
list[str]: A list of table names found in the query
|
| 78 |
-
|
| 79 |
-
Example:
|
| 80 |
-
>>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
|
| 81 |
-
['temperature_data']
|
| 82 |
-
"""
|
| 83 |
-
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
| 84 |
-
matches = re.findall(pattern, sql_query)
|
| 85 |
-
return matches
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def loc2coords(location: str) -> tuple[float, float]:
|
| 89 |
-
"""Converts a location name to geographic coordinates.
|
| 90 |
-
|
| 91 |
-
This function uses the Nominatim geocoding service to convert
|
| 92 |
-
a location name (e.g., city name) to its latitude and longitude.
|
| 93 |
-
|
| 94 |
-
Args:
|
| 95 |
-
location (str): The name of the location to geocode
|
| 96 |
-
|
| 97 |
-
Returns:
|
| 98 |
-
tuple[float, float]: A tuple containing (latitude, longitude)
|
| 99 |
-
|
| 100 |
-
Raises:
|
| 101 |
-
AttributeError: If the location cannot be found
|
| 102 |
-
"""
|
| 103 |
-
geolocator = Nominatim(user_agent="city_to_latlong")
|
| 104 |
-
coords = geolocator.geocode(location)
|
| 105 |
-
return (coords.latitude, coords.longitude)
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
def coords2loc(coords: tuple[float, float]) -> str:
|
| 109 |
-
"""Converts geographic coordinates to a location name.
|
| 110 |
-
|
| 111 |
-
This function uses the Nominatim reverse geocoding service to convert
|
| 112 |
-
latitude and longitude coordinates to a human-readable location name.
|
| 113 |
-
|
| 114 |
-
Args:
|
| 115 |
-
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
| 116 |
-
|
| 117 |
-
Returns:
|
| 118 |
-
str: The address of the location, or "Unknown Location" if not found
|
| 119 |
-
|
| 120 |
-
Example:
|
| 121 |
-
>>> coords2loc((48.8566, 2.3522))
|
| 122 |
-
'Paris, France'
|
| 123 |
-
"""
|
| 124 |
-
geolocator = Nominatim(user_agent="coords_to_city")
|
| 125 |
-
try:
|
| 126 |
-
location = geolocator.reverse(coords)
|
| 127 |
-
return location.address
|
| 128 |
-
except Exception as e:
|
| 129 |
-
print(f"Error: {e}")
|
| 130 |
-
return "Unknown Location"
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
|
| 134 |
-
long = round(location[1], 3)
|
| 135 |
-
lat = round(location[0], 3)
|
| 136 |
-
|
| 137 |
-
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
| 138 |
-
|
| 139 |
-
results = duckdb.sql(
|
| 140 |
-
f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
|
| 141 |
-
).fetchdf()
|
| 142 |
-
|
| 143 |
-
if len(results) == 0:
|
| 144 |
-
return "", ""
|
| 145 |
-
# cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
|
| 146 |
-
return results['latitude'].iloc[0], results['longitude'].iloc[0]
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
|
| 150 |
-
"""Identifies relevant tables for a plot based on user input.
|
| 151 |
-
|
| 152 |
-
This function uses an LLM to analyze the user's question and the plot
|
| 153 |
-
description to determine which tables in the DRIAS database would be
|
| 154 |
-
most relevant for generating the requested visualization.
|
| 155 |
-
|
| 156 |
-
Args:
|
| 157 |
-
user_question (str): The user's question about climate data
|
| 158 |
-
plot (Plot): The plot configuration object
|
| 159 |
-
llm: The language model instance to use for analysis
|
| 160 |
-
|
| 161 |
-
Returns:
|
| 162 |
-
list[str]: A list of table names that are relevant for the plot
|
| 163 |
-
|
| 164 |
-
Example:
|
| 165 |
-
>>> detect_relevant_tables(
|
| 166 |
-
... "What will the temperature be like in Paris?",
|
| 167 |
-
... indicator_evolution_at_location,
|
| 168 |
-
... llm
|
| 169 |
-
... )
|
| 170 |
-
['mean_annual_temperature', 'mean_summer_temperature']
|
| 171 |
-
"""
|
| 172 |
-
# Get all table names
|
| 173 |
-
table_names_list = DRIAS_TABLES
|
| 174 |
-
|
| 175 |
-
prompt = (
|
| 176 |
-
f"You are helping to build a plot following this description : {plot['description']}."
|
| 177 |
-
f"You are given a list of tables and a user question."
|
| 178 |
-
f"Based on the description of the plot, which table are appropriate for that kind of plot."
|
| 179 |
-
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
|
| 180 |
-
f"### List of tables : {table_names_list}"
|
| 181 |
-
f"### User question : {user_question}"
|
| 182 |
-
f"### List of table name : "
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
table_names = ast.literal_eval(
|
| 186 |
-
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
| 187 |
-
)
|
| 188 |
-
return table_names
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def replace_coordonates(coords, query, coords_tables):
|
| 192 |
-
n = query.count(str(coords[0]))
|
| 193 |
-
|
| 194 |
-
for i in range(n):
|
| 195 |
-
query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
|
| 196 |
-
query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
|
| 197 |
-
return query
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
async def detect_relevant_plots(user_question: str, llm):
|
| 201 |
-
plots_description = ""
|
| 202 |
-
for plot in PLOTS:
|
| 203 |
-
plots_description += "Name: " + plot["name"]
|
| 204 |
-
plots_description += " - Description: " + plot["description"] + "\n"
|
| 205 |
-
|
| 206 |
-
prompt = (
|
| 207 |
-
f"You are helping to answer a quesiton with insightful visualizations."
|
| 208 |
-
f"You are given an user question and a list of plots with their name and description."
|
| 209 |
-
f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
|
| 210 |
-
f"Write the most relevant tables to use. Answer only a python list of plot name."
|
| 211 |
-
f"### Descriptions of the plots : {plots_description}"
|
| 212 |
-
f"### User question : {user_question}"
|
| 213 |
-
f"### Name of the plot : "
|
| 214 |
-
)
|
| 215 |
-
# prompt = (
|
| 216 |
-
# f"You are helping to answer a question with insightful visualizations. "
|
| 217 |
-
# f"Given a list of plots with their name and description: "
|
| 218 |
-
# f"{plots_description} "
|
| 219 |
-
# f"The user question is: {user_question}. "
|
| 220 |
-
# f"Choose the most relevant plots to answer the question. "
|
| 221 |
-
# f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
|
| 222 |
-
# f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
|
| 223 |
-
# )
|
| 224 |
-
|
| 225 |
-
plot_names = ast.literal_eval(
|
| 226 |
-
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
| 227 |
-
)
|
| 228 |
-
return plot_names
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
# Next Version
|
| 232 |
-
# class QueryOutput(TypedDict):
|
| 233 |
-
# """Generated SQL query."""
|
| 234 |
-
|
| 235 |
-
# query: Annotated[str, ..., "Syntactically valid SQL query."]
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
# class PlotlyCodeOutput(TypedDict):
|
| 239 |
-
# """Generated Plotly code"""
|
| 240 |
-
|
| 241 |
-
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
| 242 |
-
# def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
|
| 243 |
-
# """Generate SQL query to fetch information."""
|
| 244 |
-
# prompt_params = {
|
| 245 |
-
# "dialect": db.dialect,
|
| 246 |
-
# "table_info": db.get_table_info(),
|
| 247 |
-
# "input": user_input,
|
| 248 |
-
# "relevant_tables": relevant_tables,
|
| 249 |
-
# "model": "ALADIN63_CNRM-CM5",
|
| 250 |
-
# }
|
| 251 |
-
|
| 252 |
-
# prompt = ChatPromptTemplate.from_template(query_prompt_template)
|
| 253 |
-
# structured_llm = llm.with_structured_output(QueryOutput)
|
| 254 |
-
# chain = prompt | structured_llm
|
| 255 |
-
# result = chain.invoke(prompt_params)
|
| 256 |
-
|
| 257 |
-
# return result["query"]
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
# def fetch_data_from_sql_query(db: str, sql_query: str):
|
| 261 |
-
# conn = sqlite3.connect(db)
|
| 262 |
-
# cursor = conn.cursor()
|
| 263 |
-
# cursor.execute(sql_query)
|
| 264 |
-
# column_names = [desc[0] for desc in cursor.description]
|
| 265 |
-
# values = cursor.fetchall()
|
| 266 |
-
# return {"column_names": column_names, "data": values}
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
# def generate_chart_code(user_input: str, sql_query: list[str], llm):
|
| 270 |
-
# """ "Generate plotly python code for the chart based on the sql query and the user question"""
|
| 271 |
-
|
| 272 |
-
# class PlotlyCodeOutput(TypedDict):
|
| 273 |
-
# """Generated Plotly code"""
|
| 274 |
-
|
| 275 |
-
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
| 276 |
-
|
| 277 |
-
# prompt = ChatPromptTemplate.from_template(plot_prompt_template)
|
| 278 |
-
# structured_llm = llm.with_structured_output(PlotlyCodeOutput)
|
| 279 |
-
# chain = prompt | structured_llm
|
| 280 |
-
# result = chain.invoke({"input": user_input, "sql_query": sql_query})
|
| 281 |
-
# return result["code"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/vanna_class.py
DELETED
|
@@ -1,325 +0,0 @@
|
|
| 1 |
-
from vanna.base import VannaBase
|
| 2 |
-
from pinecone import Pinecone
|
| 3 |
-
from climateqa.engine.embeddings import get_embeddings_function
|
| 4 |
-
import pandas as pd
|
| 5 |
-
import hashlib
|
| 6 |
-
|
| 7 |
-
class MyCustomVectorDB(VannaBase):
|
| 8 |
-
|
| 9 |
-
"""
|
| 10 |
-
VectorDB class for storing and retrieving vectors from Pinecone.
|
| 11 |
-
|
| 12 |
-
args :
|
| 13 |
-
config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
|
| 14 |
-
- pc_api_key (str) : Pinecone API key
|
| 15 |
-
- index_name (str) : Pinecone index name
|
| 16 |
-
- top_k (int) : Number of top results to return (default = 2)
|
| 17 |
-
|
| 18 |
-
"""
|
| 19 |
-
|
| 20 |
-
def __init__(self,config):
|
| 21 |
-
super().__init__(config = config)
|
| 22 |
-
try :
|
| 23 |
-
self.api_key = config.get('pc_api_key')
|
| 24 |
-
self.index_name = config.get('index_name')
|
| 25 |
-
except :
|
| 26 |
-
raise Exception("Please provide the Pinecone API key and the index name")
|
| 27 |
-
|
| 28 |
-
self.pc = Pinecone(api_key = self.api_key)
|
| 29 |
-
self.index = self.pc.Index(self.index_name)
|
| 30 |
-
self.top_k = config.get('top_k', 2)
|
| 31 |
-
self.embeddings = get_embeddings_function()
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def check_embedding(self, id, namespace):
|
| 35 |
-
fetched = self.index.fetch(ids = [id], namespace = namespace)
|
| 36 |
-
if fetched['vectors'] == {}:
|
| 37 |
-
return False
|
| 38 |
-
return True
|
| 39 |
-
|
| 40 |
-
def generate_hash_id(self, data: str) -> str:
|
| 41 |
-
"""
|
| 42 |
-
Generate a unique hash ID for the given data.
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
data (str): The input data to hash (e.g., a concatenated string of user attributes).
|
| 46 |
-
|
| 47 |
-
Returns:
|
| 48 |
-
str: A unique hash ID as a hexadecimal string.
|
| 49 |
-
"""
|
| 50 |
-
|
| 51 |
-
data_bytes = data.encode('utf-8')
|
| 52 |
-
hash_object = hashlib.sha256(data_bytes)
|
| 53 |
-
hash_id = hash_object.hexdigest()
|
| 54 |
-
|
| 55 |
-
return hash_id
|
| 56 |
-
|
| 57 |
-
def add_ddl(self, ddl: str, **kwargs) -> str:
|
| 58 |
-
id = self.generate_hash_id(ddl) + '_ddl'
|
| 59 |
-
|
| 60 |
-
if self.check_embedding(id, 'ddl'):
|
| 61 |
-
print(f"DDL having id {id} already exists")
|
| 62 |
-
return id
|
| 63 |
-
|
| 64 |
-
self.index.upsert(
|
| 65 |
-
vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
|
| 66 |
-
namespace = 'ddl'
|
| 67 |
-
)
|
| 68 |
-
|
| 69 |
-
return id
|
| 70 |
-
|
| 71 |
-
def add_documentation(self, doc: str, **kwargs) -> str:
|
| 72 |
-
id = self.generate_hash_id(doc) + '_doc'
|
| 73 |
-
|
| 74 |
-
if self.check_embedding(id, 'documentation'):
|
| 75 |
-
print(f"Documentation having id {id} already exists")
|
| 76 |
-
return id
|
| 77 |
-
|
| 78 |
-
self.index.upsert(
|
| 79 |
-
vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
|
| 80 |
-
namespace = 'documentation'
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
return id
|
| 84 |
-
|
| 85 |
-
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
| 86 |
-
id = self.generate_hash_id(question) + '_sql'
|
| 87 |
-
|
| 88 |
-
if self.check_embedding(id, 'question_sql'):
|
| 89 |
-
print(f"Question-SQL pair having id {id} already exists")
|
| 90 |
-
return id
|
| 91 |
-
|
| 92 |
-
self.index.upsert(
|
| 93 |
-
vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
|
| 94 |
-
namespace = 'question_sql'
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
return id
|
| 98 |
-
|
| 99 |
-
def get_related_ddl(self, question: str, **kwargs) -> list:
|
| 100 |
-
res = self.index.query(
|
| 101 |
-
vector=self.embeddings.embed_query(question),
|
| 102 |
-
top_k=self.top_k,
|
| 103 |
-
namespace='ddl',
|
| 104 |
-
include_metadata=True
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
return [match['metadata']['ddl'] for match in res['matches']]
|
| 108 |
-
|
| 109 |
-
def get_related_documentation(self, question: str, **kwargs) -> list:
|
| 110 |
-
res = self.index.query(
|
| 111 |
-
vector=self.embeddings.embed_query(question),
|
| 112 |
-
top_k=self.top_k,
|
| 113 |
-
namespace='documentation',
|
| 114 |
-
include_metadata=True
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
return [match['metadata']['doc'] for match in res['matches']]
|
| 118 |
-
|
| 119 |
-
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
| 120 |
-
res = self.index.query(
|
| 121 |
-
vector=self.embeddings.embed_query(question),
|
| 122 |
-
top_k=self.top_k,
|
| 123 |
-
namespace='question_sql',
|
| 124 |
-
include_metadata=True
|
| 125 |
-
)
|
| 126 |
-
|
| 127 |
-
return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
|
| 128 |
-
|
| 129 |
-
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
| 130 |
-
|
| 131 |
-
list_of_data = []
|
| 132 |
-
|
| 133 |
-
namespaces = ['ddl', 'documentation', 'question_sql']
|
| 134 |
-
|
| 135 |
-
for namespace in namespaces:
|
| 136 |
-
|
| 137 |
-
data = self.index.query(
|
| 138 |
-
top_k=10000,
|
| 139 |
-
namespace=namespace,
|
| 140 |
-
include_metadata=True,
|
| 141 |
-
include_values=False
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
for match in data['matches']:
|
| 145 |
-
list_of_data.append(match['metadata'])
|
| 146 |
-
|
| 147 |
-
return pd.DataFrame(list_of_data)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def remove_training_data(self, id: str, **kwargs) -> bool:
|
| 152 |
-
if id.endswith("_ddl"):
|
| 153 |
-
self.Index.delete(ids=[id], namespace="_ddl")
|
| 154 |
-
return True
|
| 155 |
-
if id.endswith("_sql"):
|
| 156 |
-
self.index.delete(ids=[id], namespace="_sql")
|
| 157 |
-
return True
|
| 158 |
-
|
| 159 |
-
if id.endswith("_doc"):
|
| 160 |
-
self.Index.delete(ids=[id], namespace="_doc")
|
| 161 |
-
return True
|
| 162 |
-
|
| 163 |
-
return False
|
| 164 |
-
|
| 165 |
-
def generate_embedding(self, text, **kwargs):
|
| 166 |
-
# Implement the method here
|
| 167 |
-
pass
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def get_sql_prompt(
|
| 171 |
-
self,
|
| 172 |
-
initial_prompt : str,
|
| 173 |
-
question: str,
|
| 174 |
-
question_sql_list: list,
|
| 175 |
-
ddl_list: list,
|
| 176 |
-
doc_list: list,
|
| 177 |
-
**kwargs,
|
| 178 |
-
):
|
| 179 |
-
"""
|
| 180 |
-
Example:
|
| 181 |
-
```python
|
| 182 |
-
vn.get_sql_prompt(
|
| 183 |
-
question="What are the top 10 customers by sales?",
|
| 184 |
-
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
| 185 |
-
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
| 186 |
-
doc_list=["The customers table contains information about customers and their sales."],
|
| 187 |
-
)
|
| 188 |
-
|
| 189 |
-
```
|
| 190 |
-
|
| 191 |
-
This method is used to generate a prompt for the LLM to generate SQL.
|
| 192 |
-
|
| 193 |
-
Args:
|
| 194 |
-
question (str): The question to generate SQL for.
|
| 195 |
-
question_sql_list (list): A list of questions and their corresponding SQL statements.
|
| 196 |
-
ddl_list (list): A list of DDL statements.
|
| 197 |
-
doc_list (list): A list of documentation.
|
| 198 |
-
|
| 199 |
-
Returns:
|
| 200 |
-
any: The prompt for the LLM to generate SQL.
|
| 201 |
-
"""
|
| 202 |
-
|
| 203 |
-
if initial_prompt is None:
|
| 204 |
-
initial_prompt = f"You are a {self.dialect} expert. " + \
|
| 205 |
-
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
| 206 |
-
|
| 207 |
-
initial_prompt = self.add_ddl_to_prompt(
|
| 208 |
-
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
| 209 |
-
)
|
| 210 |
-
|
| 211 |
-
if self.static_documentation != "":
|
| 212 |
-
doc_list.append(self.static_documentation)
|
| 213 |
-
|
| 214 |
-
initial_prompt = self.add_documentation_to_prompt(
|
| 215 |
-
initial_prompt, doc_list, max_tokens=self.max_tokens
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
# initial_prompt = self.add_sql_to_prompt(
|
| 219 |
-
# initial_prompt, question_sql_list, max_tokens=self.max_tokens
|
| 220 |
-
# )
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
initial_prompt += (
|
| 224 |
-
"===Response Guidelines \n"
|
| 225 |
-
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
| 226 |
-
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
| 227 |
-
"3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
|
| 228 |
-
"4. Please use the most relevant table(s). \n"
|
| 229 |
-
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
| 230 |
-
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
| 231 |
-
f"7. Add a description of the table in the result of the sql query, if relevant. \n"
|
| 232 |
-
"8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
|
| 233 |
-
# f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
|
| 234 |
-
# "7. Add a description of the table in the result of the sql query."
|
| 235 |
-
# "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
|
| 236 |
-
# "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
|
| 237 |
-
)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
message_log = [self.system_message(initial_prompt)]
|
| 241 |
-
|
| 242 |
-
for example in question_sql_list:
|
| 243 |
-
if example is None:
|
| 244 |
-
print("example is None")
|
| 245 |
-
else:
|
| 246 |
-
if example is not None and "question" in example and "sql" in example:
|
| 247 |
-
message_log.append(self.user_message(example["question"]))
|
| 248 |
-
message_log.append(self.assistant_message(example["sql"]))
|
| 249 |
-
|
| 250 |
-
message_log.append(self.user_message(question))
|
| 251 |
-
|
| 252 |
-
return message_log
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
# def get_sql_prompt(
|
| 256 |
-
# self,
|
| 257 |
-
# initial_prompt : str,
|
| 258 |
-
# question: str,
|
| 259 |
-
# question_sql_list: list,
|
| 260 |
-
# ddl_list: list,
|
| 261 |
-
# doc_list: list,
|
| 262 |
-
# **kwargs,
|
| 263 |
-
# ):
|
| 264 |
-
# """
|
| 265 |
-
# Example:
|
| 266 |
-
# ```python
|
| 267 |
-
# vn.get_sql_prompt(
|
| 268 |
-
# question="What are the top 10 customers by sales?",
|
| 269 |
-
# question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
| 270 |
-
# ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
| 271 |
-
# doc_list=["The customers table contains information about customers and their sales."],
|
| 272 |
-
# )
|
| 273 |
-
|
| 274 |
-
# ```
|
| 275 |
-
|
| 276 |
-
# This method is used to generate a prompt for the LLM to generate SQL.
|
| 277 |
-
|
| 278 |
-
# Args:
|
| 279 |
-
# question (str): The question to generate SQL for.
|
| 280 |
-
# question_sql_list (list): A list of questions and their corresponding SQL statements.
|
| 281 |
-
# ddl_list (list): A list of DDL statements.
|
| 282 |
-
# doc_list (list): A list of documentation.
|
| 283 |
-
|
| 284 |
-
# Returns:
|
| 285 |
-
# any: The prompt for the LLM to generate SQL.
|
| 286 |
-
# """
|
| 287 |
-
|
| 288 |
-
# if initial_prompt is None:
|
| 289 |
-
# initial_prompt = f"You are a {self.dialect} expert. " + \
|
| 290 |
-
# "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
| 291 |
-
|
| 292 |
-
# initial_prompt = self.add_ddl_to_prompt(
|
| 293 |
-
# initial_prompt, ddl_list, max_tokens=self.max_tokens
|
| 294 |
-
# )
|
| 295 |
-
|
| 296 |
-
# if self.static_documentation != "":
|
| 297 |
-
# doc_list.append(self.static_documentation)
|
| 298 |
-
|
| 299 |
-
# initial_prompt = self.add_documentation_to_prompt(
|
| 300 |
-
# initial_prompt, doc_list, max_tokens=self.max_tokens
|
| 301 |
-
# )
|
| 302 |
-
|
| 303 |
-
# initial_prompt += (
|
| 304 |
-
# "===Response Guidelines \n"
|
| 305 |
-
# "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
| 306 |
-
# "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
| 307 |
-
# "3. If the provided context is insufficient, please explain why it can't be generated. \n"
|
| 308 |
-
# "4. Please use the most relevant table(s). \n"
|
| 309 |
-
# "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
| 310 |
-
# f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
| 311 |
-
# )
|
| 312 |
-
|
| 313 |
-
# message_log = [self.system_message(initial_prompt)]
|
| 314 |
-
|
| 315 |
-
# for example in question_sql_list:
|
| 316 |
-
# if example is None:
|
| 317 |
-
# print("example is None")
|
| 318 |
-
# else:
|
| 319 |
-
# if example is not None and "question" in example and "sql" in example:
|
| 320 |
-
# message_log.append(self.user_message(example["question"]))
|
| 321 |
-
# message_log.append(self.assistant_message(example["sql"]))
|
| 322 |
-
|
| 323 |
-
# message_log.append(self.user_message(question))
|
| 324 |
-
|
| 325 |
-
# return message_log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
climateqa/engine/talk_to_data/workflow/drias.py
CHANGED
|
@@ -125,11 +125,16 @@ async def drias_workflow(user_input: str) -> State:
|
|
| 125 |
'plot': plot,
|
| 126 |
'status': 'OK'
|
| 127 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
-
# Gather all required parameters
|
| 130 |
params = {}
|
| 131 |
-
for
|
| 132 |
-
param = await find_param(state, param_name, mode='DRIAS')
|
| 133 |
if param:
|
| 134 |
params.update(param)
|
| 135 |
|
|
|
|
| 125 |
'plot': plot,
|
| 126 |
'status': 'OK'
|
| 127 |
}
|
| 128 |
+
|
| 129 |
+
# Gather all required parameters in parallel
|
| 130 |
+
param_tasks = [
|
| 131 |
+
find_param(state, param_name, mode='DRIAS')
|
| 132 |
+
for param_name in DRIAS_PLOT_PARAMETERS
|
| 133 |
+
]
|
| 134 |
+
param_results = await asyncio.gather(*param_tasks)
|
| 135 |
|
|
|
|
| 136 |
params = {}
|
| 137 |
+
for param in param_results:
|
|
|
|
| 138 |
if param:
|
| 139 |
params.update(param)
|
| 140 |
|
climateqa/engine/talk_to_data/workflow/ipcc.py
CHANGED
|
@@ -125,12 +125,17 @@ async def ipcc_workflow(user_input: str) -> State:
|
|
| 125 |
}
|
| 126 |
|
| 127 |
# Gather all required parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
params = {}
|
| 129 |
-
for
|
| 130 |
-
param = await find_param(state, param_name, mode='IPCC')
|
| 131 |
if param:
|
| 132 |
params.update(param)
|
| 133 |
-
|
| 134 |
# Process all outputs in parallel using process_output
|
| 135 |
tasks = [
|
| 136 |
process_output(output_title, output['table'], output['plot'], params.copy())
|
|
@@ -152,10 +157,18 @@ async def ipcc_workflow(user_input: str) -> State:
|
|
| 152 |
|
| 153 |
# Set error messages if needed
|
| 154 |
if not errors['have_relevant_table']:
|
| 155 |
-
state['error'] =
|
|
|
|
|
|
|
|
|
|
| 156 |
elif not errors['have_sql_query']:
|
| 157 |
-
state['error'] =
|
|
|
|
|
|
|
|
|
|
| 158 |
elif not errors['have_dataframe']:
|
| 159 |
-
state['error'] =
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
return state
|
|
|
|
| 125 |
}
|
| 126 |
|
| 127 |
# Gather all required parameters
|
| 128 |
+
param_tasks = [
|
| 129 |
+
find_param(state, param_name, mode='IPCC')
|
| 130 |
+
for param_name in IPCC_PLOT_PARAMETERS
|
| 131 |
+
]
|
| 132 |
+
param_results = await asyncio.gather(*param_tasks)
|
| 133 |
+
|
| 134 |
params = {}
|
| 135 |
+
for param in param_results:
|
|
|
|
| 136 |
if param:
|
| 137 |
params.update(param)
|
| 138 |
+
|
| 139 |
# Process all outputs in parallel using process_output
|
| 140 |
tasks = [
|
| 141 |
process_output(output_title, output['table'], output['plot'], params.copy())
|
|
|
|
| 157 |
|
| 158 |
# Set error messages if needed
|
| 159 |
if not errors['have_relevant_table']:
|
| 160 |
+
state['error'] = (
|
| 161 |
+
"Sorry, I couldn't find any relevant table in our database to answer your question.\n"
|
| 162 |
+
"Try asking about a different climate indicator like temperature or precipitation."
|
| 163 |
+
)
|
| 164 |
elif not errors['have_sql_query']:
|
| 165 |
+
state['error'] = (
|
| 166 |
+
"Sorry, I couldn't generate a relevant SQL query to answer your question.\n"
|
| 167 |
+
"Try rephrasing your question to focus on a specific location, a year, or a month."
|
| 168 |
+
)
|
| 169 |
elif not errors['have_dataframe']:
|
| 170 |
+
state['error'] = (
|
| 171 |
+
"Sorry, there is no data in our tables that can answer your question.\n"
|
| 172 |
+
"Try asking about a more common location, or a different year."
|
| 173 |
+
)
|
| 174 |
return state
|
climateqa/engine/vectorstore.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
-
#
|
| 2 |
-
# More info at https://docs.pinecone.io/docs/langchain
|
| 3 |
-
# And https://python.langchain.com/docs/integrations/vectorstores/pinecone
|
| 4 |
import os
|
| 5 |
-
from pinecone import Pinecone
|
| 6 |
-
from langchain_community.vectorstores import Pinecone as PineconeVectorstore
|
| 7 |
|
| 8 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
try:
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
load_dotenv()
|
|
@@ -13,44 +13,136 @@ except:
|
|
| 13 |
pass
|
| 14 |
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# def get_pinecone_retriever(vectorstore,k = 10,namespace = "vectors",sources = ["IPBES","IPCC"]):
|
| 42 |
-
|
| 43 |
-
# assert isinstance(sources,list)
|
| 44 |
-
|
| 45 |
-
# # Check if all elements in the list are either IPCC or IPBES
|
| 46 |
-
# filter = {
|
| 47 |
-
# "source": { "$in":sources},
|
| 48 |
-
# }
|
| 49 |
-
|
| 50 |
-
# retriever = vectorstore.as_retriever(search_kwargs={
|
| 51 |
-
# "k": k,
|
| 52 |
-
# "namespace":"vectors",
|
| 53 |
-
# "filter":filter
|
| 54 |
-
# })
|
| 55 |
|
| 56 |
-
# return retriever
|
|
|
|
| 1 |
+
# Azure AI Search: https://python.langchain.com/docs/integrations/vectorstores/azuresearch
|
|
|
|
|
|
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
# Azure AI Search imports
|
| 5 |
+
from langchain_community.vectorstores.azuresearch import AzureSearch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Load environment variables
|
| 9 |
try:
|
| 10 |
from dotenv import load_dotenv
|
| 11 |
load_dotenv()
|
|
|
|
| 13 |
pass
|
| 14 |
|
| 15 |
|
| 16 |
+
class AzureSearchWrapper:
|
| 17 |
+
"""
|
| 18 |
+
Wrapper class for Azure AI Search vectorstore to handle filter conversion.
|
| 19 |
+
|
| 20 |
+
This wrapper automatically converts dictionary-style filters to Azure Search OData filter format,
|
| 21 |
+
ensuring seamless compatibility when switching from other providers.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, azure_search_vectorstore):
|
| 25 |
+
self.vectorstore = azure_search_vectorstore
|
| 26 |
+
|
| 27 |
+
def __getattr__(self, name):
|
| 28 |
+
"""Delegate all other attributes to the wrapped vectorstore."""
|
| 29 |
+
return getattr(self.vectorstore, name)
|
| 30 |
+
|
| 31 |
+
def _convert_dict_filter_to_odata(self, filter_dict):
|
| 32 |
+
"""
|
| 33 |
+
Convert dictionary-style filters to Azure Search OData filter format.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
filter_dict (dict): Dictionary-style filter
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
str: OData filter string
|
| 40 |
+
"""
|
| 41 |
+
if not filter_dict:
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
conditions = []
|
| 45 |
+
|
| 46 |
+
for key, value in filter_dict.items():
|
| 47 |
+
if key.endswith('_exclude'):
|
| 48 |
+
# Handle exclusion filters (e.g., report_type_exclude)
|
| 49 |
+
base_key = key.replace('_exclude', '')
|
| 50 |
+
if isinstance(value, list):
|
| 51 |
+
if len(value) == 1:
|
| 52 |
+
conditions.append(f"{base_key} ne '{value[0]}'")
|
| 53 |
+
else:
|
| 54 |
+
exclude_conditions = [f"{base_key} ne '{v}'" for v in value]
|
| 55 |
+
conditions.append(f"({' and '.join(exclude_conditions)})")
|
| 56 |
+
else:
|
| 57 |
+
conditions.append(f"{base_key} ne '{value}'")
|
| 58 |
+
elif isinstance(value, list):
|
| 59 |
+
# Handle list values (equivalent to $in operator)
|
| 60 |
+
if len(value) == 1:
|
| 61 |
+
conditions.append(f"{key} eq '{value[0]}'")
|
| 62 |
+
else:
|
| 63 |
+
list_conditions = [f"{key} eq '{v}'" for v in value]
|
| 64 |
+
conditions.append(f"({' or '.join(list_conditions)})")
|
| 65 |
+
else:
|
| 66 |
+
# Handle single values
|
| 67 |
+
conditions.append(f"{key} eq '{value}'")
|
| 68 |
+
|
| 69 |
+
return " and ".join(conditions) if conditions else None
|
| 70 |
+
|
| 71 |
+
def similarity_search_with_score(self, query, k=4, filter=None, **kwargs):
|
| 72 |
+
"""Override similarity_search_with_score to convert filters."""
|
| 73 |
+
if filter is not None:
|
| 74 |
+
filter = self._convert_dict_filter_to_odata(filter)
|
| 75 |
+
|
| 76 |
+
return self.vectorstore.hybrid_search_with_score(
|
| 77 |
+
query=query, k=k, filters=filter, **kwargs
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def similarity_search(self, query, k=4, filter=None, **kwargs):
|
| 82 |
+
"""Override similarity_search to convert filters."""
|
| 83 |
+
if filter is not None:
|
| 84 |
+
filter = self._convert_dict_filter_to_odata(filter)
|
| 85 |
+
|
| 86 |
+
return self.vectorstore.similarity_search(
|
| 87 |
+
query=query, k=k, filter=filter, **kwargs
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def similarity_search_by_vector(self, embedding, k=4, filter=None, **kwargs):
|
| 91 |
+
"""Override similarity_search_by_vector to convert filters."""
|
| 92 |
+
if filter is not None:
|
| 93 |
+
filter = self._convert_dict_filter_to_odata(filter)
|
| 94 |
+
|
| 95 |
+
return self.vectorstore.similarity_search_by_vector(
|
| 96 |
+
embedding=embedding, k=k, filter=filter, **kwargs
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs):
|
| 100 |
+
"""Override as_retriever to handle filter conversion in search_kwargs."""
|
| 101 |
+
if search_kwargs and "filter" in search_kwargs:
|
| 102 |
+
# Convert the filter in search_kwargs
|
| 103 |
+
search_kwargs = search_kwargs.copy() # Don't modify the original
|
| 104 |
+
if search_kwargs["filter"] is not None:
|
| 105 |
+
search_kwargs["filter"] = self._convert_dict_filter_to_odata(search_kwargs["filter"])
|
| 106 |
+
|
| 107 |
+
return self.vectorstore.as_retriever(
|
| 108 |
+
search_type=search_type, search_kwargs=search_kwargs, **kwargs
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_azure_search_vectorstore(embeddings, text_key="content", index_name=None):
|
| 113 |
+
"""
|
| 114 |
+
Create an Azure AI Search vectorstore instance.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
embeddings: The embeddings function to use
|
| 118 |
+
text_key: The key for text content in the payload (default: "content")
|
| 119 |
+
index_name: The name of the Azure Search index
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
AzureSearchWrapper: A wrapped Azure AI Search vectorstore instance with filter compatibility
|
| 123 |
+
"""
|
| 124 |
+
# Get Azure AI Search configuration from environment variables
|
| 125 |
+
azure_search_endpoint = os.getenv("AI_SEARCH_INDEX_ENDPOINT")
|
| 126 |
+
azure_search_key = os.getenv("AI_SEARCH_KEY")
|
| 127 |
+
|
| 128 |
+
if not azure_search_endpoint:
|
| 129 |
+
raise ValueError("AI_SEARCH_INDEX_ENDPOINT environment variable is required")
|
| 130 |
+
|
| 131 |
+
if not azure_search_key:
|
| 132 |
+
raise ValueError("AI_SEARCH_KEY environment variable is required")
|
| 133 |
+
|
| 134 |
+
if not index_name:
|
| 135 |
+
raise ValueError("index_name must be provided for Azure Search")
|
| 136 |
+
|
| 137 |
+
# Create Azure Search vectorstore
|
| 138 |
+
vectorstore = AzureSearch(
|
| 139 |
+
azure_search_endpoint=azure_search_endpoint,
|
| 140 |
+
azure_search_key=azure_search_key,
|
| 141 |
+
index_name=index_name,
|
| 142 |
+
embedding_function=embeddings.embed_query,
|
| 143 |
+
content_key=text_key,
|
| 144 |
)
|
| 145 |
+
|
| 146 |
+
# Wrap the vectorstore to handle filter conversion
|
| 147 |
+
return AzureSearchWrapper(vectorstore)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
|
|
climateqa/utils.py
CHANGED
|
@@ -25,7 +25,7 @@ def remove_duplicates_keep_highest_score(documents):
|
|
| 25 |
unique_docs = {}
|
| 26 |
|
| 27 |
for doc in documents:
|
| 28 |
-
doc_id = doc.metadata.get('
|
| 29 |
if doc_id in unique_docs:
|
| 30 |
if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
|
| 31 |
unique_docs[doc_id] = doc
|
|
|
|
| 25 |
unique_docs = {}
|
| 26 |
|
| 27 |
for doc in documents:
|
| 28 |
+
doc_id = doc.metadata.get('id')
|
| 29 |
if doc_id in unique_docs:
|
| 30 |
if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
|
| 31 |
unique_docs[doc_id] = doc
|
front/tabs/tab_ipcc.py
CHANGED
|
@@ -68,6 +68,8 @@ def show_filter_by_scenario(table_names, index_state, dataframes):
|
|
| 68 |
return gr.update(visible=False)
|
| 69 |
|
| 70 |
def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
|
|
|
|
|
|
|
| 71 |
df = dataframes[index_state]
|
| 72 |
if not table_names[index_state].startswith("Map"):
|
| 73 |
return df, figures[index_state](df)
|
|
|
|
| 68 |
return gr.update(visible=False)
|
| 69 |
|
| 70 |
def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
|
| 71 |
+
if len(dataframes) == 0:
|
| 72 |
+
return None, None
|
| 73 |
df = dataframes[index_state]
|
| 74 |
if not table_names[index_state].startswith("Map"):
|
| 75 |
return df, figures[index_state](df)
|
requirements.txt
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
gradio==5.0.2
|
| 2 |
azure-storage-file-share==12.11.1
|
| 3 |
azure-storage-blob==12.23.0
|
|
|
|
|
|
|
|
|
|
| 4 |
python-dotenv==1.0.0
|
| 5 |
langchain==0.2.1
|
| 6 |
langchain_openai==0.1.7
|
|
|
|
| 1 |
gradio==5.0.2
|
| 2 |
azure-storage-file-share==12.11.1
|
| 3 |
azure-storage-blob==12.23.0
|
| 4 |
+
# Azure AI Search support
|
| 5 |
+
azure-search-documents>=11.4.0
|
| 6 |
+
azure-core>=1.29.0
|
| 7 |
python-dotenv==1.0.0
|
| 8 |
langchain==0.2.1
|
| 9 |
langchain_openai==0.1.7
|
sandbox/20241104 - CQA - StepByStep CQA.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
style.css
CHANGED
|
@@ -661,7 +661,6 @@ a {
|
|
| 661 |
|
| 662 |
#sql-query textarea{
|
| 663 |
min-height: 200px !important;
|
| 664 |
-
|
| 665 |
}
|
| 666 |
|
| 667 |
#sql-query span{
|
|
|
|
| 661 |
|
| 662 |
#sql-query textarea{
|
| 663 |
min-height: 200px !important;
|
|
|
|
| 664 |
}
|
| 665 |
|
| 666 |
#sql-query span{
|