zhang qiao
Upload folder using huggingface_hub
8cf4695
raw
history blame
16.1 kB
import json
import io
import os
import tempfile
import datetime
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sktime.utils.plotting import plot_series
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from src.forecaster import Forecaster
from src.forecaster.models import XGBoost
from src.analyser import Analyser
from src.idsc import IDSC
from src.forecaster.models import ProphetForecaster
class GradioApp():
def __init__(
self
) -> None:
self.forecaster = Forecaster()
self.analyser = Analyser()
self.idsc = IDSC()
self.historical_demo_data = 'data/multivariate/demo_historical.csv'
self.future_demo_data = 'data/multivariate/demo_future.csv'
self.data: pd.DataFrame = None
self.n_predict = 3
self.window_length = 7
self.target_column = 'y'
self.exog_columns = []
# Define if the model's result is going to be rounded
self.round_results = True
# Delete old temp files oder than n minutes
self.delete_file_old_than_n_minutes = 10
self.plot_figsize_full_screen = (20, 4)
# -------------------- #
# Model Related Params #
# -------------------- #
# XGBoost #
self.xgboost = XGBoost()
self.xgboost_cv = False
self.xgboost_params = self.xgboost.cv_params
self.xgboost_strategy = 'recursive'
self.xgboost_forecast = None
self.xgboost_test = None
print('Init Gradio app')
# Prophet #
self.prophet = ProphetForecaster()
self.prophet__seasonality_mode = 'multiplicative'
self.prophet__add_country_holidays = {'country_name': 'Singapore'}
self.prophet__yearly_seasonality = True
self.prophet__weekly_seasonality = False
self.prophet__daily_seasonality = False
def checkbox__round_results__change(self, val):
self.round_results = val
def textbox__target_column__change(self, val):
print('Updating textbox__target_column:', val)
self.target_column = val
def btn__profiling__click(self):
self.analyser.fit(self.data)
self.analyser.profiling()
return (
self.update__md__profiling(),
self.update__plot__changepoints())
def btn__plot_correlation__click(self):
return (self.update__plot__correlation())
def file__historical__upload(
self,
file
):
self.data = pd.read_csv(
file.name,
index_col='datetime',
parse_dates=['datetime'])
print('[file__historical__upload]')
return (
self.update__df__table_view(),
self.update__dropdown__chart_view_filter(),
self.update__dropdown__seasonality_decompose(),
self.update__plot__chart_view())
def file__future__upload(
self,
file
):
self.__handle_future_data_upload(file.name)
return (
self.update__df__table_view(),
self.update__dropdown__chart_view_filter(),
self.update__dropdown__seasonality_decompose(),
self.update__plot__chart_view(),
self.update__number__n_predict())
def btn__load_future_demo__click(
self
):
self.__handle_future_data_upload(self.future_demo_data)
# [df__table_view, number__n_predict]
return (
self.update__df__table_view(),
self.update__dropdown__chart_view_filter(),
self.update__dropdown__seasonality_decompose(),
self.update__plot__chart_view(),
self.update__number__n_predict())
def __handle_future_data_upload(
self,
path
):
data = pd.read_csv(
path,
index_col='datetime',
parse_dates=['datetime'])
self.exog_columns = data.columns.tolist()
self.n_predict = len(data)
print(
f"[file__future__upload] with {self.exog_columns} columns")
self.data = pd.concat(
[self.data, data],
axis=0)
def number__n_predict__change(
self,
val
):
print(f'[number__n_predict__change], {val}')
self.n_predict = val
def number__window_length__change(
self,
val):
print(f'[number__window_length__change], {val}')
self.window_length = val
def btn__fit_data__click(
self):
data = self.data.drop(columns=self.exog_columns).dropna(how='any')
self.forecaster.fit(
data,
target_col=self.target_column,
n_predict=self.n_predict,
window_length=self.window_length,
exog=None if len(
self.exog_columns) == 0 else self.data[self.exog_columns])
return (
gr.Number(interactive=False), # number__n_predict
gr.Number(interactive=False), # number__window_length
gr.File(interactive=False), # file__historical
gr.File(interactive=False), # file__future
gr.Button(visible=False), # btn__fit_data
gr.Column(visible=True), # column__models
gr.Button(visible=False), # btn__load_historical_demo
gr.Button(visible=False), # btn__load_future_demo
self.update__md__forecast_data_info()
)
def btn__load_historical_demo__click(
self
):
self.data = pd.read_csv(
self.historical_demo_data,
index_col='datetime',
parse_dates=['datetime'])
return (
self.update__df__table_view(),
self.update__dropdown__chart_view_filter(),
self.update__dropdown__seasonality_decompose(),
self.update__plot__chart_view()
)
def dropdown__chart_view_filter__change(self, options):
return (self.update__plot__chart_view(options))
def dropdown__seasonality_decompose__change(self, col):
return (
self.update__plot__seasonality_decompose(col),
self.update__plot_acg_pacf(col))
# ------------------------ #
# XGboost Model Operations #
# ------------------------ #
def btn__train_xgboost__click(self):
(test, forecast, best_params) = self.xgboost.fit_predict(
y=self.forecaster.y,
y_train=self.forecaster.y_train,
window_length=self.forecaster.window_length,
fh=self.forecaster.fh,
fh_test=self.forecaster.fh_test,
params=self.xgboost_params,
X=self.forecaster.X,
X_train=self.forecaster.X_train,
X_test=self.forecaster.X_test,
X_future=self.forecaster.X_future
)
print(test, forecast, best_params)
self.xgboost_forecast = forecast
self.xgboost_test = test
return (
self.update__plot__xgboost_result(test, forecast),
self.update__file__xgboost_result(),
self.update__df__xgboost_result())
def btn__set_xgboost_params__click(self, text):
params = json.loads(text.replace("'", '"'))
self.xgboost_params = params
return (
self.update__json_xgboost_params()
)
def checkbox__xgboost_round__change(self, val):
self.xgboost.round_result = val
# ----------------------------------- #
# Prophet Model Operations & Updaters #
# ----------------------------------- #
def btn__forecast_with_prophet__click(self):
self.prophet.fit_predict(
self.forecaster.y_train,
self.forecaster.y,
self.forecaster.fh,
self.forecaster.fh_test,
self.forecaster.period,
self.forecaster.freq,
X=self.forecaster.exog,
seasonality_mode=self.prophet__seasonality_mode,
add_country_holidays=self.prophet__add_country_holidays,
yearly_seasonality=self.prophet__yearly_seasonality,
weekly_seasonality=self.prophet__weekly_seasonality,
daily_seasonality=self.prophet__daily_seasonality,
round_val=self.round_results)
return (
self.update__plot__prophet_result(),
self.update__file__prophet_result(),
self.update__df__prophet_result())
def update__plot__prophet_result(self):
fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen)
plot_series(
self.forecaster.y_train[-2 * self.forecaster.period:],
self.forecaster.y_test,
self.prophet.predict,
self.prophet.forecast,
pred_interval=self.prophet.forecast_interval,
labels=['Train', 'Test', 'Predicted - Test', 'Forecast'],
ax=ax)
ax.set_title('Prophet Forecast Result')
ax.legend(loc='upper left')
fig.tight_layout()
return gr.Plot(fig)
def update__file__prophet_result(self):
prophet_forecast_df = pd.DataFrame(self.prophet.forecast)
path = self.__create_temp_csv_file(prophet_forecast_df)
return gr.File(path)
def update__df__prophet_result(self):
prophet_forecast_df = self.prophet.forecast.reset_index()
return gr.Dataframe(value=prophet_forecast_df)
# =============================== #
# || Gradio Component Updaters || #
# =============================== #
def update__plot__changepoints(self):
fig, axs = plt.subplots(2, 1, figsize=(20, 8))
axs[0].plot(self.data[['y']])
axs[0].text(self.data.index[0],
axs[0].get_ylim()[1]*0.9,
self.analyser.quantity_predictability[0],
fontsize=20)
for i, p in enumerate(self.analyser.quantity_change_points):
axs[0].axvline(x=p)
axs[0].text(p,
axs[0].get_ylim()[1]*0.9,
self.analyser.quantity_predictability[i+1],
fontsize=20)
axs[1].plot(self.data[['y']])
axs[1].text(self.data.index[0],
axs[1].get_ylim()[1]*0.9,
self.analyser.intermittent_predictability[0],
fontsize=20)
for i, p in enumerate(self.analyser.intermittent_change_points):
axs[1].axvline(x=p)
axs[1].text(p,
axs[1].get_ylim()[1]*0.9,
self.analyser.intermittent_predictability[i+1],
fontsize=20)
axs[0].set_title('Quantity Change Points & Predictability')
axs[1].set_title('Intermittent Change Points & Predictability')
fig.tight_layout()
return gr.Plot(fig)
def update__md__profiling(self):
return (f"""
\n### Data Characteristic:
\n # {self.analyser.characteristic}
\n ---
\n### Quantity Change Points: {self.analyser.quantity_change_points.astype(str).tolist()}
\n### Quantity Predictability: {self.analyser.quantity_predictability}
\n### Intermittent Change Points: {self.analyser.intermittent_change_points.astype(str).tolist()}
\n### Intermittent Predictability: {self.analyser.intermittent_predictability}
""")
def update__md__forecast_data_info(self):
return gr.Markdown(value=f' \
**Forecasting for these timestamps**: \
{self.forecaster.fh.to_pandas().astype(str).tolist()} \
\n **Data Period**: {self.forecaster.period} \
\n **Data Frequency**: {self.forecaster.freq} \
')
def update__plot__correlation(self):
fig, ax = plt.subplots(figsize=(20, 8))
corr = self.data.corr(numeric_only=True)
mask = np.triu(np.ones_like(corr, dtype=bool))
sns.heatmap(
corr,
mask=mask,
square=True,
annot=True,
cmap='coolwarm',
linewidths=.5,
cbar_kws={"shrink": .5},
ax=ax)
fig.tight_layout()
return gr.Plot(fig)
def update__df__table_view(
self
):
data = self.data.reset_index()
return gr.Dataframe(value=data)
def update__number__n_predict(
self
):
return gr.Number(self.n_predict, interactive=False)
def update__dropdown__chart_view_filter(self):
options = self.data.columns.tolist()
return gr.Dropdown(options, value=options)
def update__dropdown__seasonality_decompose(self):
options = self.data.columns.tolist()
return gr.Dropdown(options)
def update__plot__seasonality_decompose(self, col):
seasonal = seasonal_decompose(self.data[col].dropna())
fig = seasonal.plot()
return gr.Plot(fig)
def update__plot_acg_pacf(self, col):
fig, axs = plt.subplots(2, 1, sharex=True, sharey=True)
plot_acf(self.data[col].dropna(), ax=axs[0], zero=False)
plot_pacf(self.data[col].dropna(), ax=axs[1], zero=False)
axs[0].set_title('Auto Correlation')
axs[1].set_title('Partial Auto Correlation')
return gr.Plot(fig)
# ---------------------- #
# Update XGboost Results #
# ---------------------- #
def update__json_xgboost_params(self):
return gr.JSON(value=self.xgboost_params)
def update__plot__xgboost_result(self, test, predict):
fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen)
plot_series(
self.forecaster.y_train[-2*self.forecaster.period:],
self.forecaster.y_test,
test,
predict,
labels=["y_train (part)", "y_test", "y_pred", 'y_forecast'],
x_label='Date',
ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
fig.tight_layout()
return gr.Plot(fig)
def update__plot__chart_view(self, cols=None):
fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen)
_cols = cols
if _cols is None:
_cols = self.data.columns
print('[update__plot__chart_view]')
for col in _cols:
ax.plot(self.data[[col]], label=col)
fig.legend()
fig.tight_layout()
return gr.Plot(fig)
def update__file__xgboost_result(self):
path = self.__create_temp_csv_file(self.xgboost_forecast)
return gr.File(path)
def update__df__xgboost_result(self):
# xgboost_forecast is actually a Series instead of proper DataFrame
# Re constructing a proper dataframe for gradio to take
data = pd.DataFrame(
{"datetime": self.xgboost_forecast.index,
"y": self.xgboost_forecast.values})
return gr.Dataframe(value=data)
# ------------- #
# Util Function #
# ------------- #
def __create_temp_csv_file(self, df) -> str:
time_format = "%Y%m%d%H%M%S"
directory = 'temp'
now = datetime.datetime.now()
# Check if there are old files, remove them #
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
file_time = datetime.datetime.strptime(
filename.split('.')[0], time_format)
# If the file is older than 3 minutes, delete the file
if now > datetime.timedelta(
minutes=self.delete_file_old_than_n_minutes) + file_time:
print('deleting olde file: ', filename)
os.remove(file_path)
new_file_name = now.strftime(format=time_format) + '.csv'
new_file_path = os.path.join(directory, new_file_name)
df.to_csv(new_file_path)
return new_file_path