## Imports import pickle import warnings import streamlit as st from pathlib import Path import numpy as np import pandas as pd import matplotlib.pyplot as plt import datetime import torch from torch.distributions import Normal from pytorch_forecasting import ( TimeSeriesDataSet, TemporalFusionTransformer, ) from PIL import Image ## Functions def raw_preds_to_df(raw, quantiles = None): """ raw is output of model.predict with return_index=True quantiles can be provided like [0.1,0.5,0.9] to get interpretable quantiles in the output, time_idx is the first prediction time index (one step after knowledge cutoff) pred_idx the index of the predicted date i.e. time_idx + h - 1 """ index = raw[2] output = raw[0] preds = output.prediction dec_len = output.prediction.shape[1] n_quantiles = output.prediction.shape[-1] preds_df = pd.DataFrame(index.values.repeat(dec_len * n_quantiles, axis=0),columns=index.columns) preds_df = preds_df.assign(h=np.tile(np.repeat(np.arange(1,1+dec_len),n_quantiles),len(preds_df)//(dec_len*n_quantiles))) preds_df = preds_df.assign(q=np.tile(np.arange(n_quantiles),len(preds_df)//n_quantiles)) preds_df = preds_df.assign(pred=preds.flatten().cpu().numpy()) if quantiles is not None: preds_df['q'] = preds_df['q'].map({i:q for i,q in enumerate(quantiles)}) preds_df['pred_idx'] = preds_df['time_idx'] + preds_df['h'] - 1 return preds_df def prepare_dataset(_parameters, df, rain, temperature, datepicker, mapping): if rain != "Default": df["MTXWTH_Day_precip"] = mapping[rain] df["MTXWTH_Temp_min"] = df["MTXWTH_Temp_min"] + temperature df["MTXWTH_Temp_max"] = df["MTXWTH_Temp_max"] + temperature lowerbound = datepicker - datetime.timedelta(days = 35) upperbound = datepicker + datetime.timedelta(days = 30) df = df.loc[(df["Date"].dt.date>lowerbound) & (df["Date"].dt.date<=upperbound)] df = TimeSeriesDataSet.from_parameters(_parameters, df) return df.to_dataloader(train=False, batch_size=256,num_workers = 0) def predict(model, dataloader): out = model.predict(dataloader, mode="raw", return_x=True, return_index=True)#, trainer_kwargs=dict(accelerator="cpu")) preds = raw_preds_to_df(out, quantiles = None) return preds[["pred_idx", "Group", "pred"]] def adjust_data_for_plot(df, preds): df = pd.merge(df, preds, left_on=["time_idx", "Group"], right_on=["pred_idx", "Group"], how = "left") df = df[~df["pred"].isna()] df["sales"] = df["sales"].replace(0.0, np.nan) return df def generate_plot(df): fig, axs = plt.subplots(2, 2, figsize=(8, 6)) # Plot scatter plots for each group axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'sales'], color='grey') axs[0, 0].plot(df.loc[df['Group'] == '4', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red') axs[0, 0].set_title('Article Group 1') axs[0, 0].xaxis.set_tick_params(rotation=45) axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '7', 'sales'], color='grey') axs[0, 1].plot(df.loc[df['Group'] == '7', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red') axs[0, 1].set_title('Article Group 2') axs[0, 1].xaxis.set_tick_params(rotation=45) axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '1', 'sales'], color='grey') axs[1, 0].plot(df.loc[df['Group'] == '1', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red') axs[1, 0].set_title('Article Group 3') axs[1, 0].xaxis.set_tick_params(rotation=45) axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '6', 'sales'], color='grey') axs[1, 1].plot(df.loc[df['Group'] == '6', 'Date'], df.loc[df['Group'] == '4', 'pred'], color = 'red') axs[1, 1].set_title('Article Group 4') axs[1, 1].xaxis.set_tick_params(rotation=45) plt.tight_layout() return fig, axs @st.cache_data def load_data(): with open('data/parameters_q.pkl', 'rb') as f: parameters = pickle.load(f) df = pd.read_pickle('data/test_data.pkl') df = df.loc[(df["Branch"] == "15") & (df["Group"].isin(["6","7","4","1"]))] return parameters, df @st.cache_resource def init_model(): model = TemporalFusionTransformer.load_from_checkpoint('model/tft_check_q.ckpt', map_location=torch.device('cpu')) return model def main(): # Start App st.title("Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting") image = Image.open('data/image.png') st.image(image, caption='Coding.Waterkant Festival for AI') st.markdown(body = """ ### Abstract Multi-horizon forecasting often contains a complex mix of inputs – including static (i.e. time-invariant) covariates, known future inputs, and other exogenous time series that are only observed in the past – without any prior information on how they interact with the target. Several deep learning methods have been proposed, but they are typically ‘black-box’ models which do not shed light on how they use the full range of inputs present in practical scenarios. In this pa- per, we introduce the Temporal Fusion Transformer (TFT) – a novel attention- based architecture which combines high-performance multi-horizon forecasting with interpretable insights into temporal dynamics. To learn temporal rela- tionships at different scales, TFT uses recurrent layers for local processing and interpretable self-attention layers for long-term dependencies. TFT utilizes spe- cialized components to select relevant features and a series of gating layers to suppress unnecessary components, enabling high performance in a wide range of scenarios. On a variety of real-world datasets, we demonstrate significant per- formance improvements over existing benchmarks, and showcase three practical interpretability use cases of TFT. ### Experiments We implemented TFT for sales multi-horizon sales forecast during Coding.Waterkant. Please try our implementation and adjust some of the training data. Adjustments to the model and extention with Quantile forecast are coming soon ;) """) RAIN_MAPPING = { "Yes" : 1, "No" : 0 } parameters, df = load_data() model = init_model() datepicker = st.date_input("Start of Forecast", value = datetime.date(2022, 10, 24) ,min_value=datetime.date(2022, 6, 26) + datetime.timedelta(days = 35), max_value=datetime.date(2023, 6, 26) - datetime.timedelta(days = 30), key = "date") temperature = st.slider('Change in Temperature', min_value=-10.0, max_value=10.0, value=0.0, step=0.25, key = "temperature") rain = st.selectbox("Rain Indicator", ('Default', 'Yes', 'No'), key = "rain") dataloader = prepare_dataset(parameters, df.copy(), st.session_state.rain, st.session_state.temperature, st.session_state.date, RAIN_MAPPING) preds = predict(model, dataloader) data_plot = adjust_data_for_plot(df.copy(), preds) fig, _ = generate_plot(data_plot) st.pyplot(fig) st.markdown(body = """ ### Sources **Paper:** [Bryan Lim et al. in Temporal Fusion Transformers (TFT)](https://arxiv.org/abs/1912.09363).
**Demo created by:** [MalteLeuschner - leuschnm](https://github.com/MalteLeuschner) """, unsafe_allow_html = True) if __name__ == '__main__': main()