| from typing import Any, Literal, Optional, cast |
| import ast |
| from langchain_core.prompts import ChatPromptTemplate |
| from geopy.geocoders import Nominatim |
| from climateqa.engine.llm import get_llm |
| import duckdb |
| import os |
| from climateqa.engine.talk_to_data.config import DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH, IPCC_COORDINATES_PATH |
| from climateqa.engine.talk_to_data.objects.llm_outputs import ArrayOutput |
| from climateqa.engine.talk_to_data.objects.location import Location |
| from climateqa.engine.talk_to_data.objects.plot import Plot |
| from climateqa.engine.talk_to_data.objects.states import State |
| import calendar |
|
|
| async def detect_location_with_openai(sentence: str) -> str: |
| """ |
| Detects locations in a sentence using OpenAI's API via LangChain. |
| """ |
| llm = get_llm() |
|
|
| prompt = f""" |
| Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence. |
| Return the result as a Python list. If no locations are mentioned, return an empty list. |
| |
| Sentence: "{sentence}" |
| """ |
|
|
| response = await llm.ainvoke(prompt) |
| location_list = ast.literal_eval(response.content.strip("```python\n").strip()) |
| if location_list: |
| return location_list[0] |
| else: |
| return "" |
|
|
| def loc_to_coords(location: str) -> tuple[float, float]: |
| """Converts a location name to geographic coordinates. |
| |
| This function uses the Nominatim geocoding service to convert |
| a location name (e.g., city name) to its latitude and longitude. |
| |
| Args: |
| location (str): The name of the location to geocode |
| |
| Returns: |
| tuple[float, float]: A tuple containing (latitude, longitude) |
| |
| Raises: |
| AttributeError: If the location cannot be found |
| """ |
| geolocator = Nominatim(user_agent="city_to_latlong", timeout=5) |
| coords = geolocator.geocode(location) |
| return (coords.latitude, coords.longitude) |
|
|
| def coords_to_country(coords: tuple[float, float]) -> tuple[str,str]: |
| """Converts geographic coordinates to a country name. |
| |
| This function uses the Nominatim reverse geocoding service to convert |
| latitude and longitude coordinates to a country name. |
| |
| Args: |
| coords (tuple[float, float]): A tuple containing (latitude, longitude) |
| |
| Returns: |
| tuple[str,str]: A tuple containg (country_code, country_name, admin1) |
| |
| Raises: |
| AttributeError: If the coordinates cannot be found |
| """ |
| geolocator = Nominatim(user_agent="latlong_to_country") |
| location = geolocator.reverse(coords) |
| address = location.raw['address'] |
| return address['country_code'].upper(), address['country'] |
|
|
| def nearest_neighbour_sql(location: tuple, mode: Literal['DRIAS', 'IPCC']) -> tuple[str, str, Optional[str]]: |
| long = round(location[1], 3) |
| lat = round(location[0], 3) |
| conn = duckdb.connect() |
|
|
| if mode == 'DRIAS': |
| table_path = f"'{DRIAS_MEAN_ANNUAL_TEMPERATURE_PATH}'" |
| results = conn.sql( |
| f"SELECT latitude, longitude FROM {table_path} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}" |
| ).fetchdf() |
| else: |
| table_path = f"'{IPCC_COORDINATES_PATH}'" |
| results = conn.sql( |
| f"SELECT latitude, longitude, admin1 FROM {table_path} WHERE latitude BETWEEN {lat - 0.5} AND {lat + 0.5} AND longitude BETWEEN {long - 0.5} AND {long + 0.5}" |
| ).fetchdf() |
| |
|
|
| if len(results) == 0: |
| return "", "", "" |
|
|
| if 'admin1' in results.columns: |
| admin1 = results['admin1'].iloc[0] |
| else: |
| admin1 = None |
| return results['latitude'].iloc[0], results['longitude'].iloc[0], admin1 |
|
|
| async def detect_year_with_openai(sentence: str) -> str: |
| """ |
| Detects years in a sentence using OpenAI's API via LangChain. |
| """ |
| llm = get_llm() |
|
|
| prompt = """ |
| Extract all years mentioned in the following sentence. |
| Return the result as a Python list. If no year are mentioned, return an empty list. |
| |
| Sentence: "{sentence}" |
| """ |
|
|
| prompt = ChatPromptTemplate.from_template(prompt) |
| structured_llm = llm.with_structured_output(ArrayOutput) |
| chain = prompt | structured_llm |
| response: ArrayOutput = await chain.ainvoke({"sentence": sentence}) |
| years_list = ast.literal_eval(response['array']) |
| if len(years_list) > 0: |
| return years_list[0] |
| else: |
| return "" |
|
|
|
|
| async def detect_relevant_tables(user_question: str, plot: Plot, llm, table_names_list: list[str]) -> list[str]: |
| """Identifies relevant tables for a plot based on user input. |
| |
| This function uses an LLM to analyze the user's question and the plot |
| description to determine which tables in the DRIAS database would be |
| most relevant for generating the requested visualization. |
| |
| Args: |
| user_question (str): The user's question about climate data |
| plot (Plot): The plot configuration object |
| llm: The language model instance to use for analysis |
| |
| Returns: |
| list[str]: A list of table names that are relevant for the plot |
| |
| Example: |
| >>> detect_relevant_tables( |
| ... "What will the temperature be like in Paris?", |
| ... indicator_evolution_at_location, |
| ... llm |
| ... ) |
| ['mean_annual_temperature', 'mean_summer_temperature'] |
| """ |
| |
|
|
| prompt = ( |
| f"You are helping to build a plot following this description : {plot['description']}." |
| f"You are given a list of tables and a user question." |
| f"Based on the description of the plot, which table are appropriate for that kind of plot." |
| f"Write the 3 most relevant tables to use. Answer only a python list of table name." |
| f"### List of tables : {table_names_list}" |
| f"### User question : {user_question}" |
| f"### List of table name : " |
| ) |
|
|
| table_names = ast.literal_eval( |
| (await llm.ainvoke(prompt)).content.strip("```python\n").strip() |
| ) |
| return table_names |
|
|
| async def detect_relevant_plots(user_question: str, llm, plot_list: list[Plot]) -> list[str]: |
| plots_description = "" |
| for plot in plot_list: |
| plots_description += "Name: " + plot["name"] |
| plots_description += " - Description: " + plot["description"] + "\n" |
|
|
| prompt = ( |
| "You are helping to answer a question with insightful visualizations.\n" |
| "You are given a user question and a list of plots with their name and description.\n" |
| "Based on the descriptions of the plots, select ALL plots that could provide a useful answer to this question. " |
| "Include any plot that could show relevant information, even if their perspectives (such as time series or spatial distribution) are different.\n" |
| "For example, for a question like 'What will be the total rainfall in China in 2050?', both a time series plot and a spatial map plot could be relevant.\n" |
| "Return only a Python list of plot names sorted from the most relevant one to the less relevant one.\n" |
| f"### Descriptions of the plots : {plots_description}" |
| f"### User question : {user_question}\n" |
| f"### Names of the plots : " |
| ) |
|
|
| plot_names = ast.literal_eval( |
| (await llm.ainvoke(prompt)).content.strip("```python\n").strip() |
| ) |
| return plot_names |
|
|
| async def find_location(user_input: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> Location: |
| print(f"---- Find location in user input ----") |
| location = await detect_location_with_openai(user_input) |
| output: Location = { |
| 'location' : location, |
| 'longitude' : None, |
| 'latitude' : None, |
| 'country_code' : None, |
| 'country_name' : None, |
| 'admin1' : None |
| } |
| |
| if location: |
| coords = loc_to_coords(location) |
| country_code, country_name = coords_to_country(coords) |
| neighbour = nearest_neighbour_sql(coords, mode) |
| output.update({ |
| "latitude": neighbour[0], |
| "longitude": neighbour[1], |
| "country_code": country_code, |
| "country_name": country_name, |
| "admin1": neighbour[2] |
| }) |
| output = cast(Location, output) |
| return output |
|
|
| async def find_year(user_input: str) -> str| None: |
| """Extracts year information from user input using LLM. |
| |
| This function uses an LLM to identify and extract year information from the |
| user's query, which is used to filter data in subsequent queries. |
| |
| Args: |
| user_input (str): The user's query text |
| |
| Returns: |
| str: The extracted year, or empty string if no year found |
| """ |
| print(f"---- Find year ---") |
| year = await detect_year_with_openai(user_input) |
| if year == "": |
| return None |
| return year |
|
|
| async def find_month(user_input: str) -> dict[str, str|None]: |
| """ |
| Extracts month information from user input using an LLM. |
| |
| This function analyzes the user's query to detect if a month is mentioned. |
| It returns both the month number (as a string, e.g. '7' for July) and the full English month name (e.g. 'July'). |
| If no month is found, both values will be None. |
| |
| Args: |
| user_input (str): The user's query text. |
| |
| Returns: |
| dict[str, str|None]: A dictionary with keys: |
| - "month_number": the month number as a string (e.g. '7'), or None if not found |
| - "month_name": the full English month name (e.g. 'July'), or None if not found |
| |
| Example: |
| >>> await find_month("Show me the temperature in Paris in July") |
| {'month_number': '7', 'month_name': 'July'} |
| >>> await find_month("Show me the temperature in Paris") |
| {'month_number': None, 'month_name': None} |
| """ |
|
|
| llm = get_llm() |
| prompt = """ |
| Extract the month (as a number from 1 to 12) mentioned in the following sentence. |
| Return the result as a Python list of integers. If no month is mentioned, return an empty list. |
| |
| Sentence: "{sentence}" |
| """ |
| prompt = ChatPromptTemplate.from_template(prompt) |
| structured_llm = llm.with_structured_output(ArrayOutput) |
| chain = prompt | structured_llm |
| response: ArrayOutput = await chain.ainvoke({"sentence": user_input}) |
| months_list = ast.literal_eval(response['array']) |
| if len(months_list) > 0: |
| month_number = int(months_list[0]) |
| month_name = calendar.month_name[month_number] |
| return { |
| "month_number": str(month_number), |
| "month_name": month_name |
| } |
| else: |
| return { |
| "month_number" : None, |
| "month_name" : None |
| } |
|
|
|
|
| async def find_relevant_plots(state: State, llm, plots: list[Plot]) -> list[str]: |
| print("---- Find relevant plots ----") |
| relevant_plots = await detect_relevant_plots(state['user_input'], llm, plots) |
| return relevant_plots |
|
|
| async def find_relevant_tables_per_plot(state: State, plot: Plot, llm, tables: list[str]) -> list[str]: |
| print(f"---- Find relevant tables for {plot['name']} ----") |
| relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm, tables) |
| return relevant_tables |
|
|
| async def find_param(state: State, param_name: str, mode: Literal['DRIAS', 'IPCC'] = 'DRIAS') -> dict[str, Optional[str]] | Location | None: |
| """ |
| Retrieves a specific parameter (location, year, month, etc.) from the user's input using the appropriate extraction method. |
| |
| Args: |
| state (State): The current state containing at least the user's input under 'user_input'. |
| param_name (str): The name of the parameter to extract. Supported: 'location', 'year', 'month'. |
| mode (Literal['DRIAS', 'IPCC']): The data mode to use for location extraction. |
| |
| Returns: |
| - For 'location': a Location object (dict with keys like 'location', 'latitude', etc.), or None if not found. |
| - For 'year': a dict {'year': year or None}. |
| - For 'month': a dict {'month_number': str or None, 'month_name': str or None}. |
| - None if the parameter is not recognized or not found. |
| |
| Example: |
| >>> await find_param(state, 'location') |
| {'location': 'Paris', 'latitude': ..., ...} |
| >>> await find_param(state, 'year') |
| {'year': '2050'} |
| >>> await find_param(state, 'month') |
| {'month_number': '7', 'month_name': 'July'} |
| """ |
| if param_name == 'location': |
| location = await find_location(state['user_input'], mode) |
| return location |
| if param_name == 'year': |
| year = await find_year(state['user_input']) |
| return {'year': year} |
| if param_name == 'month': |
| month = await find_month(state['user_input']) |
| return month |
| return None |