armanddemasson's picture
feat: implemented talk to drias v1
4df74e4
raw
history blame
5.68 kB
from typing import Callable, TypedDict
import pandas as pd
from plotly.graph_objects import Figure
import plotly.graph_objects as go
from climateqa.engine.talk_to_data.sql_query import indicator_per_year_at_location_query
class Plot(TypedDict):
name: str
description: str
params: list[str]
plot_function: Callable[..., Callable[..., Figure]]
sql_query: Callable[..., str]
def plot_indicator_per_year_at_location(params: dict) -> Callable[..., Figure]:
"""Generate the function to plot a line plot of an indicator per year at a certain location
Args:
params (dict): dictionnary with the required params : model, indicator_column, location
Returns:
Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
"""
indicator = params["indicator_column"]
model = params["model"]
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
def plot_data(df: pd.DataFrame) -> Figure:
"""Generate the figure thanks to the dataframe
Args:
df (pd.DataFrame): pandas dataframe with the required data
Returns:
Figure: Plotly figure
"""
fig = go.Figure()
if model == "ALL":
df_avg = df.groupby("year", as_index=False)[indicator].mean()
# Transform to list to avoid pandas encoding
indicators = df_avg[indicator].astype(float).tolist()
years = df_avg["year"].astype(int).tolist()
# Compute the 10-year rolling average
sliding_averages = (
df_avg[indicator]
.rolling(window=10, min_periods=5)
.mean()
.astype(float)
.tolist()
)
else:
df_model = df[df["model"] == model]
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
years = df_model["year"].astype(int).tolist()
# Compute the 10-year rolling average
sliding_averages = (
df_model[indicator]
.rolling(window=10, min_periods=5)
.mean()
.astype(float)
.tolist()
)
# Indicator per year plot
fig.add_scatter(
x=years,
y=indicators,
name=f"Yearly {indicator_label}",
mode="lines",
)
# Sliding average dashed line
fig.add_scatter(
x=years,
y=sliding_averages,
mode="lines",
name="10 years rolling average",
line=dict(dash="dash"),
marker=dict(color="#1f77b4"),
)
fig.update_layout(
title=f"Plot of {indicator_label} in {params['location']} (Model Average)",
xaxis_title="Year",
yaxis_title=indicator_label,
template="plotly_white",
)
return fig
return plot_data
indicator_per_year_at_location: Plot = {
"name": "Indicator per year at location",
"description": "Plot an evolution of the indicator at a certain location over the years",
"params": ["indicator_column", "location", "model"],
"plot_function": plot_indicator_per_year_at_location,
"sql_query": indicator_per_year_at_location_query,
}
def plot_indicator_number_of_days_per_year_at_location(params) -> Callable[..., Figure]:
"""Generate the function to plot a line plot of an indicator per year at a certain location
Args:
params (dict): dictionnary with the required params : model, indicator_column, location
Returns:
Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe
"""
indicator = params["indicator_column"]
model = params["model"]
def plot_data(df) -> Figure:
fig = go.Figure()
if params["model"] == "ALL":
df_avg = df.groupby("year", as_index=False)[indicator].mean()
# Transform to list to avoid pandas encoding
indicators = df_avg[indicator].astype(float).tolist()
years = df_avg["year"].astype(int).tolist()
else:
df_model = df[df["model"] == model]
# Transform to list to avoid pandas encoding
indicators = df_model[indicator].astype(float).tolist()
years = df_model["year"].astype(int).tolist()
# Bar plot
fig.add_trace(
go.Bar(
x=years,
y=indicators,
width=0.5,
marker=dict(color="#1f77b4"),
)
)
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
fig.update_layout(
title=f"{indicator_label} in {params['location']} (Model Average)",
xaxis_title="Year",
yaxis_title=indicator,
yaxis=dict(range=[0, 366]),
bargap=0.5,
template="plotly_white",
)
return fig
return plot_data
indicator_number_of_days_per_year_at_location: Plot = {
"name": "Indicator number of days per year at location",
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
"params": ["indicator_column", "location", "model"],
"plot_function": plot_indicator_number_of_days_per_year_at_location,
"sql_query": indicator_per_year_at_location_query,
}
PLOTS = [indicator_per_year_at_location, indicator_number_of_days_per_year_at_location]