|
|
from typing import Callable, TypedDict |
|
|
import pandas as pd |
|
|
from plotly.graph_objects import Figure |
|
|
import plotly.graph_objects as go |
|
|
|
|
|
from climateqa.engine.talk_to_data.sql_query import indicator_per_year_at_location_query |
|
|
|
|
|
|
|
|
class Plot(TypedDict): |
|
|
name: str |
|
|
description: str |
|
|
params: list[str] |
|
|
plot_function: Callable[..., Callable[..., Figure]] |
|
|
sql_query: Callable[..., str] |
|
|
|
|
|
|
|
|
def plot_indicator_per_year_at_location(params: dict) -> Callable[..., Figure]: |
|
|
"""Generate the function to plot a line plot of an indicator per year at a certain location |
|
|
|
|
|
Args: |
|
|
params (dict): dictionnary with the required params : model, indicator_column, location |
|
|
|
|
|
Returns: |
|
|
Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe |
|
|
""" |
|
|
indicator = params["indicator_column"] |
|
|
model = params["model"] |
|
|
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")]) |
|
|
|
|
|
def plot_data(df: pd.DataFrame) -> Figure: |
|
|
"""Generate the figure thanks to the dataframe |
|
|
|
|
|
Args: |
|
|
df (pd.DataFrame): pandas dataframe with the required data |
|
|
|
|
|
Returns: |
|
|
Figure: Plotly figure |
|
|
""" |
|
|
fig = go.Figure() |
|
|
if model == "ALL": |
|
|
df_avg = df.groupby("year", as_index=False)[indicator].mean() |
|
|
|
|
|
|
|
|
indicators = df_avg[indicator].astype(float).tolist() |
|
|
years = df_avg["year"].astype(int).tolist() |
|
|
|
|
|
|
|
|
sliding_averages = ( |
|
|
df_avg[indicator] |
|
|
.rolling(window=10, min_periods=5) |
|
|
.mean() |
|
|
.astype(float) |
|
|
.tolist() |
|
|
) |
|
|
else: |
|
|
df_model = df[df["model"] == model] |
|
|
|
|
|
|
|
|
indicators = df_model[indicator].astype(float).tolist() |
|
|
years = df_model["year"].astype(int).tolist() |
|
|
|
|
|
|
|
|
sliding_averages = ( |
|
|
df_model[indicator] |
|
|
.rolling(window=10, min_periods=5) |
|
|
.mean() |
|
|
.astype(float) |
|
|
.tolist() |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_scatter( |
|
|
x=years, |
|
|
y=indicators, |
|
|
name=f"Yearly {indicator_label}", |
|
|
mode="lines", |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_scatter( |
|
|
x=years, |
|
|
y=sliding_averages, |
|
|
mode="lines", |
|
|
name="10 years rolling average", |
|
|
line=dict(dash="dash"), |
|
|
marker=dict(color="#1f77b4"), |
|
|
) |
|
|
fig.update_layout( |
|
|
title=f"Plot of {indicator_label} in {params['location']} (Model Average)", |
|
|
xaxis_title="Year", |
|
|
yaxis_title=indicator_label, |
|
|
template="plotly_white", |
|
|
) |
|
|
return fig |
|
|
|
|
|
return plot_data |
|
|
|
|
|
|
|
|
indicator_per_year_at_location: Plot = { |
|
|
"name": "Indicator per year at location", |
|
|
"description": "Plot an evolution of the indicator at a certain location over the years", |
|
|
"params": ["indicator_column", "location", "model"], |
|
|
"plot_function": plot_indicator_per_year_at_location, |
|
|
"sql_query": indicator_per_year_at_location_query, |
|
|
} |
|
|
|
|
|
|
|
|
def plot_indicator_number_of_days_per_year_at_location(params) -> Callable[..., Figure]: |
|
|
"""Generate the function to plot a line plot of an indicator per year at a certain location |
|
|
|
|
|
Args: |
|
|
params (dict): dictionnary with the required params : model, indicator_column, location |
|
|
|
|
|
Returns: |
|
|
Callable[..., Figure]: Function which can be call to create the figure with the associated dataframe |
|
|
""" |
|
|
|
|
|
indicator = params["indicator_column"] |
|
|
model = params["model"] |
|
|
|
|
|
def plot_data(df) -> Figure: |
|
|
fig = go.Figure() |
|
|
if params["model"] == "ALL": |
|
|
df_avg = df.groupby("year", as_index=False)[indicator].mean() |
|
|
|
|
|
|
|
|
indicators = df_avg[indicator].astype(float).tolist() |
|
|
years = df_avg["year"].astype(int).tolist() |
|
|
|
|
|
else: |
|
|
df_model = df[df["model"] == model] |
|
|
|
|
|
|
|
|
indicators = df_model[indicator].astype(float).tolist() |
|
|
years = df_model["year"].astype(int).tolist() |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar( |
|
|
x=years, |
|
|
y=indicators, |
|
|
width=0.5, |
|
|
marker=dict(color="#1f77b4"), |
|
|
) |
|
|
) |
|
|
|
|
|
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")]) |
|
|
|
|
|
fig.update_layout( |
|
|
title=f"{indicator_label} in {params['location']} (Model Average)", |
|
|
xaxis_title="Year", |
|
|
yaxis_title=indicator, |
|
|
yaxis=dict(range=[0, 366]), |
|
|
bargap=0.5, |
|
|
template="plotly_white", |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
return plot_data |
|
|
|
|
|
|
|
|
indicator_number_of_days_per_year_at_location: Plot = { |
|
|
"name": "Indicator number of days per year at location", |
|
|
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.", |
|
|
"params": ["indicator_column", "location", "model"], |
|
|
"plot_function": plot_indicator_number_of_days_per_year_at_location, |
|
|
"sql_query": indicator_per_year_at_location_query, |
|
|
} |
|
|
|
|
|
|
|
|
PLOTS = [indicator_per_year_at_location, indicator_number_of_days_per_year_at_location] |
|
|
|