|
import streamlit as st |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import pytorch_lightning as pl |
|
from neuralforecast.core import NeuralForecast |
|
from neuralforecast.models import NHITS, TimesNet, LSTM, TFT |
|
from neuralforecast.losses.pytorch import HuberMQLoss |
|
from neuralforecast.utils import AirPassengersDF |
|
import time |
|
|
|
@st.cache_resource |
|
def load_model(path, freq): |
|
nf = NeuralForecast.load(path=path) |
|
return nf |
|
|
|
@st.cache_resource |
|
def load_all_models(): |
|
nhits_paths = { |
|
'D': './M4/NHITS/daily', |
|
'M': './M4/NHITS/monthly', |
|
'H': './M4/NHITS/hourly', |
|
'W': './M4/NHITS/weekly', |
|
'Y': './M4/NHITS/yearly' |
|
} |
|
|
|
timesnet_paths = { |
|
'D': './M4/TimesNet/daily', |
|
'M': './M4/TimesNet/monthly', |
|
'H': './M4/TimesNet/hourly', |
|
'W': './M4/TimesNet/weekly', |
|
'Y': './M4/TimesNet/yearly' |
|
} |
|
|
|
lstm_paths = { |
|
'D': './M4/LSTM/daily', |
|
'M': './M4/LSTM/monthly', |
|
'H': './M4/LSTM/hourly', |
|
'W': './M4/LSTM/weekly', |
|
'Y': './M4/LSTM/yearly' |
|
} |
|
|
|
tft_paths = { |
|
'D': './M4/TFT/daily', |
|
'M': './M4/TFT/monthly', |
|
'H': './M4/TFT/hourly', |
|
'W': './M4/TFT/weekly', |
|
'Y': './M4/TFT/yearly' |
|
} |
|
nhits_models = {freq: load_model(path, freq) for freq, path in nhits_paths.items()} |
|
timesnet_models = {freq: load_model(path, freq) for freq, path in timesnet_paths.items()} |
|
lstm_models = {freq: load_model(path, freq) for freq, path in lstm_paths.items()} |
|
tft_models = {freq: load_model(path, freq) for freq, path in tft_paths.items()} |
|
|
|
return nhits_models, timesnet_models, lstm_models, tft_models |
|
|
|
def generate_forecast(model, df): |
|
forecast_df = model.predict(df=df) |
|
return forecast_df |
|
|
|
def determine_frequency(df, ds_col): |
|
df[ds_col] = pd.to_datetime(df[ds_col]) |
|
df = df.set_index(ds_col) |
|
freq = pd.infer_freq(df.index) |
|
return freq |
|
|
|
def plot_forecasts(forecast_df, train_df, title): |
|
fig, ax = plt.subplots(1, 1, figsize=(20, 7)) |
|
plot_df = pd.concat([train_df, forecast_df]).set_index('ds') |
|
historical_col = 'y' |
|
forecast_col = next((col for col in plot_df.columns if 'median' in col), None) |
|
lo_col = next((col for col in plot_df.columns if 'lo-90' in col), None) |
|
hi_col = next((col for col in plot_df.columns if 'hi-90' in col), None) |
|
if forecast_col is None: |
|
raise KeyError("No forecast column found in the data.") |
|
plot_df[[historical_col, forecast_col]].plot(ax=ax, linewidth=2, label=['Historical', 'Forecast']) |
|
if lo_col and hi_col: |
|
ax.fill_between( |
|
plot_df.index, |
|
plot_df[lo_col], |
|
plot_df[hi_col], |
|
color='blue', |
|
alpha=0.3, |
|
label='90% Confidence Interval' |
|
) |
|
ax.set_title(title, fontsize=22) |
|
ax.set_ylabel('Value', fontsize=20) |
|
ax.set_xlabel('Timestamp [t]', fontsize=20) |
|
ax.legend(prop={'size': 15}) |
|
ax.grid() |
|
st.pyplot(fig) |
|
|
|
def select_model_based_on_frequency(freq, nhits_models, timesnet_models, lstm_models, tft_models): |
|
if freq == 'D': |
|
return nhits_models['D'], timesnet_models['D'], lstm_models['D'], tft_models['D'] |
|
elif freq == 'ME': |
|
return nhits_models['M'], timesnet_models['M'], lstm_models['M'], tft_models['M'] |
|
elif freq == 'H': |
|
return nhits_models['H'], timesnet_models['H'], lstm_models['H'], tft_models['H'] |
|
elif freq in ['W', 'W-SUN']: |
|
return nhits_models['W'], timesnet_models['W'], lstm_models['W'], tft_models['W'] |
|
elif freq in ['Y', 'Y-DEC']: |
|
return nhits_models['Y'], timesnet_models['Y'], lstm_models['Y'], tft_models['Y'] |
|
else: |
|
raise ValueError(f"Unsupported frequency: {freq}") |
|
|
|
def select_model(horizon, model_type, max_steps=200): |
|
if model_type == 'NHITS': |
|
return NHITS(input_size=5 * horizon, |
|
h=horizon, |
|
max_steps=max_steps, |
|
stack_types=3*['identity'], |
|
n_blocks=3*[1], |
|
mlp_units=[[256, 256] for _ in range(3)], |
|
n_pool_kernel_size=3*[1], |
|
batch_size=32, |
|
scaler_type='standard', |
|
n_freq_downsample=[12, 4, 1], |
|
loss=HuberMQLoss(level=[90])) |
|
elif model_type == 'TimesNet': |
|
return TimesNet(h=horizon, |
|
input_size=horizon * 5, |
|
hidden_size=16, |
|
conv_hidden_size=32, |
|
loss=HuberMQLoss(level=[90]), |
|
scaler_type='standard', |
|
learning_rate=1e-3, |
|
max_steps=max_steps, |
|
val_check_steps=200, |
|
valid_batch_size=64, |
|
windows_batch_size=128, |
|
inference_windows_batch_size=512) |
|
elif model_type == 'LSTM': |
|
return LSTM(h=horizon, |
|
input_size=horizon * 5, |
|
loss=HuberMQLoss(level=[90]), |
|
scaler_type='standard', |
|
encoder_n_layers=2, |
|
encoder_hidden_size=64, |
|
context_size=10, |
|
decoder_hidden_size=64, |
|
decoder_layers=2, |
|
max_steps=max_steps) |
|
elif model_type == 'TFT': |
|
return TFT(h=horizon, |
|
input_size=horizon, |
|
hidden_size=16, |
|
loss=HuberMQLoss(level=[90]), |
|
learning_rate=0.005, |
|
scaler_type='standard', |
|
windows_batch_size=128, |
|
max_steps=max_steps, |
|
val_check_steps=200, |
|
valid_batch_size=64, |
|
enable_progress_bar=True) |
|
else: |
|
raise ValueError(f"Unsupported model type: {model_type}") |
|
|
|
def model_train(df,model, ds_col): |
|
df[ds_col] = pd.to_datetime(df[ds_col]) |
|
model.fit(df) |
|
return model |
|
|
|
def forecast_time_series(df, model_type, horizon, max_steps=200, ds_col='ds'): |
|
start_time = time.time() |
|
freq = determine_frequency(df, ds_col) |
|
st.sidebar.write(f"Data frequency: {freq}") |
|
|
|
selected_model = select_model(horizon, model_type, max_steps) |
|
model = model_train(df, selected_model, ds_col) |
|
|
|
forecast_results = {} |
|
st.sidebar.write(f"Generating forecast using {model_type} model...") |
|
forecast_results[model_type] = generate_forecast(model, df) |
|
|
|
for model_name, forecast_df in forecast_results.items(): |
|
plot_forecasts(forecast_df, df, f'{model_name} Forecast Comparison') |
|
|
|
end_time = time.time() |
|
time_taken = end_time - start_time |
|
st.sidebar.success(f"Time taken for {model_type} forecast: {time_taken:.2f} seconds") |
|
|
|
@st.cache_data |
|
def load_default(): |
|
df = AirPassengersDF.copy() |
|
return df |
|
|
|
def transfer_learning_forecasting(): |
|
st.title("Transfer Learning Forecasting") |
|
|
|
nhits_model, timesnet_model, lstm_model, tft_model = load_all_models() |
|
with st.sidebar.expander("Upload and Configure Dataset", expanded=True): |
|
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) |
|
if uploaded_file: |
|
df = pd.read_csv(uploaded_file) |
|
st.session_state.df = df |
|
else: |
|
df = load_default() |
|
st.session_state.df = df |
|
|
|
|
|
columns = df.columns.tolist() |
|
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0) |
|
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1) |
|
|
|
|
|
st.session_state.ds_col = ds_col |
|
st.session_state.y_col = y_col |
|
|
|
|
|
st.sidebar.subheader("Model Selection and Forecasting") |
|
model_choice = st.sidebar.selectbox("Select model", ["NHITS", "TimesNet", "LSTM", "TFT"]) |
|
horizon = st.sidebar.number_input("Forecast horizon", value=18) |
|
|
|
df = df.rename(columns={ds_col: 'ds', y_col: 'y'}) |
|
df['unique_id']=1 |
|
st.session_state.df = df |
|
|
|
|
|
frequency = determine_frequency(df, 'ds') |
|
st.sidebar.write(f"Detected frequency: {frequency}") |
|
|
|
|
|
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models) |
|
forecast_results = {} |
|
|
|
start_time = time.time() |
|
if model_choice == "NHITS": |
|
forecast_results['NHITS'] = generate_forecast(nhits_model, df) |
|
elif model_choice == "TimesNet": |
|
forecast_results['TimesNet'] = generate_forecast(timesnet_model, df) |
|
elif model_choice == "LSTM": |
|
forecast_results['LSTM'] = generate_forecast(lstm_model, df) |
|
elif model_choice == "TFT": |
|
forecast_results['TFT'] = generate_forecast(tft_model, df) |
|
|
|
for model_name, forecast_df in forecast_results.items(): |
|
plot_forecasts(forecast_df, df, f'{model_name} Forecast') |
|
|
|
end_time = time.time() |
|
time_taken = end_time - start_time |
|
st.sidebar.success(f"Time taken for {model_choice} forecast: {time_taken:.2f} seconds") |
|
|
|
def dynamic_forecasting(): |
|
st.title("Dynamic Forecasting") |
|
|
|
with st.sidebar.expander("Upload and Configure Dataset", expanded=True): |
|
uploaded_file = st.file_uploader("Upload your time series data (CSV)", type=["csv"]) |
|
if uploaded_file: |
|
df = pd.read_csv(uploaded_file) |
|
st.session_state.df = df |
|
else: |
|
df = load_default() |
|
st.session_state.df = df |
|
|
|
|
|
columns = df.columns.tolist() |
|
ds_col = st.selectbox("Select Date/Time column", options=columns, index=columns.index('ds') if 'ds' in columns else 0) |
|
y_col = st.selectbox("Select Target column", options=columns, index=columns.index('y') if 'y' in columns else 1) |
|
|
|
|
|
df['unique_id']=1 |
|
st.session_state.ds_col = ds_col |
|
st.session_state.y_col = y_col |
|
|
|
|
|
st.sidebar.subheader("Dynamic Model Selection and Forecasting") |
|
dynamic_model_choice = st.sidebar.selectbox("Select model for dynamic forecasting", ["NHITS", "TimesNet", "LSTM", "TFT"], key="dynamic_model_choice") |
|
dynamic_horizon = st.sidebar.number_input("Forecast horizon", value=18) |
|
dynamic_max_steps = st.sidebar.number_input('Max steps', value=200) |
|
|
|
df = df.rename(columns={ds_col: 'ds', y_col: 'y'}) |
|
st.session_state.df = df |
|
|
|
forecast_time_series(df, dynamic_model_choice, dynamic_horizon, dynamic_max_steps, ds_col='ds') |
|
|
|
pg = st.navigation({ |
|
"Overview": [ |
|
|
|
st.Page(transfer_learning_forecasting, title="Transfer Learning Forecasting", default=True, icon=":material/query_stats:"), |
|
st.Page(dynamic_forecasting, title="Dynamic Forecasting", icon=":material/monitoring:"), |
|
] |
|
}) |
|
|
|
try: |
|
pg.run() |
|
except Exception as e: |
|
st.sidebar.error(f"Something went wrong: {e}", icon=":material/error:") |
|
|
|
|