Spaces:
Runtime error
Runtime error
import json | |
import io | |
import os | |
import tempfile | |
import datetime | |
import gradio as gr | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
from sktime.utils.plotting import plot_series | |
from statsmodels.tsa.seasonal import seasonal_decompose | |
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf | |
from src.forecaster import Forecaster | |
from src.forecaster.models import XGBoost | |
from src.analyser import Analyser | |
from src.idsc import IDSC | |
from src.forecaster.models import ProphetForecaster | |
class GradioApp(): | |
def __init__( | |
self | |
) -> None: | |
self.forecaster = Forecaster() | |
self.analyser = Analyser() | |
self.idsc = IDSC() | |
self.historical_demo_data = 'data/multivariate/demo_historical.csv' | |
self.future_demo_data = 'data/multivariate/demo_future.csv' | |
self.data: pd.DataFrame = None | |
self.n_predict = 3 | |
self.window_length = 7 | |
self.target_column = 'y' | |
self.exog_columns = [] | |
# Define if the model's result is going to be rounded | |
self.round_results = True | |
# Delete old temp files oder than n minutes | |
self.delete_file_old_than_n_minutes = 10 | |
self.plot_figsize_full_screen = (20, 4) | |
# -------------------- # | |
# Model Related Params # | |
# -------------------- # | |
# XGBoost # | |
self.xgboost = XGBoost() | |
self.xgboost_cv = False | |
self.xgboost_params = self.xgboost.cv_params | |
self.xgboost_strategy = 'recursive' | |
self.xgboost_forecast = None | |
self.xgboost_test = None | |
print('Init Gradio app') | |
# Prophet # | |
self.prophet = ProphetForecaster() | |
self.prophet__seasonality_mode = 'multiplicative' | |
self.prophet__add_country_holidays = {'country_name': 'Singapore'} | |
self.prophet__yearly_seasonality = True | |
self.prophet__weekly_seasonality = False | |
self.prophet__daily_seasonality = False | |
def checkbox__round_results__change(self, val): | |
self.round_results = val | |
def textbox__target_column__change(self, val): | |
print('Updating textbox__target_column:', val) | |
self.target_column = val | |
def btn__profiling__click(self): | |
self.analyser.fit(self.data) | |
self.analyser.profiling() | |
return ( | |
self.update__md__profiling(), | |
self.update__plot__changepoints()) | |
def btn__plot_correlation__click(self): | |
return (self.update__plot__correlation()) | |
def file__historical__upload( | |
self, | |
file | |
): | |
self.data = pd.read_csv( | |
file.name, | |
index_col='datetime', | |
parse_dates=['datetime']) | |
print('[file__historical__upload]') | |
return ( | |
self.update__df__table_view(), | |
self.update__dropdown__chart_view_filter(), | |
self.update__dropdown__seasonality_decompose(), | |
self.update__plot__chart_view()) | |
def file__future__upload( | |
self, | |
file | |
): | |
self.__handle_future_data_upload(file.name) | |
return ( | |
self.update__df__table_view(), | |
self.update__dropdown__chart_view_filter(), | |
self.update__dropdown__seasonality_decompose(), | |
self.update__plot__chart_view(), | |
self.update__number__n_predict()) | |
def btn__load_future_demo__click( | |
self | |
): | |
self.__handle_future_data_upload(self.future_demo_data) | |
# [df__table_view, number__n_predict] | |
return ( | |
self.update__df__table_view(), | |
self.update__dropdown__chart_view_filter(), | |
self.update__dropdown__seasonality_decompose(), | |
self.update__plot__chart_view(), | |
self.update__number__n_predict()) | |
def __handle_future_data_upload( | |
self, | |
path | |
): | |
data = pd.read_csv( | |
path, | |
index_col='datetime', | |
parse_dates=['datetime']) | |
self.exog_columns = data.columns.tolist() | |
self.n_predict = len(data) | |
print( | |
f"[file__future__upload] with {self.exog_columns} columns") | |
self.data = pd.concat( | |
[self.data, data], | |
axis=0) | |
def number__n_predict__change( | |
self, | |
val | |
): | |
print(f'[number__n_predict__change], {val}') | |
self.n_predict = val | |
def number__window_length__change( | |
self, | |
val): | |
print(f'[number__window_length__change], {val}') | |
self.window_length = val | |
def btn__fit_data__click( | |
self): | |
data = self.data.drop(columns=self.exog_columns).dropna(how='any') | |
self.forecaster.fit( | |
data, | |
target_col=self.target_column, | |
n_predict=self.n_predict, | |
window_length=self.window_length, | |
exog=None if len( | |
self.exog_columns) == 0 else self.data[self.exog_columns]) | |
return ( | |
gr.Number(interactive=False), # number__n_predict | |
gr.Number(interactive=False), # number__window_length | |
gr.File(interactive=False), # file__historical | |
gr.File(interactive=False), # file__future | |
gr.Button(visible=False), # btn__fit_data | |
gr.Column(visible=True), # column__models | |
gr.Button(visible=False), # btn__load_historical_demo | |
gr.Button(visible=False), # btn__load_future_demo | |
self.update__md__forecast_data_info() | |
) | |
def btn__load_historical_demo__click( | |
self | |
): | |
self.data = pd.read_csv( | |
self.historical_demo_data, | |
index_col='datetime', | |
parse_dates=['datetime']) | |
return ( | |
self.update__df__table_view(), | |
self.update__dropdown__chart_view_filter(), | |
self.update__dropdown__seasonality_decompose(), | |
self.update__plot__chart_view() | |
) | |
def dropdown__chart_view_filter__change(self, options): | |
return (self.update__plot__chart_view(options)) | |
def dropdown__seasonality_decompose__change(self, col): | |
return ( | |
self.update__plot__seasonality_decompose(col), | |
self.update__plot_acg_pacf(col)) | |
# ------------------------ # | |
# XGboost Model Operations # | |
# ------------------------ # | |
def btn__train_xgboost__click(self): | |
(test, forecast, best_params) = self.xgboost.fit_predict( | |
y=self.forecaster.y, | |
y_train=self.forecaster.y_train, | |
window_length=self.forecaster.window_length, | |
fh=self.forecaster.fh, | |
fh_test=self.forecaster.fh_test, | |
params=self.xgboost_params, | |
X=self.forecaster.X, | |
X_train=self.forecaster.X_train, | |
X_test=self.forecaster.X_test, | |
X_future=self.forecaster.X_future | |
) | |
print(test, forecast, best_params) | |
self.xgboost_forecast = forecast | |
self.xgboost_test = test | |
return ( | |
self.update__plot__xgboost_result(test, forecast), | |
self.update__file__xgboost_result(), | |
self.update__df__xgboost_result()) | |
def btn__set_xgboost_params__click(self, text): | |
params = json.loads(text.replace("'", '"')) | |
self.xgboost_params = params | |
return ( | |
self.update__json_xgboost_params() | |
) | |
def checkbox__xgboost_round__change(self, val): | |
self.xgboost.round_result = val | |
# ----------------------------------- # | |
# Prophet Model Operations & Updaters # | |
# ----------------------------------- # | |
def btn__forecast_with_prophet__click(self): | |
self.prophet.fit_predict( | |
self.forecaster.y_train, | |
self.forecaster.y, | |
self.forecaster.fh, | |
self.forecaster.fh_test, | |
self.forecaster.period, | |
self.forecaster.freq, | |
X=self.forecaster.exog, | |
seasonality_mode=self.prophet__seasonality_mode, | |
add_country_holidays=self.prophet__add_country_holidays, | |
yearly_seasonality=self.prophet__yearly_seasonality, | |
weekly_seasonality=self.prophet__weekly_seasonality, | |
daily_seasonality=self.prophet__daily_seasonality, | |
round_val=self.round_results) | |
return ( | |
self.update__plot__prophet_result(), | |
self.update__file__prophet_result(), | |
self.update__df__prophet_result()) | |
def update__plot__prophet_result(self): | |
fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen) | |
plot_series( | |
self.forecaster.y_train[-2 * self.forecaster.period:], | |
self.forecaster.y_test, | |
self.prophet.predict, | |
self.prophet.forecast, | |
pred_interval=self.prophet.forecast_interval, | |
labels=['Train', 'Test', 'Predicted - Test', 'Forecast'], | |
ax=ax) | |
ax.set_title('Prophet Forecast Result') | |
ax.legend(loc='upper left') | |
fig.tight_layout() | |
return gr.Plot(fig) | |
def update__file__prophet_result(self): | |
prophet_forecast_df = pd.DataFrame(self.prophet.forecast) | |
path = self.__create_temp_csv_file(prophet_forecast_df) | |
return gr.File(path) | |
def update__df__prophet_result(self): | |
prophet_forecast_df = self.prophet.forecast.reset_index() | |
return gr.Dataframe(value=prophet_forecast_df) | |
# =============================== # | |
# || Gradio Component Updaters || # | |
# =============================== # | |
def update__plot__changepoints(self): | |
fig, axs = plt.subplots(2, 1, figsize=(20, 8)) | |
axs[0].plot(self.data[['y']]) | |
axs[0].text(self.data.index[0], | |
axs[0].get_ylim()[1]*0.9, | |
self.analyser.quantity_predictability[0], | |
fontsize=20) | |
for i, p in enumerate(self.analyser.quantity_change_points): | |
axs[0].axvline(x=p) | |
axs[0].text(p, | |
axs[0].get_ylim()[1]*0.9, | |
self.analyser.quantity_predictability[i+1], | |
fontsize=20) | |
axs[1].plot(self.data[['y']]) | |
axs[1].text(self.data.index[0], | |
axs[1].get_ylim()[1]*0.9, | |
self.analyser.intermittent_predictability[0], | |
fontsize=20) | |
for i, p in enumerate(self.analyser.intermittent_change_points): | |
axs[1].axvline(x=p) | |
axs[1].text(p, | |
axs[1].get_ylim()[1]*0.9, | |
self.analyser.intermittent_predictability[i+1], | |
fontsize=20) | |
axs[0].set_title('Quantity Change Points & Predictability') | |
axs[1].set_title('Intermittent Change Points & Predictability') | |
fig.tight_layout() | |
return gr.Plot(fig) | |
def update__md__profiling(self): | |
return (f""" | |
\n### Data Characteristic: | |
\n # {self.analyser.characteristic} | |
\n --- | |
\n### Quantity Change Points: {self.analyser.quantity_change_points.astype(str).tolist()} | |
\n### Quantity Predictability: {self.analyser.quantity_predictability} | |
\n### Intermittent Change Points: {self.analyser.intermittent_change_points.astype(str).tolist()} | |
\n### Intermittent Predictability: {self.analyser.intermittent_predictability} | |
""") | |
def update__md__forecast_data_info(self): | |
return gr.Markdown(value=f' \ | |
**Forecasting for these timestamps**: \ | |
{self.forecaster.fh.to_pandas().astype(str).tolist()} \ | |
\n **Data Period**: {self.forecaster.period} \ | |
\n **Data Frequency**: {self.forecaster.freq} \ | |
') | |
def update__plot__correlation(self): | |
fig, ax = plt.subplots(figsize=(20, 8)) | |
corr = self.data.corr(numeric_only=True) | |
mask = np.triu(np.ones_like(corr, dtype=bool)) | |
sns.heatmap( | |
corr, | |
mask=mask, | |
square=True, | |
annot=True, | |
cmap='coolwarm', | |
linewidths=.5, | |
cbar_kws={"shrink": .5}, | |
ax=ax) | |
fig.tight_layout() | |
return gr.Plot(fig) | |
def update__df__table_view( | |
self | |
): | |
data = self.data.reset_index() | |
return gr.Dataframe(value=data) | |
def update__number__n_predict( | |
self | |
): | |
return gr.Number(self.n_predict, interactive=False) | |
def update__dropdown__chart_view_filter(self): | |
options = self.data.columns.tolist() | |
return gr.Dropdown(options, value=options) | |
def update__dropdown__seasonality_decompose(self): | |
options = self.data.columns.tolist() | |
return gr.Dropdown(options) | |
def update__plot__seasonality_decompose(self, col): | |
seasonal = seasonal_decompose(self.data[col].dropna()) | |
fig = seasonal.plot() | |
return gr.Plot(fig) | |
def update__plot_acg_pacf(self, col): | |
fig, axs = plt.subplots(2, 1, sharex=True, sharey=True) | |
plot_acf(self.data[col].dropna(), ax=axs[0], zero=False) | |
plot_pacf(self.data[col].dropna(), ax=axs[1], zero=False) | |
axs[0].set_title('Auto Correlation') | |
axs[1].set_title('Partial Auto Correlation') | |
return gr.Plot(fig) | |
# ---------------------- # | |
# Update XGboost Results # | |
# ---------------------- # | |
def update__json_xgboost_params(self): | |
return gr.JSON(value=self.xgboost_params) | |
def update__plot__xgboost_result(self, test, predict): | |
fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen) | |
plot_series( | |
self.forecaster.y_train[-2*self.forecaster.period:], | |
self.forecaster.y_test, | |
test, | |
predict, | |
labels=["y_train (part)", "y_test", "y_pred", 'y_forecast'], | |
x_label='Date', | |
ax=ax) | |
ax.set_xticklabels(ax.get_xticklabels(), rotation=45) | |
fig.tight_layout() | |
return gr.Plot(fig) | |
def update__plot__chart_view(self, cols=None): | |
fig, ax = plt.subplots(figsize=self.plot_figsize_full_screen) | |
_cols = cols | |
if _cols is None: | |
_cols = self.data.columns | |
print('[update__plot__chart_view]') | |
for col in _cols: | |
ax.plot(self.data[[col]], label=col) | |
fig.legend() | |
fig.tight_layout() | |
return gr.Plot(fig) | |
def update__file__xgboost_result(self): | |
path = self.__create_temp_csv_file(self.xgboost_forecast) | |
return gr.File(path) | |
def update__df__xgboost_result(self): | |
# xgboost_forecast is actually a Series instead of proper DataFrame | |
# Re constructing a proper dataframe for gradio to take | |
data = pd.DataFrame( | |
{"datetime": self.xgboost_forecast.index, | |
"y": self.xgboost_forecast.values}) | |
return gr.Dataframe(value=data) | |
# ------------- # | |
# Util Function # | |
# ------------- # | |
def __create_temp_csv_file(self, df) -> str: | |
time_format = "%Y%m%d%H%M%S" | |
directory = 'temp' | |
now = datetime.datetime.now() | |
# Check if there are old files, remove them # | |
for filename in os.listdir(directory): | |
file_path = os.path.join(directory, filename) | |
file_time = datetime.datetime.strptime( | |
filename.split('.')[0], time_format) | |
# If the file is older than 3 minutes, delete the file | |
if now > datetime.timedelta( | |
minutes=self.delete_file_old_than_n_minutes) + file_time: | |
print('deleting olde file: ', filename) | |
os.remove(file_path) | |
new_file_name = now.strftime(format=time_format) + '.csv' | |
new_file_path = os.path.join(directory, new_file_name) | |
df.to_csv(new_file_path) | |
return new_file_path | |