timeki commited on
Commit
0110684
·
2 Parent(s): d37790b ecc6c98

Merged in dev (pull request #27)

Browse files
app.py CHANGED
@@ -7,7 +7,7 @@ from azure.storage.fileshare import ShareServiceClient
7
  # Import custom modules
8
  from climateqa.engine.embeddings import get_embeddings_function
9
  from climateqa.engine.llm import get_llm
10
- from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
  from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
@@ -66,17 +66,11 @@ user_id = create_user_id()
66
 
67
  # Create vectorstore and retriever
68
  embeddings_function = get_embeddings_function()
69
- vectorstore = get_pinecone_vectorstore(
70
- embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
71
- )
72
- vectorstore_graphs = get_pinecone_vectorstore(
73
- embeddings_function,
74
- index_name=os.getenv("PINECONE_API_INDEX_OWID"),
75
- text_key="description",
76
- )
77
- vectorstore_region = get_pinecone_vectorstore(
78
- embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
79
- )
80
 
81
  llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
82
  if os.environ["GRADIO_ENV"] == "local":
 
7
  # Import custom modules
8
  from climateqa.engine.embeddings import get_embeddings_function
9
  from climateqa.engine.llm import get_llm
10
+ from climateqa.engine.vectorstore import get_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
  from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
 
66
 
67
  # Create vectorstore and retriever
68
  embeddings_function = get_embeddings_function()
69
+
70
+ vectorstore = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-ipx")
71
+ vectorstore_graphs = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-owid", text_key="description")
72
+ vectorstore_region = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-v2")
73
+
 
 
 
 
 
 
74
 
75
  llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
76
  if os.environ["GRADIO_ENV"] == "local":
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -19,7 +19,7 @@ from ..llm import get_llm
19
  from .prompts import retrieve_chapter_prompt_template
20
  from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.output_parsers import StrOutputParser
22
- from ..vectorstore import get_pinecone_vectorstore
23
  from ..embeddings import get_embeddings_function
24
  import ast
25
 
@@ -134,7 +134,7 @@ def get_ToCs(version: str) :
134
  "version": version
135
  }
136
  embeddings_function = get_embeddings_function()
137
- vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
138
  tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
139
 
140
  # remove duplicates or almost duplicates
@@ -236,7 +236,7 @@ async def get_POC_documents_by_ToC_relevant_documents(
236
  filters_text_toc = {
237
  **filters,
238
  "chunk_type":"text",
239
- "toc_level0": {"$in": toc_filters},
240
  "version": version
241
  # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
242
  }
@@ -273,6 +273,22 @@ async def get_POC_documents_by_ToC_relevant_documents(
273
  "docs_images" : docs_images
274
  }
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  async def get_IPCC_relevant_documents(
278
  query: str,
@@ -299,9 +315,9 @@ async def get_IPCC_relevant_documents(
299
  filters = {}
300
 
301
  if len(reports) > 0:
302
- filters["short_name"] = {"$in":reports}
303
  else:
304
- filters["source"] = { "$in": sources}
305
 
306
  # INIT
307
  docs_summaries = []
@@ -323,18 +339,16 @@ async def get_IPCC_relevant_documents(
323
  filters_summaries = {
324
  **filters,
325
  "chunk_type":"text",
326
- "report_type": { "$in":["SPM"]},
327
  }
328
 
329
  docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
330
  docs_summaries = [x for x in docs_summaries if x[1] > threshold]
331
 
332
  # Search for k_total - k_summary documents in the full reports dataset
333
- filters_full = {
334
- **filters,
335
- "chunk_type":"text",
336
- "report_type": { "$nin":["SPM"]},
337
- }
338
  docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
339
 
340
  if search_figures:
 
19
  from .prompts import retrieve_chapter_prompt_template
20
  from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.output_parsers import StrOutputParser
22
+ from ..vectorstore import get_vectorstore
23
  from ..embeddings import get_embeddings_function
24
  import ast
25
 
 
134
  "version": version
135
  }
136
  embeddings_function = get_embeddings_function()
137
+ vectorstore = get_vectorstore(provider="qdrant", embeddings=embeddings_function, index_name="climateqa")
138
  tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
139
 
140
  # remove duplicates or almost duplicates
 
236
  filters_text_toc = {
237
  **filters,
238
  "chunk_type":"text",
239
+ "toc_level0": toc_filters, # Changed from {"$in": toc_filters} to direct list
240
  "version": version
241
  # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
242
  }
 
273
  "docs_images" : docs_images
274
  }
275
 
276
+ def filter_for_full_report_documents(filters: dict) -> dict:
277
+ """
278
+ Filter for full report documents.
279
+ Returns a dictionary format compatible with all vectorstore providers.
280
+ """
281
+ # Start with the base filters
282
+ full_filters = filters.copy()
283
+
284
+ # Add chunk_type filter
285
+ full_filters["chunk_type"] = "text"
286
+
287
+ # Add report_type exclusion using the new _exclude suffix format
288
+ # This will be converted to appropriate OData filter by Azure Search wrapper
289
+ full_filters["report_type_exclude"] = ["SPM"]
290
+
291
+ return full_filters
292
 
293
  async def get_IPCC_relevant_documents(
294
  query: str,
 
315
  filters = {}
316
 
317
  if len(reports) > 0:
318
+ filters["short_name"] = reports # Changed from {"$in":reports} to direct list
319
  else:
320
+ filters["source"] = sources # Changed from {"$in": sources} to direct list
321
 
322
  # INIT
323
  docs_summaries = []
 
339
  filters_summaries = {
340
  **filters,
341
  "chunk_type":"text",
342
+ "report_type": ["SPM"], # Changed from {"$in":["SPM"]} to direct list
343
  }
344
 
345
  docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
346
  docs_summaries = [x for x in docs_summaries if x[1] > threshold]
347
 
348
  # Search for k_total - k_summary documents in the full reports dataset
349
+ filters_full = filter_for_full_report_documents(filters)
350
+
351
+
 
 
352
  docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
353
 
354
  if search_figures:
climateqa/engine/graph_retriever.py CHANGED
@@ -60,10 +60,9 @@ async def retrieve_graphs(
60
  assert sources
61
  assert any([x in ["OWID"] for x in sources])
62
 
63
- # Prepare base search kwargs
64
- filters = {}
65
-
66
- filters["source"] = {"$in": sources}
67
 
68
  docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
69
 
 
60
  assert sources
61
  assert any([x in ["OWID"] for x in sources])
62
 
63
+ # Prepare base search kwargs for Azure AI Search
64
+ # Azure expects a filter string, e.g. "source eq 'OWID' or source eq 'IEA'"
65
+ filters = {"source":"OWID"}
 
66
 
67
  docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
68
 
climateqa/engine/llm/openai.py CHANGED
@@ -8,7 +8,6 @@ except Exception:
8
  pass
9
 
10
  def get_llm(model="gpt-4o-mini",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
11
-
12
  llm = ChatOpenAI(
13
  model=model,
14
  api_key=os.environ.get("THEO_API_KEY", None),
 
8
  pass
9
 
10
  def get_llm(model="gpt-4o-mini",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
 
11
  llm = ChatOpenAI(
12
  model=model,
13
  api_key=os.environ.get("THEO_API_KEY", None),
climateqa/engine/talk_to_data/input_processing.py CHANGED
@@ -10,6 +10,7 @@ from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput
10
  from climateqa.engine.talk_to_data.objects.location import Location
11
  from climateqa.engine.talk_to_data.objects.plot import Plot
12
  from climateqa.engine.talk_to_data.objects.states import State
 
13
 
14
  async def detect_location_with_openai(sentence: str) -> str:
15
  """
@@ -118,7 +119,7 @@ async def detect_year_with_openai(sentence: str) -> str:
118
  return years_list[0]
119
  else:
120
  return ""
121
-
122
 
123
  async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
124
  """Identifies relevant tables for a plot based on user input.
@@ -227,6 +228,55 @@ async def find_year(user_input: str) -> str| None:
227
  return None
228
  return year
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
231
  print("---- Find relevant plots ----")
232
  relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
@@ -237,16 +287,28 @@ async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: l
237
  relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
238
  return relevant_tables
239
 
240
- async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
241
- """Perform the good method to retrieve the desired parameter
 
242
 
243
  Args:
244
- state (State): state of the workflow
245
- param_name (str): name of the desired parameter
246
- table (str): name of the table
247
 
248
  Returns:
249
- dict[str, Any] | None:
 
 
 
 
 
 
 
 
 
 
 
250
  """
251
  if param_name == 'location':
252
  location = await find_location(state['user_input'], mode)
@@ -254,4 +316,7 @@ async def find_param(state: State, param_name:str, mode: Literal['DRIAS', 'IPCC'
254
  if param_name == 'year':
255
  year = await find_year(state['user_input'])
256
  return {'year': year}
257
- return None
 
 
 
 
10
  from climateqa.engine.talk_to_data.objects.location import Location
11
  from climateqa.engine.talk_to_data.objects.plot import Plot
12
  from climateqa.engine.talk_to_data.objects.states import State
13
+ import calendar
14
 
15
  async def detect_location_with_openai(sentence: str) -> str:
16
  """
 
119
  return years_list[0]
120
  else:
121
  return ""
122
+
123
 
124
  async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]:
125
  """Identifies relevant tables for a plot based on user input.
 
228
  return None
229
  return year
230
 
231
+ async def find_month(user_input: str) -> dict[str, str|None]:
232
+ """
233
+ Extracts month information from user input using an LLM.
234
+
235
+ This function analyzes the user's query to detect if a month is mentioned.
236
+ It returns both the month number (as a string, e.g. '7' for July) and the full English month name (e.g. 'July').
237
+ If no month is found, both values will be None.
238
+
239
+ Args:
240
+ user_input (str): The user's query text.
241
+
242
+ Returns:
243
+ dict[str, str|None]: A dictionary with keys:
244
+ - "month_number": the month number as a string (e.g. '7'), or None if not found
245
+ - "month_name": the full English month name (e.g. 'July'), or None if not found
246
+
247
+ Example:
248
+ >>> await find_month("Show me the temperature in Paris in July")
249
+ {'month_number': '7', 'month_name': 'July'}
250
+ >>> await find_month("Show me the temperature in Paris")
251
+ {'month_number': None, 'month_name': None}
252
+ """
253
+
254
+ llm = get_llm()
255
+ prompt = """
256
+ Extract the month (as a number from 1 to 12) mentioned in the following sentence.
257
+ Return the result as a Python list of integers. If no month is mentioned, return an empty list.
258
+
259
+ Sentence: "{sentence}"
260
+ """
261
+ prompt = ChatPromptTemplate.from_template(prompt)
262
+ structured_llm = llm.with_structured_output(ArrayOutput)
263
+ chain = prompt | structured_llm
264
+ response: ArrayOutput = await chain.ainvoke({"sentence": user_input})
265
+ months_list = ast.literal_eval(response['array'])
266
+ if len(months_list) > 0:
267
+ month_number = int(months_list[0])
268
+ month_name = calendar.month_name[month_number]
269
+ return {
270
+ "month_number": str(month_number),
271
+ "month_name": month_name
272
+ }
273
+ else:
274
+ return {
275
+ "month_number" : None,
276
+ "month_name" : None
277
+ }
278
+
279
+
280
  async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]:
281
  print("---- Find relevant plots ----")
282
  relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots)
 
287
  relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables)
288
  return relevant_tables
289
 
290
+ async def find_param(state: State, param_name: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None:
291
+ """
292
+ Retrieves a specific parameter (location, year, month, etc.) from the user's input using the appropriate extraction method.
293
 
294
  Args:
295
+ state (State): The current state containing at least the user's input under 'user_input'.
296
+ param_name (str): The name of the parameter to extract. Supported: 'location', 'year', 'month'.
297
+ mode (Literal['DRIAS', 'IPCC']): The data mode to use for location extraction.
298
 
299
  Returns:
300
+ - For 'location': a Location object (dict with keys like 'location', 'latitude', etc.), or None if not found.
301
+ - For 'year': a dict {'year': year or None}.
302
+ - For 'month': a dict {'month_number': str or None, 'month_name': str or None}.
303
+ - None if the parameter is not recognized or not found.
304
+
305
+ Example:
306
+ >>> await find_param(state, 'location')
307
+ {'location': 'Paris', 'latitude': ..., ...}
308
+ >>> await find_param(state, 'year')
309
+ {'year': '2050'}
310
+ >>> await find_param(state, 'month')
311
+ {'month_number': '7', 'month_name': 'July'}
312
  """
313
  if param_name == 'location':
314
  location = await find_location(state['user_input'], mode)
 
316
  if param_name == 'year':
317
  year = await find_year(state['user_input'])
318
  return {'year': year}
319
+ if param_name == 'month':
320
+ month = await find_month(state['user_input'])
321
+ return month
322
+ return None
climateqa/engine/talk_to_data/ipcc/config.py CHANGED
@@ -6,16 +6,22 @@ from climateqa.engine.talk_to_data.config import IPCC_DATASET_URL
6
  IPCC_TABLES = [
7
  "mean_temperature",
8
  "total_precipitation",
 
 
9
  ]
10
 
11
  IPCC_INDICATOR_COLUMNS_PER_TABLE = {
12
  "mean_temperature": "mean_temperature",
13
- "total_precipitation": "total_precipitation"
 
 
14
  }
15
 
16
  IPCC_INDICATOR_TO_UNIT = {
17
  "mean_temperature": "°C",
18
- "total_precipitation": "mm/day"
 
 
19
  }
20
 
21
  IPCC_SCENARIO = [
@@ -30,7 +36,8 @@ IPCC_MODELS = []
30
 
31
  IPCC_PLOT_PARAMETERS = [
32
  'year',
33
- 'location'
 
34
  ]
35
 
36
  MACRO_COUNTRIES = ['JP',
@@ -63,7 +70,9 @@ HUGE_MACRO_COUNTRIES = ['CL',
63
 
64
  IPCC_INDICATOR_TO_COLORSCALE = {
65
  "mean_temperature": TEMPERATURE_COLORSCALE,
66
- "total_precipitation": PRECIPITATION_COLORSCALE
 
 
67
  }
68
 
69
  IPCC_UI_TEXT = """
@@ -77,9 +86,12 @@ By default, we take the **mediane of each climate model**.
77
  Current available charts :
78
  - Yearly evolution of an indicator at a specific location (historical + SSP Projections)
79
  - Yearly spatial distribution of an indicator in a specific country
 
80
 
81
  Current available indicators :
82
  - Mean temperature
 
 
83
  - Total precipitation
84
 
85
  For example, you can ask:
 
6
  IPCC_TABLES = [
7
  "mean_temperature",
8
  "total_precipitation",
9
+ "minimum_temperature",
10
+ "maximum_temperature"
11
  ]
12
 
13
  IPCC_INDICATOR_COLUMNS_PER_TABLE = {
14
  "mean_temperature": "mean_temperature",
15
+ "total_precipitation": "total_precipitation",
16
+ "minimum_temperature": "minimum_temperature",
17
+ "maximum_temperature": "maximum_temperature"
18
  }
19
 
20
  IPCC_INDICATOR_TO_UNIT = {
21
  "mean_temperature": "°C",
22
+ "total_precipitation": "mm/day",
23
+ "minimum_temperature": "°C",
24
+ "maximum_temperature": "°C"
25
  }
26
 
27
  IPCC_SCENARIO = [
 
36
 
37
  IPCC_PLOT_PARAMETERS = [
38
  'year',
39
+ 'location',
40
+ 'month'
41
  ]
42
 
43
  MACRO_COUNTRIES = ['JP',
 
70
 
71
  IPCC_INDICATOR_TO_COLORSCALE = {
72
  "mean_temperature": TEMPERATURE_COLORSCALE,
73
+ "total_precipitation": PRECIPITATION_COLORSCALE,
74
+ "minimum_temperature": TEMPERATURE_COLORSCALE,
75
+ "maximum_temperature": TEMPERATURE_COLORSCALE,
76
  }
77
 
78
  IPCC_UI_TEXT = """
 
86
  Current available charts :
87
  - Yearly evolution of an indicator at a specific location (historical + SSP Projections)
88
  - Yearly spatial distribution of an indicator in a specific country
89
+ - Yearly evolution of an indicator in a specific month at a specific location (historical + SSP Projections)
90
 
91
  Current available indicators :
92
  - Mean temperature
93
+ - Minimum temperature
94
+ - Maximum temperature
95
  - Total precipitation
96
 
97
  For example, you can ask:
climateqa/engine/talk_to_data/ipcc/plot_informations.py CHANGED
@@ -47,4 +47,27 @@ Each grid point is colored according to the value of the indicator ({unit}), all
47
  - For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
48
  - The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
49
  - The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """
 
47
  - For each grid point of {location} country ({country_name}), the value of {indicator} in {year} and for the selected scenario is extracted and mapped to its geographic coordinates.
48
  - The grid points correspond to 1-degree squares centered on the grid points of the IPCC dataset. Each grid point has been mapped to a country using [**reverse_geocoder**](https://github.com/thampiman/reverse-geocoder).
49
  - The coordinates used for each region are those of the closest available grid point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
50
+ """
51
+
52
+ def indicator_specific_month_evolution_informations(
53
+ indicator: str,
54
+ params: dict[str, str]
55
+ ) -> str:
56
+ if "location" not in params:
57
+ raise ValueError('"location" must be provided in params')
58
+ location = params["location"]
59
+ if "month_name" not in params:
60
+ raise ValueError('"month_name" must be provided in params')
61
+ month = params["month_name"]
62
+ unit = IPCC_INDICATOR_TO_UNIT[indicator]
63
+ return f"""
64
+ This plot shows how the climate indicator **{indicator}** evolves over time in **{location}** for the month of **{month}**.
65
+ It combines both historical (from 1950 to 2015) observations and future (from 2016 to 2100) projections for the different SSP climate scenarios (SSP126, SSP245, SSP370 and SSP585).
66
+ The x-axis represents the years (from 1950 to 2100), and the y-axis shows the value of the {indicator} ({unit}) for the selected month.
67
+ Each line corresponds to a different scenario, allowing you to compare how {indicator} for month {month} might change under various future conditions.
68
+
69
+ **Data source:**
70
+ - The data comes from the IPCC climate datasets (Parquet files) for the relevant indicator, location, and month.
71
+ - For each year and scenario, the value of {indicator} for month {month} is extracted for the selected location.
72
+ - The coordinates used for {location} correspond to the closest available point in the IPCC database, which uses a regular grid with a spatial resolution of 1 degree.
73
  """
climateqa/engine/talk_to_data/ipcc/plots.py CHANGED
@@ -5,8 +5,8 @@ import pandas as pd
5
  import geojson
6
 
7
  from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
8
- from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations
9
- from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_at_location_query
10
  from climateqa.engine.talk_to_data.objects.plot import Plot
11
 
12
  def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
@@ -102,6 +102,82 @@ indicator_evolution_at_location_historical_and_projections: Plot = {
102
  "short_name": "Evolution"
103
  }
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def plot_choropleth_map_of_country_indicator_for_specific_year(
106
  params: dict,
107
  ) -> Callable[[pd.DataFrame], Figure]:
@@ -167,6 +243,7 @@ def plot_choropleth_map_of_country_indicator_for_specific_year(
167
 
168
  return plot_data
169
 
 
170
  choropleth_map_of_country_indicator_for_specific_year: Plot = {
171
  "name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
172
  "description": (
@@ -185,5 +262,6 @@ choropleth_map_of_country_indicator_for_specific_year: Plot = {
185
 
186
  IPCC_PLOTS = [
187
  indicator_evolution_at_location_historical_and_projections,
188
- choropleth_map_of_country_indicator_for_specific_year
 
189
  ]
 
5
  import geojson
6
 
7
  from climateqa.engine.talk_to_data.ipcc.config import IPCC_INDICATOR_TO_COLORSCALE, IPCC_INDICATOR_TO_UNIT, IPCC_SCENARIO
8
+ from climateqa.engine.talk_to_data.ipcc.plot_informations import choropleth_map_informations, indicator_evolution_informations, indicator_specific_month_evolution_informations
9
+ from climateqa.engine.talk_to_data.ipcc.queries import indicator_for_given_year_query, indicator_per_year_and_specific_month_at_location_query, indicator_per_year_at_location_query
10
  from climateqa.engine.talk_to_data.objects.plot import Plot
11
 
12
  def generate_geojson_polygons(latitudes: list[float], longitudes: list[float], indicators: list[float]) -> geojson.FeatureCollection:
 
102
  "short_name": "Evolution"
103
  }
104
 
105
+ def plot_indicator_monthly_evolution_at_location(
106
+ params: dict,
107
+ ) -> Callable[[pd.DataFrame], Figure]:
108
+ """
109
+ Returns a function that generates a line plot showing the evolution of a climate indicator
110
+ for a specific month over time at a specific location, including both historical data
111
+ and future projections for different climate scenarios.
112
+
113
+ Args:
114
+ params (dict): Dictionary with:
115
+ - indicator_column (str): Name of the climate indicator column to plot.
116
+ - location (str): Location (e.g., country, city) for which to plot the indicator.
117
+ - month (str): Month name to plot.
118
+
119
+ Returns:
120
+ Callable[[pd.DataFrame], Figure]: Function that takes a DataFrame and returns a Plotly Figure.
121
+ """
122
+ indicator = params["indicator_column"]
123
+ location = params["location"]
124
+ month = params["month_name"]
125
+ indicator_label = " ".join(word.capitalize() for word in indicator.split("_"))
126
+ unit = IPCC_INDICATOR_TO_UNIT.get(indicator, "")
127
+
128
+ def plot_data(df: pd.DataFrame) -> Figure:
129
+ df = df.sort_values(by='year')
130
+ years = df['year'].astype(int).tolist()
131
+ indicators = df[indicator].astype(float).tolist()
132
+ scenarios = df['scenario'].astype(str).tolist()
133
+
134
+ # Find last historical value for continuity
135
+ last_historical = [(y, v) for y, v, s in zip(years, indicators, scenarios) if s == 'historical']
136
+ last_historical_year, last_historical_indicator = last_historical[-1] if last_historical else (None, None)
137
+
138
+ fig = go.Figure()
139
+ for scenario in IPCC_SCENARIO:
140
+ x = [y for y, s in zip(years, scenarios) if s == scenario]
141
+ y = [v for v, s in zip(indicators, scenarios) if s == scenario]
142
+ # Connect historical to scenario
143
+ if scenario != 'historical' and last_historical_indicator is not None:
144
+ x = [last_historical_year] + x
145
+ y = [last_historical_indicator] + y
146
+ fig.add_trace(go.Scatter(
147
+ x=x,
148
+ y=y,
149
+ mode='lines',
150
+ name=scenario
151
+ ))
152
+
153
+ fig.update_layout(
154
+ title=f'Evolution of {indicator_label} in {month} in {location} (Historical + SSP Scenarios)',
155
+ xaxis_title='Year',
156
+ yaxis_title=f'{indicator_label} ({unit})',
157
+ legend_title='Scenario',
158
+ height=800,
159
+ )
160
+ return fig
161
+
162
+ return plot_data
163
+
164
+
165
+ indicator_specific_month_evolution_at_location: Plot = {
166
+ "name": "Indicator specific month Evolution at Location (Historical + Projections)",
167
+ "description": (
168
+ "Shows how a climate indicator (e.g., rainfall, temperature) for a specific month changes over time at a specific location, "
169
+ "including historical data and future projections. "
170
+ "Useful for questions about the value or trend of an indicator for a given month at a location, "
171
+ "such as 'How does July temperature evolve in Paris over time?'. "
172
+ "Parameters: indicator_column (the climate variable), location (e.g., country, city), month (1-12)."
173
+ ),
174
+ "params": ["indicator_column", "location", "month"],
175
+ "plot_function": plot_indicator_monthly_evolution_at_location,
176
+ "sql_query": indicator_per_year_and_specific_month_at_location_query,
177
+ "plot_information": indicator_specific_month_evolution_informations,
178
+ "short_name": "Evolution for a specific month"
179
+ }
180
+
181
  def plot_choropleth_map_of_country_indicator_for_specific_year(
182
  params: dict,
183
  ) -> Callable[[pd.DataFrame], Figure]:
 
243
 
244
  return plot_data
245
 
246
+
247
  choropleth_map_of_country_indicator_for_specific_year: Plot = {
248
  "name": "Choropleth Map of a Country's Indicator Distribution for a Specific Year",
249
  "description": (
 
262
 
263
  IPCC_PLOTS = [
264
  indicator_evolution_at_location_historical_and_projections,
265
+ choropleth_map_of_country_indicator_for_specific_year,
266
+ indicator_specific_month_evolution_at_location
267
  ]
climateqa/engine/talk_to_data/ipcc/queries.py CHANGED
@@ -43,7 +43,7 @@ def indicator_per_year_at_location_query(
43
  return ""
44
 
45
  if country_code in MACRO_COUNTRIES:
46
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
47
  sql_query = f"""
48
  SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
49
  FROM {table_path}
@@ -52,7 +52,7 @@ def indicator_per_year_at_location_query(
52
  ORDER BY year, scenario
53
  """
54
  elif country_code in HUGE_MACRO_COUNTRIES:
55
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
56
  sql_query = f"""
57
  SELECT year, scenario, {indicator_column}
58
  FROM {table_path}
@@ -75,6 +75,66 @@ def indicator_per_year_at_location_query(
75
  """
76
  return sql_query.strip()
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  class IndicatorForGivenYearQueryParams(TypedDict, total=False):
79
  """
80
  Parameters for querying an indicator's values across locations for a specific year.
@@ -110,7 +170,7 @@ def indicator_for_given_year_query(
110
  return ""
111
 
112
  if country_code in MACRO_COUNTRIES:
113
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
114
  sql_query = f"""
115
  SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
116
  FROM {table_path}
@@ -119,7 +179,7 @@ def indicator_for_given_year_query(
119
  ORDER BY latitude, longitude, scenario
120
  """
121
  elif country_code in HUGE_MACRO_COUNTRIES:
122
- table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_macro.parquet'"
123
  sql_query = f"""
124
  SELECT latitude, longitude, scenario, {indicator_column}
125
  FROM {table_path}
@@ -141,4 +201,4 @@ def indicator_for_given_year_query(
141
  ORDER BY latitude, longitude, scenario
142
  """
143
 
144
- return sql_query.strip()
 
43
  return ""
44
 
45
  if country_code in MACRO_COUNTRIES:
46
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_monthly_macro.parquet'"
47
  sql_query = f"""
48
  SELECT year, scenario, AVG({indicator_column}) as {indicator_column}
49
  FROM {table_path}
 
52
  ORDER BY year, scenario
53
  """
54
  elif country_code in HUGE_MACRO_COUNTRIES:
55
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_annualy_macro.parquet'"
56
  sql_query = f"""
57
  SELECT year, scenario, {indicator_column}
58
  FROM {table_path}
 
75
  """
76
  return sql_query.strip()
77
 
78
+ class IndicatorPerYearAndSpecificMonthAtLocationQueryParams(TypedDict, total=False):
79
+ """
80
+ Parameters for querying the evolution of an indicator per year for a specific month at a specific location.
81
+
82
+ Attributes:
83
+ indicator_column (str): Name of the climate indicator column.
84
+ latitude (str): Latitude of the location.
85
+ longitude (str): Longitude of the location.
86
+ country_code (str): Country code.
87
+ month (str): Month targeted
88
+ """
89
+ indicator_column: str
90
+ latitude: str
91
+ longitude: str
92
+ country_code: str
93
+ month: str
94
+
95
+ def indicator_per_year_and_specific_month_at_location_query(
96
+ table: str, params: IndicatorPerYearAndSpecificMonthAtLocationQueryParams
97
+ ) -> str:
98
+ """
99
+ Builds an SQL query to get the evolution of an indicator per year for a specific month at a specific location.
100
+
101
+ Args:
102
+ table (str): SQL table of the indicator.
103
+ params (dict): Dictionary with required params:
104
+ - indicator_column (str)
105
+ - latitude (str or float)
106
+ - longitude (str or float)
107
+ - month (int)
108
+
109
+ Returns:
110
+ str: The SQL query string.
111
+ """
112
+ indicator_column = params.get("indicator_column")
113
+ latitude = params.get("latitude")
114
+ longitude = params.get("longitude")
115
+ country_code = params.get("country_code")
116
+ month = params.get('month_number')
117
+
118
+ if not all([indicator_column, latitude, longitude, country_code, month]):
119
+ return ""
120
+
121
+ if country_code in (MACRO_COUNTRIES+HUGE_MACRO_COUNTRIES):
122
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_monthly_macro.parquet'"
123
+ sql_query = f"""
124
+ SELECT year, scenario, {indicator_column}
125
+ FROM {table_path}
126
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950 AND month={month}
127
+ ORDER BY year, scenario
128
+ """
129
+ else:
130
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}.parquet'"
131
+ sql_query = f"""
132
+ SELECT year, scenario, MEDIAN({indicator_column}) AS {indicator_column}
133
+ FROM {table_path}
134
+ WHERE latitude = {latitude} AND longitude = {longitude} AND year >= 1950 AND month={month}
135
+ GROUP BY scenario, year
136
+ """
137
+ return sql_query.strip()
138
  class IndicatorForGivenYearQueryParams(TypedDict, total=False):
139
  """
140
  Parameters for querying an indicator's values across locations for a specific year.
 
170
  return ""
171
 
172
  if country_code in MACRO_COUNTRIES:
173
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_monthly_macro.parquet'"
174
  sql_query = f"""
175
  SELECT latitude, longitude, scenario, AVG({indicator_column}) as {indicator_column}
176
  FROM {table_path}
 
179
  ORDER BY latitude, longitude, scenario
180
  """
181
  elif country_code in HUGE_MACRO_COUNTRIES:
182
+ table_path = f"'{IPCC_DATASET_URL}/{table.lower()}/{country_code}_annualy_macro.parquet'"
183
  sql_query = f"""
184
  SELECT latitude, longitude, scenario, {indicator_column}
185
  FROM {table_path}
 
201
  ORDER BY latitude, longitude, scenario
202
  """
203
 
204
+ return sql_query.strip()
climateqa/engine/talk_to_data/main.py CHANGED
@@ -50,7 +50,7 @@ async def ask_drias(query: str, index_state: int = 0, user_id: str | None = None
50
 
51
  if "error" in final_state and final_state["error"] != "":
52
  # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
53
- return None, None, None, None, [], [], [], 0, [], final_state["error"]
54
 
55
  sql_query = sql_queries[index_state]
56
  dataframe = result_dataframes[index_state]
@@ -112,7 +112,7 @@ async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None)
112
 
113
  if "error" in final_state and final_state["error"] != "":
114
  # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
115
- return None, None, None, None, [], [], [], 0, [], final_state["error"]
116
 
117
  sql_query = sql_queries[index_state]
118
  dataframe = result_dataframes[index_state]
 
50
 
51
  if "error" in final_state and final_state["error"] != "":
52
  # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
53
+ return None, None, None, None, [], [], [], [], 0, [], final_state["error"]
54
 
55
  sql_query = sql_queries[index_state]
56
  dataframe = result_dataframes[index_state]
 
112
 
113
  if "error" in final_state and final_state["error"] != "":
114
  # No Sql query, no dataframe, no figure, no plot information, empty sql queries list, empty result dataframes list, empty figures list, empty plot information list, index state = 0, empty table list, error message
115
+ return None, None, None, None, [], [], [], [], 0, [], final_state["error"]
116
 
117
  sql_query = sql_queries[index_state]
118
  dataframe = result_dataframes[index_state]
climateqa/engine/talk_to_data/myVanna.py DELETED
@@ -1,13 +0,0 @@
1
- from dotenv import load_dotenv
2
- from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
3
- from vanna.openai import OpenAI_Chat
4
- import os
5
-
6
- load_dotenv()
7
-
8
- OPENAI_API_KEY = os.getenv('THEO_API_KEY')
9
-
10
- class MyVanna(MyCustomVectorDB, OpenAI_Chat):
11
- def __init__(self, config=None):
12
- MyCustomVectorDB.__init__(self, config=config)
13
- OpenAI_Chat.__init__(self, config=config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/plot.py DELETED
@@ -1,418 +0,0 @@
1
- from typing import Callable, TypedDict
2
- from matplotlib.figure import figaspect
3
- import pandas as pd
4
- from plotly.graph_objects import Figure
5
- import plotly.graph_objects as go
6
- import plotly.express as px
7
-
8
- from climateqa.engine.talk_to_data.sql_query import (
9
- indicator_for_given_year_query,
10
- indicator_per_year_at_location_query,
11
- )
12
- from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
13
-
14
-
15
-
16
-
17
- class Plot(TypedDict):
18
- """Represents a plot configuration in the DRIAS system.
19
-
20
- This class defines the structure for configuring different types of plots
21
- that can be generated from climate data.
22
-
23
- Attributes:
24
- name (str): The name of the plot type
25
- description (str): A description of what the plot shows
26
- params (list[str]): List of required parameters for the plot
27
- plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
28
- sql_query (Callable[..., str]): Function to generate the SQL query for the plot
29
- """
30
- name: str
31
- description: str
32
- params: list[str]
33
- plot_function: Callable[..., Callable[..., Figure]]
34
- sql_query: Callable[..., str]
35
-
36
-
37
- def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
38
- """Generates a function to plot indicator evolution over time at a location.
39
-
40
- This function creates a line plot showing how a climate indicator changes
41
- over time at a specific location. It handles temperature, precipitation,
42
- and other climate indicators.
43
-
44
- Args:
45
- params (dict): Dictionary containing:
46
- - indicator_column (str): The column name for the indicator
47
- - location (str): The location to plot
48
- - model (str): The climate model to use
49
-
50
- Returns:
51
- Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
52
-
53
- Example:
54
- >>> plot_func = plot_indicator_evolution_at_location({
55
- ... 'indicator_column': 'mean_temperature',
56
- ... 'location': 'Paris',
57
- ... 'model': 'ALL'
58
- ... })
59
- >>> fig = plot_func(df)
60
- """
61
- indicator = params["indicator_column"]
62
- location = params["location"]
63
- indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
64
- unit = INDICATOR_TO_UNIT.get(indicator, "")
65
-
66
- def plot_data(df: pd.DataFrame) -> Figure:
67
- """Generates the actual plot from the data.
68
-
69
- Args:
70
- df (pd.DataFrame): DataFrame containing the data to plot
71
-
72
- Returns:
73
- Figure: A plotly Figure object showing the indicator evolution
74
- """
75
- fig = go.Figure()
76
- if df['model'].nunique() != 1:
77
- df_avg = df.groupby("year", as_index=False)[indicator].mean()
78
-
79
- # Transform to list to avoid pandas encoding
80
- indicators = df_avg[indicator].astype(float).tolist()
81
- years = df_avg["year"].astype(int).tolist()
82
-
83
- # Compute the 10-year rolling average
84
- rolling_window = 10
85
- sliding_averages = (
86
- df_avg[indicator]
87
- .rolling(window=rolling_window, min_periods=rolling_window)
88
- .mean()
89
- .astype(float)
90
- .tolist()
91
- )
92
- model_label = "Model Average"
93
-
94
- # Only add rolling average if we have enough data points
95
- if len([x for x in sliding_averages if pd.notna(x)]) > 0:
96
- # Sliding average dashed line
97
- fig.add_scatter(
98
- x=years,
99
- y=sliding_averages,
100
- mode="lines",
101
- name="10 years rolling average",
102
- line=dict(dash="dash"),
103
- marker=dict(color="#d62728"),
104
- hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
105
- )
106
-
107
- else:
108
- df_model = df
109
-
110
- # Transform to list to avoid pandas encoding
111
- indicators = df_model[indicator].astype(float).tolist()
112
- years = df_model["year"].astype(int).tolist()
113
-
114
- # Compute the 10-year rolling average
115
- rolling_window = 10
116
- sliding_averages = (
117
- df_model[indicator]
118
- .rolling(window=rolling_window, min_periods=rolling_window)
119
- .mean()
120
- .astype(float)
121
- .tolist()
122
- )
123
- model_label = f"Model : {df['model'].unique()[0]}"
124
-
125
- # Only add rolling average if we have enough data points
126
- if len([x for x in sliding_averages if pd.notna(x)]) > 0:
127
- # Sliding average dashed line
128
- fig.add_scatter(
129
- x=years,
130
- y=sliding_averages,
131
- mode="lines",
132
- name="10 years rolling average",
133
- line=dict(dash="dash"),
134
- marker=dict(color="#d62728"),
135
- hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
136
- )
137
-
138
- # Indicator per year plot
139
- fig.add_scatter(
140
- x=years,
141
- y=indicators,
142
- name=f"Yearly {indicator_label}",
143
- mode="lines",
144
- marker=dict(color="#1f77b4"),
145
- hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
146
- )
147
- fig.update_layout(
148
- title=f"Plot of {indicator_label} in {location} ({model_label})",
149
- xaxis_title="Year",
150
- yaxis_title=f"{indicator_label} ({unit})",
151
- template="plotly_white",
152
- )
153
- return fig
154
-
155
- return plot_data
156
-
157
-
158
- indicator_evolution_at_location: Plot = {
159
- "name": "Indicator evolution at location",
160
- "description": "Plot an evolution of the indicator at a certain location",
161
- "params": ["indicator_column", "location", "model"],
162
- "plot_function": plot_indicator_evolution_at_location,
163
- "sql_query": indicator_per_year_at_location_query,
164
- }
165
-
166
-
167
- def plot_indicator_number_of_days_per_year_at_location(
168
- params: dict,
169
- ) -> Callable[..., Figure]:
170
- """Generates a function to plot the number of days per year for an indicator.
171
-
172
- This function creates a bar chart showing the frequency of certain climate
173
- events (like days above a temperature threshold) per year at a specific location.
174
-
175
- Args:
176
- params (dict): Dictionary containing:
177
- - indicator_column (str): The column name for the indicator
178
- - location (str): The location to plot
179
- - model (str): The climate model to use
180
-
181
- Returns:
182
- Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
183
- """
184
- indicator = params["indicator_column"]
185
- location = params["location"]
186
- indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
187
- unit = INDICATOR_TO_UNIT.get(indicator, "")
188
-
189
- def plot_data(df: pd.DataFrame) -> Figure:
190
- """Generate the figure thanks to the dataframe
191
-
192
- Args:
193
- df (pd.DataFrame): pandas dataframe with the required data
194
-
195
- Returns:
196
- Figure: Plotly figure
197
- """
198
- fig = go.Figure()
199
- if df['model'].nunique() != 1:
200
- df_avg = df.groupby("year", as_index=False)[indicator].mean()
201
-
202
- # Transform to list to avoid pandas encoding
203
- indicators = df_avg[indicator].astype(float).tolist()
204
- years = df_avg["year"].astype(int).tolist()
205
- model_label = "Model Average"
206
-
207
- else:
208
- df_model = df
209
- # Transform to list to avoid pandas encoding
210
- indicators = df_model[indicator].astype(float).tolist()
211
- years = df_model["year"].astype(int).tolist()
212
- model_label = f"Model : {df['model'].unique()[0]}"
213
-
214
-
215
- # Bar plot
216
- fig.add_trace(
217
- go.Bar(
218
- x=years,
219
- y=indicators,
220
- width=0.5,
221
- marker=dict(color="#1f77b4"),
222
- hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
223
- )
224
- )
225
-
226
- fig.update_layout(
227
- title=f"{indicator_label} in {location} ({model_label})",
228
- xaxis_title="Year",
229
- yaxis_title=f"{indicator_label} ({unit})",
230
- yaxis=dict(range=[0, max(indicators)]),
231
- bargap=0.5,
232
- template="plotly_white",
233
- )
234
-
235
- return fig
236
-
237
- return plot_data
238
-
239
-
240
- indicator_number_of_days_per_year_at_location: Plot = {
241
- "name": "Indicator number of days per year at location",
242
- "description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
243
- "params": ["indicator_column", "location", "model"],
244
- "plot_function": plot_indicator_number_of_days_per_year_at_location,
245
- "sql_query": indicator_per_year_at_location_query,
246
- }
247
-
248
-
249
- def plot_distribution_of_indicator_for_given_year(
250
- params: dict,
251
- ) -> Callable[..., Figure]:
252
- """Generates a function to plot the distribution of an indicator for a year.
253
-
254
- This function creates a histogram showing the distribution of a climate
255
- indicator across different locations for a specific year.
256
-
257
- Args:
258
- params (dict): Dictionary containing:
259
- - indicator_column (str): The column name for the indicator
260
- - year (str): The year to plot
261
- - model (str): The climate model to use
262
-
263
- Returns:
264
- Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
265
- """
266
- indicator = params["indicator_column"]
267
- year = params["year"]
268
- indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
269
- unit = INDICATOR_TO_UNIT.get(indicator, "")
270
-
271
- def plot_data(df: pd.DataFrame) -> Figure:
272
- """Generate the figure thanks to the dataframe
273
-
274
- Args:
275
- df (pd.DataFrame): pandas dataframe with the required data
276
-
277
- Returns:
278
- Figure: Plotly figure
279
- """
280
- fig = go.Figure()
281
- if df['model'].nunique() != 1:
282
- df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
283
- indicator
284
- ].mean()
285
-
286
- # Transform to list to avoid pandas encoding
287
- indicators = df_avg[indicator].astype(float).tolist()
288
- model_label = "Model Average"
289
-
290
- else:
291
- df_model = df
292
-
293
- # Transform to list to avoid pandas encoding
294
- indicators = df_model[indicator].astype(float).tolist()
295
- model_label = f"Model : {df['model'].unique()[0]}"
296
-
297
-
298
- fig.add_trace(
299
- go.Histogram(
300
- x=indicators,
301
- opacity=0.8,
302
- histnorm="percent",
303
- marker=dict(color="#1f77b4"),
304
- hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
305
- )
306
- )
307
-
308
- fig.update_layout(
309
- title=f"Distribution of {indicator_label} in {year} ({model_label})",
310
- xaxis_title=f"{indicator_label} ({unit})",
311
- yaxis_title="Frequency (%)",
312
- plot_bgcolor="rgba(0, 0, 0, 0)",
313
- showlegend=False,
314
- )
315
-
316
- return fig
317
-
318
- return plot_data
319
-
320
-
321
- distribution_of_indicator_for_given_year: Plot = {
322
- "name": "Distribution of an indicator for a given year",
323
- "description": "Plot an histogram of the distribution for a given year of the values of an indicator",
324
- "params": ["indicator_column", "model", "year"],
325
- "plot_function": plot_distribution_of_indicator_for_given_year,
326
- "sql_query": indicator_for_given_year_query,
327
- }
328
-
329
-
330
- def plot_map_of_france_of_indicator_for_given_year(
331
- params: dict,
332
- ) -> Callable[..., Figure]:
333
- """Generates a function to plot a map of France for an indicator.
334
-
335
- This function creates a choropleth map of France showing the spatial
336
- distribution of a climate indicator for a specific year.
337
-
338
- Args:
339
- params (dict): Dictionary containing:
340
- - indicator_column (str): The column name for the indicator
341
- - year (str): The year to plot
342
- - model (str): The climate model to use
343
-
344
- Returns:
345
- Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
346
- """
347
- indicator = params["indicator_column"]
348
- year = params["year"]
349
- indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
350
- unit = INDICATOR_TO_UNIT.get(indicator, "")
351
-
352
- def plot_data(df: pd.DataFrame) -> Figure:
353
- fig = go.Figure()
354
- if df['model'].nunique() != 1:
355
- df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
356
- indicator
357
- ].mean()
358
-
359
- indicators = df_avg[indicator].astype(float).tolist()
360
- latitudes = df_avg["latitude"].astype(float).tolist()
361
- longitudes = df_avg["longitude"].astype(float).tolist()
362
- model_label = "Model Average"
363
-
364
- else:
365
- df_model = df
366
-
367
- # Transform to list to avoid pandas encoding
368
- indicators = df_model[indicator].astype(float).tolist()
369
- latitudes = df_model["latitude"].astype(float).tolist()
370
- longitudes = df_model["longitude"].astype(float).tolist()
371
- model_label = f"Model : {df['model'].unique()[0]}"
372
-
373
-
374
- fig.add_trace(
375
- go.Scattermapbox(
376
- lat=latitudes,
377
- lon=longitudes,
378
- mode="markers",
379
- marker=dict(
380
- size=10,
381
- color=indicators, # Color mapped to values
382
- colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
383
- cmin=min(indicators), # Minimum color range
384
- cmax=max(indicators), # Maximum color range
385
- showscale=True, # Show colorbar
386
- ),
387
- text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
388
- hoverinfo="text" # Only show the custom text on hover
389
- )
390
- )
391
-
392
- fig.update_layout(
393
- mapbox_style="open-street-map", # Use OpenStreetMap
394
- mapbox_zoom=3,
395
- mapbox_center={"lat": 46.6, "lon": 2.0},
396
- coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
397
- title=f"{indicator_label} in {year} in France ({model_label}) " # Title
398
- )
399
- return fig
400
-
401
- return plot_data
402
-
403
-
404
- map_of_france_of_indicator_for_given_year: Plot = {
405
- "name": "Map of France of an indicator for a given year",
406
- "description": "Heatmap on the map of France of the values of an in indicator for a given year",
407
- "params": ["indicator_column", "year", "model"],
408
- "plot_function": plot_map_of_france_of_indicator_for_given_year,
409
- "sql_query": indicator_for_given_year_query,
410
- }
411
-
412
-
413
- PLOTS = [
414
- indicator_evolution_at_location,
415
- indicator_number_of_days_per_year_at_location,
416
- distribution_of_indicator_for_given_year,
417
- map_of_france_of_indicator_for_given_year,
418
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/sql_query.py DELETED
@@ -1,114 +0,0 @@
1
- import asyncio
2
- from concurrent.futures import ThreadPoolExecutor
3
- from typing import TypedDict
4
- import duckdb
5
- import pandas as pd
6
-
7
- async def execute_sql_query(sql_query: str) -> pd.DataFrame:
8
- """Executes a SQL query on the DRIAS database and returns the results.
9
-
10
- This function connects to the DuckDB database containing DRIAS climate data
11
- and executes the provided SQL query. It handles the database connection and
12
- returns the results as a pandas DataFrame.
13
-
14
- Args:
15
- sql_query (str): The SQL query to execute
16
-
17
- Returns:
18
- pd.DataFrame: A DataFrame containing the query results
19
-
20
- Raises:
21
- duckdb.Error: If there is an error executing the SQL query
22
- """
23
- def _execute_query():
24
- # Execute the query
25
- con = duckdb.connect()
26
- results = con.sql(sql_query).fetchdf()
27
- # return fetched data
28
- return results
29
-
30
- # Run the query in a thread pool to avoid blocking
31
- loop = asyncio.get_event_loop()
32
- with ThreadPoolExecutor() as executor:
33
- return await loop.run_in_executor(executor, _execute_query)
34
-
35
-
36
- class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
37
- """Parameters for querying an indicator's values over time at a location.
38
-
39
- This class defines the parameters needed to query climate indicator data
40
- for a specific location over multiple years.
41
-
42
- Attributes:
43
- indicator_column (str): The column name for the climate indicator
44
- latitude (str): The latitude coordinate of the location
45
- longitude (str): The longitude coordinate of the location
46
- model (str): The climate model to use (optional)
47
- """
48
- indicator_column: str
49
- latitude: str
50
- longitude: str
51
- model: str
52
-
53
-
54
- def indicator_per_year_at_location_query(
55
- table: str, params: IndicatorPerYearAtLocationQueryParams
56
- ) -> str:
57
- """SQL Query to get the evolution of an indicator per year at a certain location
58
-
59
- Args:
60
- table (str): sql table of the indicator
61
- params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
62
-
63
- Returns:
64
- str: the sql query
65
- """
66
- indicator_column = params.get("indicator_column")
67
- latitude = params.get("latitude")
68
- longitude = params.get("longitude")
69
-
70
- if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
71
- return ""
72
-
73
- table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
74
-
75
- sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
76
-
77
- return sql_query
78
-
79
- class IndicatorForGivenYearQueryParams(TypedDict, total=False):
80
- """Parameters for querying an indicator's values across locations for a year.
81
-
82
- This class defines the parameters needed to query climate indicator data
83
- across different locations for a specific year.
84
-
85
- Attributes:
86
- indicator_column (str): The column name for the climate indicator
87
- year (str): The year to query
88
- model (str): The climate model to use (optional)
89
- """
90
- indicator_column: str
91
- year: str
92
- model: str
93
-
94
- def indicator_for_given_year_query(
95
- table:str, params: IndicatorForGivenYearQueryParams
96
- ) -> str:
97
- """SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
98
-
99
- Args:
100
- table (str): sql table of the indicator
101
- params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
102
-
103
- Returns:
104
- str: the sql query
105
- """
106
- indicator_column = params.get("indicator_column")
107
- year = params.get('year')
108
- if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
109
- return ""
110
-
111
- table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
112
-
113
- sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
114
- return sql_query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/talk_to_drias.py DELETED
@@ -1,317 +0,0 @@
1
- import os
2
-
3
- from typing import Any, Callable, TypedDict, Optional
4
- from numpy import sort
5
- import pandas as pd
6
- import asyncio
7
- from plotly.graph_objects import Figure
8
- from climateqa.engine.llm import get_llm
9
- from climateqa.engine.talk_to_data import sql_query
10
- from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
11
- from climateqa.engine.talk_to_data.plot import PLOTS, Plot
12
- from climateqa.engine.talk_to_data.sql_query import execute_sql_query
13
- from climateqa.engine.talk_to_data.utils import (
14
- detect_relevant_plots,
15
- detect_year_with_openai,
16
- loc2coords,
17
- detect_location_with_openai,
18
- nearestNeighbourSQL,
19
- detect_relevant_tables,
20
- )
21
-
22
-
23
- ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
24
-
25
- class TableState(TypedDict):
26
- """Represents the state of a table in the DRIAS workflow.
27
-
28
- This class defines the structure for tracking the state of a table during the
29
- data processing workflow, including its name, parameters, SQL query, and results.
30
-
31
- Attributes:
32
- table_name (str): The name of the table in the database
33
- params (dict[str, Any]): Parameters used for querying the table
34
- sql_query (str, optional): The SQL query used to fetch data
35
- dataframe (pd.DataFrame | None, optional): The resulting data
36
- figure (Callable[..., Figure], optional): Function to generate visualization
37
- status (str): The current status of the table processing ('OK' or 'ERROR')
38
- """
39
- table_name: str
40
- params: dict[str, Any]
41
- sql_query: Optional[str]
42
- dataframe: Optional[pd.DataFrame | None]
43
- figure: Optional[Callable[..., Figure]]
44
- status: str
45
-
46
- class PlotState(TypedDict):
47
- """Represents the state of a plot in the DRIAS workflow.
48
-
49
- This class defines the structure for tracking the state of a plot during the
50
- data processing workflow, including its name and associated tables.
51
-
52
- Attributes:
53
- plot_name (str): The name of the plot
54
- tables (list[str]): List of tables used in the plot
55
- table_states (dict[str, TableState]): States of the tables used in the plot
56
- """
57
- plot_name: str
58
- tables: list[str]
59
- table_states: dict[str, TableState]
60
-
61
- class State(TypedDict):
62
- user_input: str
63
- plots: list[str]
64
- plot_states: dict[str, PlotState]
65
- error: Optional[str]
66
-
67
- async def find_relevant_plots(state: State, llm) -> list[str]:
68
- print("---- Find relevant plots ----")
69
- relevant_plots = await detect_relevant_plots(state['user_input'], llm)
70
- return relevant_plots
71
-
72
- async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
73
- print(f"---- Find relevant tables for {plot['name']} ----")
74
- relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
75
- return relevant_tables
76
-
77
- async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
78
- """Perform the good method to retrieve the desired parameter
79
-
80
- Args:
81
- state (State): state of the workflow
82
- param_name (str): name of the desired parameter
83
- table (str): name of the table
84
-
85
- Returns:
86
- dict[str, Any] | None:
87
- """
88
- if param_name == 'location':
89
- location = await find_location(state['user_input'], table)
90
- return location
91
- if param_name == 'year':
92
- year = await find_year(state['user_input'])
93
- return {'year': year}
94
- return None
95
-
96
- class Location(TypedDict):
97
- location: str
98
- latitude: Optional[str]
99
- longitude: Optional[str]
100
-
101
- async def find_location(user_input: str, table: str) -> Location:
102
- print(f"---- Find location in table {table} ----")
103
- location = await detect_location_with_openai(user_input)
104
- output: Location = {'location' : location}
105
- if location:
106
- coords = loc2coords(location)
107
- neighbour = nearestNeighbourSQL(coords, table)
108
- output.update({
109
- "latitude": neighbour[0],
110
- "longitude": neighbour[1],
111
- })
112
- return output
113
-
114
- async def find_year(user_input: str) -> str:
115
- """Extracts year information from user input using LLM.
116
-
117
- This function uses an LLM to identify and extract year information from the
118
- user's query, which is used to filter data in subsequent queries.
119
-
120
- Args:
121
- user_input (str): The user's query text
122
-
123
- Returns:
124
- str: The extracted year, or empty string if no year found
125
- """
126
- print(f"---- Find year ---")
127
- year = await detect_year_with_openai(user_input)
128
- return year
129
-
130
- def find_indicator_column(table: str) -> str:
131
- """Retrieves the name of the indicator column within a table.
132
-
133
- This function maps table names to their corresponding indicator columns
134
- using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
135
-
136
- Args:
137
- table (str): Name of the table in the database
138
-
139
- Returns:
140
- str: Name of the indicator column for the specified table
141
-
142
- Raises:
143
- KeyError: If the table name is not found in the mapping
144
- """
145
- print(f"---- Find indicator column in table {table} ----")
146
- return INDICATOR_COLUMNS_PER_TABLE[table]
147
-
148
-
149
- async def process_table(
150
- table: str,
151
- params: dict[str, Any],
152
- plot: Plot,
153
- ) -> TableState:
154
- """Processes a table to extract relevant data and generate visualizations.
155
-
156
- This function retrieves the SQL query for the specified table, executes it,
157
- and generates a visualization based on the results.
158
-
159
- Args:
160
- table (str): The name of the table to process
161
- params (dict[str, Any]): Parameters used for querying the table
162
- plot (Plot): The plot object containing SQL query and visualization function
163
-
164
- Returns:
165
- TableState: The state of the processed table
166
- """
167
- table_state: TableState = {
168
- 'table_name': table,
169
- 'params': params.copy(),
170
- 'status': 'OK',
171
- 'dataframe': None,
172
- 'sql_query': None,
173
- 'figure': None
174
- }
175
-
176
- table_state['params']['indicator_column'] = find_indicator_column(table)
177
- sql_query = plot['sql_query'](table, table_state['params'])
178
-
179
- if sql_query == "":
180
- table_state['status'] = 'ERROR'
181
- return table_state
182
- table_state['sql_query'] = sql_query
183
- df = await execute_sql_query(sql_query)
184
-
185
- table_state['dataframe'] = df
186
- table_state['figure'] = plot['plot_function'](table_state['params'])
187
-
188
- return table_state
189
-
190
- async def drias_workflow(user_input: str) -> State:
191
- """Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
192
-
193
- Args:
194
- user_input (str): initial user input
195
-
196
- Returns:
197
- State: Final state with all the results
198
- """
199
- state: State = {
200
- 'user_input': user_input,
201
- 'plots': [],
202
- 'plot_states': {},
203
- 'error': ''
204
- }
205
-
206
- llm = get_llm(provider="openai")
207
-
208
- plots = await find_relevant_plots(state, llm)
209
-
210
- state['plots'] = plots
211
-
212
- if len(state['plots']) < 1:
213
- state['error'] = 'There is no plot to answer to the question'
214
- return state
215
-
216
- have_relevant_table = False
217
- have_sql_query = False
218
- have_dataframe = False
219
-
220
- for plot_name in state['plots']:
221
-
222
- plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
223
- if plot is None:
224
- continue
225
-
226
- plot_state: PlotState = {
227
- 'plot_name': plot_name,
228
- 'tables': [],
229
- 'table_states': {}
230
- }
231
-
232
- plot_state['plot_name'] = plot_name
233
-
234
- relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
235
-
236
- if len(relevant_tables) > 0 :
237
- have_relevant_table = True
238
-
239
- plot_state['tables'] = relevant_tables
240
-
241
- params = {}
242
- for param_name in plot['params']:
243
- param = await find_param(state, param_name, relevant_tables[0])
244
- if param:
245
- params.update(param)
246
-
247
- tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
248
- results = await asyncio.gather(*tasks)
249
-
250
- # Store results back in plot_state
251
- have_dataframe = False
252
- have_sql_query = False
253
- for table_state in results:
254
- if table_state['sql_query']:
255
- have_sql_query = True
256
- if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
257
- have_dataframe = True
258
- plot_state['table_states'][table_state['table_name']] = table_state
259
-
260
- state['plot_states'][plot_name] = plot_state
261
-
262
- if not have_relevant_table:
263
- state['error'] = "There is no relevant table in our database to answer your question"
264
- elif not have_sql_query:
265
- state['error'] = "There is no relevant sql query on our database that can help to answer your question"
266
- elif not have_dataframe:
267
- state['error'] = "There is no data in our table that can answer to your question"
268
-
269
- return state
270
-
271
- # def make_write_query_node():
272
-
273
- # def write_query(state):
274
- # print("---- Write query ----")
275
- # for table in state["tables"]:
276
- # sql_query = QUERIES[state[table]['query_type']](
277
- # table=table,
278
- # indicator_column=state[table]["columns"],
279
- # longitude=state[table]["longitude"],
280
- # latitude=state[table]["latitude"],
281
- # )
282
- # state[table].update({"sql_query": sql_query})
283
-
284
- # return state
285
-
286
- # return write_query
287
-
288
- # def make_fetch_data_node(db_path):
289
-
290
- # def fetch_data(state):
291
- # print("---- Fetch data ----")
292
- # for table in state["tables"]:
293
- # results = execute_sql_query(db_path, state[table]['sql_query'])
294
- # state[table].update(results)
295
-
296
- # return state
297
-
298
- # return fetch_data
299
-
300
-
301
-
302
- ## V2
303
-
304
-
305
- # def make_fetch_data_node(db_path: str, llm):
306
- # def fetch_data(state):
307
- # print("---- Fetch data ----")
308
- # db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
309
- # output = {}
310
- # sql_query = write_sql_query(state["query"], db, state["tables"], llm)
311
- # # TO DO : Add query checker
312
- # print(f"SQL query : {sql_query}")
313
- # output["sql_query"] = sql_query
314
- # output.update(fetch_data_from_sql_query(db_path, sql_query))
315
- # return output
316
-
317
- # return fetch_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/utils.py DELETED
@@ -1,281 +0,0 @@
1
- import re
2
- from typing import Annotated, TypedDict
3
- import duckdb
4
- from geopy.geocoders import Nominatim
5
- import ast
6
- from climateqa.engine.llm import get_llm
7
- from climateqa.engine.talk_to_data.config import DRIAS_TABLES
8
- from climateqa.engine.talk_to_data.plot import PLOTS, Plot
9
- from langchain_core.prompts import ChatPromptTemplate
10
-
11
-
12
- async def detect_location_with_openai(sentence):
13
- """
14
- Detects locations in a sentence using OpenAI's API via LangChain.
15
- """
16
- llm = get_llm()
17
-
18
- prompt = f"""
19
- Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
20
- Return the result as a Python list. If no locations are mentioned, return an empty list.
21
-
22
- Sentence: "{sentence}"
23
- """
24
-
25
- response = await llm.ainvoke(prompt)
26
- location_list = ast.literal_eval(response.content.strip("```python\n").strip())
27
- if location_list:
28
- return location_list[0]
29
- else:
30
- return ""
31
-
32
- class ArrayOutput(TypedDict):
33
- """Represents the output of a function that returns an array.
34
-
35
- This class is used to type-hint functions that return arrays,
36
- ensuring consistent return types across the codebase.
37
-
38
- Attributes:
39
- array (str): A syntactically valid Python array string
40
- """
41
- array: Annotated[str, "Syntactically valid python array."]
42
-
43
- async def detect_year_with_openai(sentence: str) -> str:
44
- """
45
- Detects years in a sentence using OpenAI's API via LangChain.
46
- """
47
- llm = get_llm()
48
-
49
- prompt = """
50
- Extract all years mentioned in the following sentence.
51
- Return the result as a Python list. If no year are mentioned, return an empty list.
52
-
53
- Sentence: "{sentence}"
54
- """
55
-
56
- prompt = ChatPromptTemplate.from_template(prompt)
57
- structured_llm = llm.with_structured_output(ArrayOutput)
58
- chain = prompt | structured_llm
59
- response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
60
- years_list = eval(response['array'])
61
- if len(years_list) > 0:
62
- return years_list[0]
63
- else:
64
- return ""
65
-
66
-
67
- def detectTable(sql_query: str) -> list[str]:
68
- """Extracts table names from a SQL query.
69
-
70
- This function uses regular expressions to find all table names
71
- referenced in a SQL query's FROM clause.
72
-
73
- Args:
74
- sql_query (str): The SQL query to analyze
75
-
76
- Returns:
77
- list[str]: A list of table names found in the query
78
-
79
- Example:
80
- >>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
81
- ['temperature_data']
82
- """
83
- pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
84
- matches = re.findall(pattern, sql_query)
85
- return matches
86
-
87
-
88
- def loc2coords(location: str) -> tuple[float, float]:
89
- """Converts a location name to geographic coordinates.
90
-
91
- This function uses the Nominatim geocoding service to convert
92
- a location name (e.g., city name) to its latitude and longitude.
93
-
94
- Args:
95
- location (str): The name of the location to geocode
96
-
97
- Returns:
98
- tuple[float, float]: A tuple containing (latitude, longitude)
99
-
100
- Raises:
101
- AttributeError: If the location cannot be found
102
- """
103
- geolocator = Nominatim(user_agent="city_to_latlong")
104
- coords = geolocator.geocode(location)
105
- return (coords.latitude, coords.longitude)
106
-
107
-
108
- def coords2loc(coords: tuple[float, float]) -> str:
109
- """Converts geographic coordinates to a location name.
110
-
111
- This function uses the Nominatim reverse geocoding service to convert
112
- latitude and longitude coordinates to a human-readable location name.
113
-
114
- Args:
115
- coords (tuple[float, float]): A tuple containing (latitude, longitude)
116
-
117
- Returns:
118
- str: The address of the location, or "Unknown Location" if not found
119
-
120
- Example:
121
- >>> coords2loc((48.8566, 2.3522))
122
- 'Paris, France'
123
- """
124
- geolocator = Nominatim(user_agent="coords_to_city")
125
- try:
126
- location = geolocator.reverse(coords)
127
- return location.address
128
- except Exception as e:
129
- print(f"Error: {e}")
130
- return "Unknown Location"
131
-
132
-
133
- def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
134
- long = round(location[1], 3)
135
- lat = round(location[0], 3)
136
-
137
- table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
138
-
139
- results = duckdb.sql(
140
- f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
141
- ).fetchdf()
142
-
143
- if len(results) == 0:
144
- return "", ""
145
- # cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
146
- return results['latitude'].iloc[0], results['longitude'].iloc[0]
147
-
148
-
149
- async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
150
- """Identifies relevant tables for a plot based on user input.
151
-
152
- This function uses an LLM to analyze the user's question and the plot
153
- description to determine which tables in the DRIAS database would be
154
- most relevant for generating the requested visualization.
155
-
156
- Args:
157
- user_question (str): The user's question about climate data
158
- plot (Plot): The plot configuration object
159
- llm: The language model instance to use for analysis
160
-
161
- Returns:
162
- list[str]: A list of table names that are relevant for the plot
163
-
164
- Example:
165
- >>> detect_relevant_tables(
166
- ... "What will the temperature be like in Paris?",
167
- ... indicator_evolution_at_location,
168
- ... llm
169
- ... )
170
- ['mean_annual_temperature', 'mean_summer_temperature']
171
- """
172
- # Get all table names
173
- table_names_list = DRIAS_TABLES
174
-
175
- prompt = (
176
- f"You are helping to build a plot following this description : {plot['description']}."
177
- f"You are given a list of tables and a user question."
178
- f"Based on the description of the plot, which table are appropriate for that kind of plot."
179
- f"Write the 3 most relevant tables to use. Answer only a python list of table name."
180
- f"### List of tables : {table_names_list}"
181
- f"### User question : {user_question}"
182
- f"### List of table name : "
183
- )
184
-
185
- table_names = ast.literal_eval(
186
- (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
187
- )
188
- return table_names
189
-
190
-
191
- def replace_coordonates(coords, query, coords_tables):
192
- n = query.count(str(coords[0]))
193
-
194
- for i in range(n):
195
- query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
196
- query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
197
- return query
198
-
199
-
200
- async def detect_relevant_plots(user_question: str, llm):
201
- plots_description = ""
202
- for plot in PLOTS:
203
- plots_description += "Name: " + plot["name"]
204
- plots_description += " - Description: " + plot["description"] + "\n"
205
-
206
- prompt = (
207
- f"You are helping to answer a quesiton with insightful visualizations."
208
- f"You are given an user question and a list of plots with their name and description."
209
- f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
210
- f"Write the most relevant tables to use. Answer only a python list of plot name."
211
- f"### Descriptions of the plots : {plots_description}"
212
- f"### User question : {user_question}"
213
- f"### Name of the plot : "
214
- )
215
- # prompt = (
216
- # f"You are helping to answer a question with insightful visualizations. "
217
- # f"Given a list of plots with their name and description: "
218
- # f"{plots_description} "
219
- # f"The user question is: {user_question}. "
220
- # f"Choose the most relevant plots to answer the question. "
221
- # f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
222
- # f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
223
- # )
224
-
225
- plot_names = ast.literal_eval(
226
- (await llm.ainvoke(prompt)).content.strip("```python\n").strip()
227
- )
228
- return plot_names
229
-
230
-
231
- # Next Version
232
- # class QueryOutput(TypedDict):
233
- # """Generated SQL query."""
234
-
235
- # query: Annotated[str, ..., "Syntactically valid SQL query."]
236
-
237
-
238
- # class PlotlyCodeOutput(TypedDict):
239
- # """Generated Plotly code"""
240
-
241
- # code: Annotated[str, ..., "Synatically valid Plotly python code."]
242
- # def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
243
- # """Generate SQL query to fetch information."""
244
- # prompt_params = {
245
- # "dialect": db.dialect,
246
- # "table_info": db.get_table_info(),
247
- # "input": user_input,
248
- # "relevant_tables": relevant_tables,
249
- # "model": "ALADIN63_CNRM-CM5",
250
- # }
251
-
252
- # prompt = ChatPromptTemplate.from_template(query_prompt_template)
253
- # structured_llm = llm.with_structured_output(QueryOutput)
254
- # chain = prompt | structured_llm
255
- # result = chain.invoke(prompt_params)
256
-
257
- # return result["query"]
258
-
259
-
260
- # def fetch_data_from_sql_query(db: str, sql_query: str):
261
- # conn = sqlite3.connect(db)
262
- # cursor = conn.cursor()
263
- # cursor.execute(sql_query)
264
- # column_names = [desc[0] for desc in cursor.description]
265
- # values = cursor.fetchall()
266
- # return {"column_names": column_names, "data": values}
267
-
268
-
269
- # def generate_chart_code(user_input: str, sql_query: list[str], llm):
270
- # """ "Generate plotly python code for the chart based on the sql query and the user question"""
271
-
272
- # class PlotlyCodeOutput(TypedDict):
273
- # """Generated Plotly code"""
274
-
275
- # code: Annotated[str, ..., "Synatically valid Plotly python code."]
276
-
277
- # prompt = ChatPromptTemplate.from_template(plot_prompt_template)
278
- # structured_llm = llm.with_structured_output(PlotlyCodeOutput)
279
- # chain = prompt | structured_llm
280
- # result = chain.invoke({"input": user_input, "sql_query": sql_query})
281
- # return result["code"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/vanna_class.py DELETED
@@ -1,325 +0,0 @@
1
- from vanna.base import VannaBase
2
- from pinecone import Pinecone
3
- from climateqa.engine.embeddings import get_embeddings_function
4
- import pandas as pd
5
- import hashlib
6
-
7
- class MyCustomVectorDB(VannaBase):
8
-
9
- """
10
- VectorDB class for storing and retrieving vectors from Pinecone.
11
-
12
- args :
13
- config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
14
- - pc_api_key (str) : Pinecone API key
15
- - index_name (str) : Pinecone index name
16
- - top_k (int) : Number of top results to return (default = 2)
17
-
18
- """
19
-
20
- def __init__(self,config):
21
- super().__init__(config = config)
22
- try :
23
- self.api_key = config.get('pc_api_key')
24
- self.index_name = config.get('index_name')
25
- except :
26
- raise Exception("Please provide the Pinecone API key and the index name")
27
-
28
- self.pc = Pinecone(api_key = self.api_key)
29
- self.index = self.pc.Index(self.index_name)
30
- self.top_k = config.get('top_k', 2)
31
- self.embeddings = get_embeddings_function()
32
-
33
-
34
- def check_embedding(self, id, namespace):
35
- fetched = self.index.fetch(ids = [id], namespace = namespace)
36
- if fetched['vectors'] == {}:
37
- return False
38
- return True
39
-
40
- def generate_hash_id(self, data: str) -> str:
41
- """
42
- Generate a unique hash ID for the given data.
43
-
44
- Args:
45
- data (str): The input data to hash (e.g., a concatenated string of user attributes).
46
-
47
- Returns:
48
- str: A unique hash ID as a hexadecimal string.
49
- """
50
-
51
- data_bytes = data.encode('utf-8')
52
- hash_object = hashlib.sha256(data_bytes)
53
- hash_id = hash_object.hexdigest()
54
-
55
- return hash_id
56
-
57
- def add_ddl(self, ddl: str, **kwargs) -> str:
58
- id = self.generate_hash_id(ddl) + '_ddl'
59
-
60
- if self.check_embedding(id, 'ddl'):
61
- print(f"DDL having id {id} already exists")
62
- return id
63
-
64
- self.index.upsert(
65
- vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
66
- namespace = 'ddl'
67
- )
68
-
69
- return id
70
-
71
- def add_documentation(self, doc: str, **kwargs) -> str:
72
- id = self.generate_hash_id(doc) + '_doc'
73
-
74
- if self.check_embedding(id, 'documentation'):
75
- print(f"Documentation having id {id} already exists")
76
- return id
77
-
78
- self.index.upsert(
79
- vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
80
- namespace = 'documentation'
81
- )
82
-
83
- return id
84
-
85
- def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
86
- id = self.generate_hash_id(question) + '_sql'
87
-
88
- if self.check_embedding(id, 'question_sql'):
89
- print(f"Question-SQL pair having id {id} already exists")
90
- return id
91
-
92
- self.index.upsert(
93
- vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
94
- namespace = 'question_sql'
95
- )
96
-
97
- return id
98
-
99
- def get_related_ddl(self, question: str, **kwargs) -> list:
100
- res = self.index.query(
101
- vector=self.embeddings.embed_query(question),
102
- top_k=self.top_k,
103
- namespace='ddl',
104
- include_metadata=True
105
- )
106
-
107
- return [match['metadata']['ddl'] for match in res['matches']]
108
-
109
- def get_related_documentation(self, question: str, **kwargs) -> list:
110
- res = self.index.query(
111
- vector=self.embeddings.embed_query(question),
112
- top_k=self.top_k,
113
- namespace='documentation',
114
- include_metadata=True
115
- )
116
-
117
- return [match['metadata']['doc'] for match in res['matches']]
118
-
119
- def get_similar_question_sql(self, question: str, **kwargs) -> list:
120
- res = self.index.query(
121
- vector=self.embeddings.embed_query(question),
122
- top_k=self.top_k,
123
- namespace='question_sql',
124
- include_metadata=True
125
- )
126
-
127
- return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
128
-
129
- def get_training_data(self, **kwargs) -> pd.DataFrame:
130
-
131
- list_of_data = []
132
-
133
- namespaces = ['ddl', 'documentation', 'question_sql']
134
-
135
- for namespace in namespaces:
136
-
137
- data = self.index.query(
138
- top_k=10000,
139
- namespace=namespace,
140
- include_metadata=True,
141
- include_values=False
142
- )
143
-
144
- for match in data['matches']:
145
- list_of_data.append(match['metadata'])
146
-
147
- return pd.DataFrame(list_of_data)
148
-
149
-
150
-
151
- def remove_training_data(self, id: str, **kwargs) -> bool:
152
- if id.endswith("_ddl"):
153
- self.Index.delete(ids=[id], namespace="_ddl")
154
- return True
155
- if id.endswith("_sql"):
156
- self.index.delete(ids=[id], namespace="_sql")
157
- return True
158
-
159
- if id.endswith("_doc"):
160
- self.Index.delete(ids=[id], namespace="_doc")
161
- return True
162
-
163
- return False
164
-
165
- def generate_embedding(self, text, **kwargs):
166
- # Implement the method here
167
- pass
168
-
169
-
170
- def get_sql_prompt(
171
- self,
172
- initial_prompt : str,
173
- question: str,
174
- question_sql_list: list,
175
- ddl_list: list,
176
- doc_list: list,
177
- **kwargs,
178
- ):
179
- """
180
- Example:
181
- ```python
182
- vn.get_sql_prompt(
183
- question="What are the top 10 customers by sales?",
184
- question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
185
- ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
186
- doc_list=["The customers table contains information about customers and their sales."],
187
- )
188
-
189
- ```
190
-
191
- This method is used to generate a prompt for the LLM to generate SQL.
192
-
193
- Args:
194
- question (str): The question to generate SQL for.
195
- question_sql_list (list): A list of questions and their corresponding SQL statements.
196
- ddl_list (list): A list of DDL statements.
197
- doc_list (list): A list of documentation.
198
-
199
- Returns:
200
- any: The prompt for the LLM to generate SQL.
201
- """
202
-
203
- if initial_prompt is None:
204
- initial_prompt = f"You are a {self.dialect} expert. " + \
205
- "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
206
-
207
- initial_prompt = self.add_ddl_to_prompt(
208
- initial_prompt, ddl_list, max_tokens=self.max_tokens
209
- )
210
-
211
- if self.static_documentation != "":
212
- doc_list.append(self.static_documentation)
213
-
214
- initial_prompt = self.add_documentation_to_prompt(
215
- initial_prompt, doc_list, max_tokens=self.max_tokens
216
- )
217
-
218
- # initial_prompt = self.add_sql_to_prompt(
219
- # initial_prompt, question_sql_list, max_tokens=self.max_tokens
220
- # )
221
-
222
-
223
- initial_prompt += (
224
- "===Response Guidelines \n"
225
- "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
226
- "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
227
- "3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
228
- "4. Please use the most relevant table(s). \n"
229
- "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
230
- f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
231
- f"7. Add a description of the table in the result of the sql query, if relevant. \n"
232
- "8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
233
- # f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
234
- # "7. Add a description of the table in the result of the sql query."
235
- # "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
236
- # "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
237
- )
238
-
239
-
240
- message_log = [self.system_message(initial_prompt)]
241
-
242
- for example in question_sql_list:
243
- if example is None:
244
- print("example is None")
245
- else:
246
- if example is not None and "question" in example and "sql" in example:
247
- message_log.append(self.user_message(example["question"]))
248
- message_log.append(self.assistant_message(example["sql"]))
249
-
250
- message_log.append(self.user_message(question))
251
-
252
- return message_log
253
-
254
-
255
- # def get_sql_prompt(
256
- # self,
257
- # initial_prompt : str,
258
- # question: str,
259
- # question_sql_list: list,
260
- # ddl_list: list,
261
- # doc_list: list,
262
- # **kwargs,
263
- # ):
264
- # """
265
- # Example:
266
- # ```python
267
- # vn.get_sql_prompt(
268
- # question="What are the top 10 customers by sales?",
269
- # question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
270
- # ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
271
- # doc_list=["The customers table contains information about customers and their sales."],
272
- # )
273
-
274
- # ```
275
-
276
- # This method is used to generate a prompt for the LLM to generate SQL.
277
-
278
- # Args:
279
- # question (str): The question to generate SQL for.
280
- # question_sql_list (list): A list of questions and their corresponding SQL statements.
281
- # ddl_list (list): A list of DDL statements.
282
- # doc_list (list): A list of documentation.
283
-
284
- # Returns:
285
- # any: The prompt for the LLM to generate SQL.
286
- # """
287
-
288
- # if initial_prompt is None:
289
- # initial_prompt = f"You are a {self.dialect} expert. " + \
290
- # "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
291
-
292
- # initial_prompt = self.add_ddl_to_prompt(
293
- # initial_prompt, ddl_list, max_tokens=self.max_tokens
294
- # )
295
-
296
- # if self.static_documentation != "":
297
- # doc_list.append(self.static_documentation)
298
-
299
- # initial_prompt = self.add_documentation_to_prompt(
300
- # initial_prompt, doc_list, max_tokens=self.max_tokens
301
- # )
302
-
303
- # initial_prompt += (
304
- # "===Response Guidelines \n"
305
- # "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
306
- # "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
307
- # "3. If the provided context is insufficient, please explain why it can't be generated. \n"
308
- # "4. Please use the most relevant table(s). \n"
309
- # "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
310
- # f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
311
- # )
312
-
313
- # message_log = [self.system_message(initial_prompt)]
314
-
315
- # for example in question_sql_list:
316
- # if example is None:
317
- # print("example is None")
318
- # else:
319
- # if example is not None and "question" in example and "sql" in example:
320
- # message_log.append(self.user_message(example["question"]))
321
- # message_log.append(self.assistant_message(example["sql"]))
322
-
323
- # message_log.append(self.user_message(question))
324
-
325
- # return message_log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
climateqa/engine/talk_to_data/workflow/drias.py CHANGED
@@ -125,11 +125,16 @@ async def drias_workflow(user_input: str) -> State:
125
  'plot': plot,
126
  'status': 'OK'
127
  }
 
 
 
 
 
 
 
128
 
129
- # Gather all required parameters
130
  params = {}
131
- for param_name in DRIAS_PLOT_PARAMETERS:
132
- param = await find_param(state, param_name, mode='DRIAS')
133
  if param:
134
  params.update(param)
135
 
 
125
  'plot': plot,
126
  'status': 'OK'
127
  }
128
+
129
+ # Gather all required parameters in parallel
130
+ param_tasks = [
131
+ find_param(state, param_name, mode='DRIAS')
132
+ for param_name in DRIAS_PLOT_PARAMETERS
133
+ ]
134
+ param_results = await asyncio.gather(*param_tasks)
135
 
 
136
  params = {}
137
+ for param in param_results:
 
138
  if param:
139
  params.update(param)
140
 
climateqa/engine/talk_to_data/workflow/ipcc.py CHANGED
@@ -125,12 +125,17 @@ async def ipcc_workflow(user_input: str) -> State:
125
  }
126
 
127
  # Gather all required parameters
 
 
 
 
 
 
128
  params = {}
129
- for param_name in IPCC_PLOT_PARAMETERS:
130
- param = await find_param(state, param_name, mode='IPCC')
131
  if param:
132
  params.update(param)
133
-
134
  # Process all outputs in parallel using process_output
135
  tasks = [
136
  process_output(output_title, output['table'], output['plot'], params.copy())
@@ -152,10 +157,18 @@ async def ipcc_workflow(user_input: str) -> State:
152
 
153
  # Set error messages if needed
154
  if not errors['have_relevant_table']:
155
- state['error'] = "There is no relevant table in our database to answer your question"
 
 
 
156
  elif not errors['have_sql_query']:
157
- state['error'] = "There is no relevant sql query on our database that can help to answer your question"
 
 
 
158
  elif not errors['have_dataframe']:
159
- state['error'] = "There is no data in our table that can answer to your question"
160
-
 
 
161
  return state
 
125
  }
126
 
127
  # Gather all required parameters
128
+ param_tasks = [
129
+ find_param(state, param_name, mode='IPCC')
130
+ for param_name in IPCC_PLOT_PARAMETERS
131
+ ]
132
+ param_results = await asyncio.gather(*param_tasks)
133
+
134
  params = {}
135
+ for param in param_results:
 
136
  if param:
137
  params.update(param)
138
+
139
  # Process all outputs in parallel using process_output
140
  tasks = [
141
  process_output(output_title, output['table'], output['plot'], params.copy())
 
157
 
158
  # Set error messages if needed
159
  if not errors['have_relevant_table']:
160
+ state['error'] = (
161
+ "Sorry, I couldn't find any relevant table in our database to answer your question.\n"
162
+ "Try asking about a different climate indicator like temperature or precipitation."
163
+ )
164
  elif not errors['have_sql_query']:
165
+ state['error'] = (
166
+ "Sorry, I couldn't generate a relevant SQL query to answer your question.\n"
167
+ "Try rephrasing your question to focus on a specific location, a year, or a month."
168
+ )
169
  elif not errors['have_dataframe']:
170
+ state['error'] = (
171
+ "Sorry, there is no data in our tables that can answer your question.\n"
172
+ "Try asking about a more common location, or a different year."
173
+ )
174
  return state
climateqa/engine/vectorstore.py CHANGED
@@ -1,11 +1,11 @@
1
- # Pinecone
2
- # More info at https://docs.pinecone.io/docs/langchain
3
- # And https://python.langchain.com/docs/integrations/vectorstores/pinecone
4
  import os
5
- from pinecone import Pinecone
6
- from langchain_community.vectorstores import Pinecone as PineconeVectorstore
7
 
8
- # LOAD ENVIRONMENT VARIABLES
 
 
 
 
9
  try:
10
  from dotenv import load_dotenv
11
  load_dotenv()
@@ -13,44 +13,136 @@ except:
13
  pass
14
 
15
 
16
-
17
-
18
- def get_pinecone_vectorstore(embeddings,text_key = "content", index_name = os.getenv("PINECONE_API_INDEX")):
19
-
20
- # # initialize pinecone
21
- # pinecone.init(
22
- # api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
23
- # environment=os.getenv("PINECONE_API_ENVIRONMENT"), # next to api key in console
24
- # )
25
-
26
- # index_name = os.getenv("PINECONE_API_INDEX")
27
- # vectorstore = Pinecone.from_existing_index(index_name, embeddings,text_key = text_key)
28
-
29
- # return vectorstore
30
-
31
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
32
- index = pc.Index(index_name)
33
-
34
- vectorstore = PineconeVectorstore(
35
- index, embeddings, text_key,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- return vectorstore
38
-
39
-
40
-
41
- # def get_pinecone_retriever(vectorstore,k = 10,namespace = "vectors",sources = ["IPBES","IPCC"]):
42
-
43
- # assert isinstance(sources,list)
44
-
45
- # # Check if all elements in the list are either IPCC or IPBES
46
- # filter = {
47
- # "source": { "$in":sources},
48
- # }
49
-
50
- # retriever = vectorstore.as_retriever(search_kwargs={
51
- # "k": k,
52
- # "namespace":"vectors",
53
- # "filter":filter
54
- # })
55
 
56
- # return retriever
 
1
+ # Azure AI Search: https://python.langchain.com/docs/integrations/vectorstores/azuresearch
 
 
2
  import os
 
 
3
 
4
+ # Azure AI Search imports
5
+ from langchain_community.vectorstores.azuresearch import AzureSearch
6
+
7
+
8
+ # Load environment variables
9
  try:
10
  from dotenv import load_dotenv
11
  load_dotenv()
 
13
  pass
14
 
15
 
16
+ class AzureSearchWrapper:
17
+ """
18
+ Wrapper class for Azure AI Search vectorstore to handle filter conversion.
19
+
20
+ This wrapper automatically converts dictionary-style filters to Azure Search OData filter format,
21
+ ensuring seamless compatibility when switching from other providers.
22
+ """
23
+
24
+ def __init__(self, azure_search_vectorstore):
25
+ self.vectorstore = azure_search_vectorstore
26
+
27
+ def __getattr__(self, name):
28
+ """Delegate all other attributes to the wrapped vectorstore."""
29
+ return getattr(self.vectorstore, name)
30
+
31
+ def _convert_dict_filter_to_odata(self, filter_dict):
32
+ """
33
+ Convert dictionary-style filters to Azure Search OData filter format.
34
+
35
+ Args:
36
+ filter_dict (dict): Dictionary-style filter
37
+
38
+ Returns:
39
+ str: OData filter string
40
+ """
41
+ if not filter_dict:
42
+ return None
43
+
44
+ conditions = []
45
+
46
+ for key, value in filter_dict.items():
47
+ if key.endswith('_exclude'):
48
+ # Handle exclusion filters (e.g., report_type_exclude)
49
+ base_key = key.replace('_exclude', '')
50
+ if isinstance(value, list):
51
+ if len(value) == 1:
52
+ conditions.append(f"{base_key} ne '{value[0]}'")
53
+ else:
54
+ exclude_conditions = [f"{base_key} ne '{v}'" for v in value]
55
+ conditions.append(f"({' and '.join(exclude_conditions)})")
56
+ else:
57
+ conditions.append(f"{base_key} ne '{value}'")
58
+ elif isinstance(value, list):
59
+ # Handle list values (equivalent to $in operator)
60
+ if len(value) == 1:
61
+ conditions.append(f"{key} eq '{value[0]}'")
62
+ else:
63
+ list_conditions = [f"{key} eq '{v}'" for v in value]
64
+ conditions.append(f"({' or '.join(list_conditions)})")
65
+ else:
66
+ # Handle single values
67
+ conditions.append(f"{key} eq '{value}'")
68
+
69
+ return " and ".join(conditions) if conditions else None
70
+
71
+ def similarity_search_with_score(self, query, k=4, filter=None, **kwargs):
72
+ """Override similarity_search_with_score to convert filters."""
73
+ if filter is not None:
74
+ filter = self._convert_dict_filter_to_odata(filter)
75
+
76
+ return self.vectorstore.hybrid_search_with_score(
77
+ query=query, k=k, filters=filter, **kwargs
78
+ )
79
+
80
+
81
+ def similarity_search(self, query, k=4, filter=None, **kwargs):
82
+ """Override similarity_search to convert filters."""
83
+ if filter is not None:
84
+ filter = self._convert_dict_filter_to_odata(filter)
85
+
86
+ return self.vectorstore.similarity_search(
87
+ query=query, k=k, filter=filter, **kwargs
88
+ )
89
+
90
+ def similarity_search_by_vector(self, embedding, k=4, filter=None, **kwargs):
91
+ """Override similarity_search_by_vector to convert filters."""
92
+ if filter is not None:
93
+ filter = self._convert_dict_filter_to_odata(filter)
94
+
95
+ return self.vectorstore.similarity_search_by_vector(
96
+ embedding=embedding, k=k, filter=filter, **kwargs
97
+ )
98
+
99
+ def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs):
100
+ """Override as_retriever to handle filter conversion in search_kwargs."""
101
+ if search_kwargs and "filter" in search_kwargs:
102
+ # Convert the filter in search_kwargs
103
+ search_kwargs = search_kwargs.copy() # Don't modify the original
104
+ if search_kwargs["filter"] is not None:
105
+ search_kwargs["filter"] = self._convert_dict_filter_to_odata(search_kwargs["filter"])
106
+
107
+ return self.vectorstore.as_retriever(
108
+ search_type=search_type, search_kwargs=search_kwargs, **kwargs
109
+ )
110
+
111
+
112
+ def get_azure_search_vectorstore(embeddings, text_key="content", index_name=None):
113
+ """
114
+ Create an Azure AI Search vectorstore instance.
115
+
116
+ Args:
117
+ embeddings: The embeddings function to use
118
+ text_key: The key for text content in the payload (default: "content")
119
+ index_name: The name of the Azure Search index
120
+
121
+ Returns:
122
+ AzureSearchWrapper: A wrapped Azure AI Search vectorstore instance with filter compatibility
123
+ """
124
+ # Get Azure AI Search configuration from environment variables
125
+ azure_search_endpoint = os.getenv("AI_SEARCH_INDEX_ENDPOINT")
126
+ azure_search_key = os.getenv("AI_SEARCH_KEY")
127
+
128
+ if not azure_search_endpoint:
129
+ raise ValueError("AI_SEARCH_INDEX_ENDPOINT environment variable is required")
130
+
131
+ if not azure_search_key:
132
+ raise ValueError("AI_SEARCH_KEY environment variable is required")
133
+
134
+ if not index_name:
135
+ raise ValueError("index_name must be provided for Azure Search")
136
+
137
+ # Create Azure Search vectorstore
138
+ vectorstore = AzureSearch(
139
+ azure_search_endpoint=azure_search_endpoint,
140
+ azure_search_key=azure_search_key,
141
+ index_name=index_name,
142
+ embedding_function=embeddings.embed_query,
143
+ content_key=text_key,
144
  )
145
+
146
+ # Wrap the vectorstore to handle filter conversion
147
+ return AzureSearchWrapper(vectorstore)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
climateqa/utils.py CHANGED
@@ -25,7 +25,7 @@ def remove_duplicates_keep_highest_score(documents):
25
  unique_docs = {}
26
 
27
  for doc in documents:
28
- doc_id = doc.metadata.get('doc_id')
29
  if doc_id in unique_docs:
30
  if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
31
  unique_docs[doc_id] = doc
 
25
  unique_docs = {}
26
 
27
  for doc in documents:
28
+ doc_id = doc.metadata.get('id')
29
  if doc_id in unique_docs:
30
  if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
31
  unique_docs[doc_id] = doc
front/tabs/tab_ipcc.py CHANGED
@@ -68,6 +68,8 @@ def show_filter_by_scenario(table_names, index_state, dataframes):
68
  return gr.update(visible=False)
69
 
70
  def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
 
 
71
  df = dataframes[index_state]
72
  if not table_names[index_state].startswith("Map"):
73
  return df, figures[index_state](df)
 
68
  return gr.update(visible=False)
69
 
70
  def filter_by_scenario(dataframes, figures, table_names, index_state, scenario):
71
+ if len(dataframes) == 0:
72
+ return None, None
73
  df = dataframes[index_state]
74
  if not table_names[index_state].startswith("Map"):
75
  return df, figures[index_state](df)
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
  gradio==5.0.2
2
  azure-storage-file-share==12.11.1
3
  azure-storage-blob==12.23.0
 
 
 
4
  python-dotenv==1.0.0
5
  langchain==0.2.1
6
  langchain_openai==0.1.7
 
1
  gradio==5.0.2
2
  azure-storage-file-share==12.11.1
3
  azure-storage-blob==12.23.0
4
+ # Azure AI Search support
5
+ azure-search-documents>=11.4.0
6
+ azure-core>=1.29.0
7
  python-dotenv==1.0.0
8
  langchain==0.2.1
9
  langchain_openai==0.1.7
sandbox/20241104 - CQA - StepByStep CQA.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
style.css CHANGED
@@ -661,7 +661,6 @@ a {
661
 
662
  #sql-query textarea{
663
  min-height: 200px !important;
664
-
665
  }
666
 
667
  #sql-query span{
 
661
 
662
  #sql-query textarea{
663
  min-height: 200px !important;
 
664
  }
665
 
666
  #sql-query span{