Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import math | |
| from src.main import DemandForecasting | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from .helpers import reset_index | |
| class GradioApp(): | |
| def __init__(self): | |
| self.forecaster = DemandForecasting() | |
| self.ts_data = None # Time series data for model training and forecasting | |
| self.model_data = None | |
| self.skus = None | |
| self.forecast_horizon = 1 | |
| def __set_ts_data(self, path): | |
| self.ts_data = pd.read_csv( | |
| path, | |
| index_col='datetime', | |
| parse_dates=['datetime']) | |
| self.skus = self.ts_data['sku'].unique().tolist() | |
| self.model_data = pd.DataFrame( | |
| { | |
| 'sku': self.skus, | |
| 'best_model': '', | |
| 'characteristic': '', | |
| # 'predictability': '', | |
| 'RMSE': '', | |
| 'Intermittent Scores':'' | |
| } | |
| ) | |
| print('[__set_ts_data] End') | |
| def __set_forecast(self, forecast: pd.DataFrame): | |
| print('__set_forecast') | |
| self.forecast = forecast.set_index('datetime') | |
| self.forecast.index = pd.to_datetime(self.forecast.index) | |
| def __set_model_selection_res(self, model_selection_reses: pd.DataFrame): | |
| ''' | |
| self.model_selection_res will be identical to self.forecast | |
| keep tracking on this just to visualize the model selection result | |
| ''' | |
| print('__set_model_selection_res') | |
| self.model_selection_res = model_selection_reses | |
| # self.model_selection_res = pd.to_datetime( | |
| # self.model_selection_res.index) | |
| def __set_model(self, model_df): | |
| if (self.skus is None): | |
| raise gr.Error( | |
| 'Incorrect SKUs, time series data must be loaded and SKUs must match.') | |
| if (set(self.skus) - set(model_df['sku']) != set()): | |
| raise gr.Error( | |
| 'SKUs in provided model select data does not match SKUs in timeseries data.' | |
| ) | |
| self.model_data = model_df | |
| def btn_load_data__click(self): | |
| print('btn_load_data__click') | |
| self.__set_ts_data('./data/demand_forecasting_demo_data.csv') | |
| return (self.update__df_ts_data(), | |
| self.update__df_model_data(), | |
| self.update__file_model_data(), | |
| self.update__slider_forecast_horizon(), | |
| self.update__md_ts_data_info()) | |
| def btn_load_demo_result__click(self): | |
| forecast = pd.read_csv( | |
| './data/demand_forecasting_demo_result.csv') | |
| self.__set_forecast(forecast) | |
| return (self.update__df_forecast(), | |
| self.update__file_forecast(), | |
| self.update__dropdown_forecast()) | |
| def file_upload_data__upload(self, file): | |
| self.__set_ts_data(file.name) | |
| return (self.update__df_ts_data(), | |
| self.update__df_model_data(), | |
| self.update__file_model_data(), | |
| self.update__slider_forecast_horizon(), | |
| self.update__md_ts_data_info()) | |
| def file_upload_model_data__upload(self, file): | |
| model_df = pd.read_csv(file.name) | |
| self.__set_model(model_df) | |
| return (self.update__df_model_data(), | |
| self.update__file_model_data()) | |
| def btn_load_model_data__click(self): | |
| model_df = pd.read_csv( | |
| './data/demand_forecasting_demo_models.csv') | |
| self.__set_model(model_df) | |
| return (self.update__df_model_data(), | |
| self.update__file_model_data()) | |
| def btn_model_selection__click(self): | |
| print('btn_model_selection__click') | |
| ts_data = reset_index(self.ts_data) | |
| model_selection_reses = [] | |
| for sku in self.skus: | |
| print('Selecting model ', sku) | |
| data = ts_data[ts_data['sku'] == sku] | |
| # ----------------- # | |
| # Feature Selection # | |
| # ----------------- # | |
| res = self.forecaster.forecast( | |
| data, 0, model='all', run_test=True) | |
| self.model_data.loc[self.model_data['sku'] == | |
| sku, 'characteristic'] = res['characteristic'] | |
| self.model_data.loc[self.model_data['sku'] == | |
| sku, 'best_model'] = res['forecast'][0]['model'] | |
| # self.model_data.loc[self.model_data['sku'] == | |
| # sku, 'predictability'] = res['predictability'] | |
| self.model_data.loc[self.model_data['sku'] == | |
| sku, 'RMSE'] = round(res['forecast'][0]['RMSE'], 2) | |
| self.model_data.loc[self.model_data['sku'] == | |
| sku, 'Intermittent Scores'] = str(res['forecast'][0]['interm_scores']) | |
| model_selection_res = res['forecast'][0]['test'].drop( | |
| columns='truth').rename(columns={'test': 'y'}) | |
| model_selection_res['sku'] = sku | |
| model_selection_reses.append(model_selection_res) | |
| self.__set_model_selection_res(pd.concat(model_selection_reses)) | |
| return (self.update__df_model_data(), | |
| self.update__file_model_data(), | |
| self.update__accordion_model_selection(), | |
| self.update__dropdown_model_selection()) | |
| def slider_forecast_horizon__update(self, slider): | |
| # print('slider_forecast_horizon__update ', slider) | |
| self.forecast_horizon = slider | |
| def btn_forecast__click(self): | |
| # ----------- # | |
| # Forecasting # | |
| # ----------- # | |
| forecasts = [] | |
| # Reset data index and format the datetime column to string | |
| ts_data = reset_index(self.ts_data) | |
| for sku in self.skus: | |
| print('Forecasting ', sku) | |
| data = ts_data[ts_data['sku'] == sku] | |
| # Drop sku column first, for now the pipeline doesn't take this column | |
| data = data.drop('sku', axis=1) | |
| model_data = self.model_data[self.model_data['sku'] == sku] | |
| print(model_data) | |
| model = model_data['best_model'].tolist()[0] | |
| characteristic = model_data['characteristic'].tolist()[0] | |
| # ----------------- # | |
| # Feature Selection # | |
| # ----------------- # | |
| print(model, characteristic) | |
| res = self.forecaster.forecast( | |
| data, self.forecast_horizon, model=model, run_test=False, characteristic=characteristic) | |
| print(res) | |
| forecast = pd.DataFrame( | |
| res['forecast'][0]['forecast'], columns=['datetime', 'y']) | |
| forecast['sku'] = sku | |
| forecasts.append(forecast) | |
| self.__set_forecast(pd.concat(forecasts)) | |
| return (self.update__df_forecast(), | |
| self.update__file_forecast(), | |
| self.update__dropdown_forecast()) | |
| def df_ts_data__change(self): | |
| return self.update__dropdown_ts_data() | |
| def dropdown_ts_data__select(self, skus): | |
| return self.update__plot_ts_data(skus) | |
| def dropdown_forecast__select(self, sku): | |
| return self.update__plot_forecast(sku) | |
| def dropdown_model_selection__select(self, sku): | |
| return self.update__plot_model_selection(sku) | |
| # ======== # | |
| # Updaters # | |
| # ======== # | |
| def update__file_model_data(self): | |
| self.model_data.to_csv('./best_models.csv', index=False) | |
| return gr.File(value='./best_models.csv') | |
| def update__df_model_data(self): | |
| return gr.Dataframe(value=self.model_data) | |
| def update__df_ts_data(self): | |
| return gr.Dataframe(value=reset_index(self.ts_data)) | |
| def update__df_forecast(self): | |
| print('upupdate__df_forecastda') | |
| print(self.forecast) | |
| return gr.Dataframe(value = reset_index(self.forecast)) | |
| def update__slider_forecast_horizon(self): | |
| skus = self.skus | |
| # Set max horizon to be the 20% of the shortest SKU data's length | |
| max_horizon = int( | |
| min(self.ts_data[self.ts_data['sku'] == sku].shape[0] for sku in skus) * 0.2) | |
| # max_horizon = int( | |
| # self.ts_data[self.ts_data['sku'] == sku].shape[0] * 0.2) | |
| return gr.Slider(maximum=max_horizon) | |
| def update__file_forecast(self): | |
| reset_index(self.forecast).to_csv('./forecast_result.csv', index=False) | |
| return gr.File(value='./forecast_result.csv') | |
| def update__md_ts_data_info(self): | |
| md = f''' | |
| ### Data Description | |
| Columns: **{reset_index(self.ts_data).columns.tolist()}** | |
| Size: {' | '.join([str(sku) + ' : **' + str(self.ts_data[self.ts_data["sku"] == sku].shape[0]) + '**' for sku in self.skus])} | |
| ''' | |
| return gr.Markdown(md) | |
| def update__dropdown_ts_data(self): | |
| # print(type(self.skus)) | |
| return gr.Dropdown(choices=self.skus) | |
| def update__dropdown_forecast(self): | |
| skus = self.forecast['sku'].unique().tolist() | |
| return gr.Dropdown(choices=skus) | |
| def update__dropdown_model_selection(self): | |
| return gr.Dropdown(choices=self.skus) | |
| def update__plot_ts_data(self, skus): | |
| # print('update__plot_ts_data') | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| for sku in skus: | |
| ax.plot(self.ts_data[self.ts_data['sku'] == sku]['y'], label=sku) | |
| ax.legend(loc='upper left') | |
| fig.tight_layout() | |
| return gr.Plot(fig) | |
| def update__plot_forecast(self, sku): | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| ''' | |
| A trick been used here, | |
| to connect the plotting lines, for the historical part, | |
| have to concat with the 1st data in the forecasting result. | |
| Because the forecasting result already have date time index, | |
| using head(1) to get the first element of the forecasting result | |
| ''' | |
| ax.plot(pd.concat( | |
| [ | |
| self.ts_data[self.ts_data['sku'] == sku], | |
| self.forecast[self.forecast['sku'] == sku].head(1) | |
| ])['y'], | |
| label=f'{sku} - historical') | |
| ax.plot(self.forecast[self.forecast['sku'] | |
| == sku]['y'], label=f'{sku} - forecast') | |
| ax.legend(loc='upper left') | |
| fig.tight_layout() | |
| return gr.Plot(fig) | |
| def update__plot_model_selection(self, sku): | |
| fig, ax = plt.subplots(figsize=(12, 4)) | |
| ''' | |
| Reason need to filter out the last index is - sometimes IDSC model cannot | |
| forecast the full required data size. Have to crop out the tail part. | |
| ''' | |
| idx = self.model_selection_res[self.model_selection_res['sku'] == sku].index | |
| ax.plot(self.ts_data[ | |
| (self.ts_data['sku'] == sku) & | |
| (self.ts_data.index <= idx[-1]) | |
| ]['y'], label=f'{sku} - ground truth') | |
| ax.plot(self.model_selection_res[self.model_selection_res['sku'] | |
| == sku]['y'], label=f'{sku} - model result') | |
| ax.axvline(x=idx[0], ymin=0.05, ymax=0.95, ls='--') | |
| ax.legend(loc='upper left') | |
| fig.tight_layout() | |
| return gr.Plot(fig) | |
| def update__accordion_model_selection(self): | |
| return gr.Accordion(visible=True) | |