Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,7 @@ from neuralforecast.models import NHITS, TimesNet, LSTM, TFT
|
|
| 7 |
from neuralforecast.losses.pytorch import HuberMQLoss
|
| 8 |
from neuralforecast.utils import AirPassengersDF
|
| 9 |
import time
|
|
|
|
| 10 |
|
| 11 |
@st.cache_resource
|
| 12 |
def load_model(path, freq):
|
|
@@ -225,6 +226,15 @@ def transfer_learning_forecasting():
|
|
| 225 |
frequency = determine_frequency(df)
|
| 226 |
st.sidebar.write(f"Detected frequency: {frequency}")
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
# Load pre-trained models
|
| 229 |
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
| 230 |
forecast_results = {}
|
|
@@ -239,6 +249,14 @@ def transfer_learning_forecasting():
|
|
| 239 |
elif model_choice == "TFT":
|
| 240 |
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
for model_name, forecast_df in forecast_results.items():
|
| 243 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
|
| 244 |
|
|
|
|
| 7 |
from neuralforecast.losses.pytorch import HuberMQLoss
|
| 8 |
from neuralforecast.utils import AirPassengersDF
|
| 9 |
import time
|
| 10 |
+
from st_aggrid import AgGrid
|
| 11 |
|
| 12 |
@st.cache_resource
|
| 13 |
def load_model(path, freq):
|
|
|
|
| 226 |
frequency = determine_frequency(df)
|
| 227 |
st.sidebar.write(f"Detected frequency: {frequency}")
|
| 228 |
|
| 229 |
+
df_grid = df.drop(columns="unique_id")
|
| 230 |
+
grid_table = AgGrid(
|
| 231 |
+
df_grid,
|
| 232 |
+
editable=False,
|
| 233 |
+
theme="streamlit",
|
| 234 |
+
fit_columns_on_grid_load=True,
|
| 235 |
+
height=360,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
# Load pre-trained models
|
| 239 |
nhits_model, timesnet_model, lstm_model, tft_model = select_model_based_on_frequency(frequency, nhits_models, timesnet_models, lstm_models, tft_models)
|
| 240 |
forecast_results = {}
|
|
|
|
| 249 |
elif model_choice == "TFT":
|
| 250 |
forecast_results['TFT'] = generate_forecast(tft_model, df)
|
| 251 |
|
| 252 |
+
df_grid = df.drop(columns="unique_id")
|
| 253 |
+
grid_table = AgGrid(
|
| 254 |
+
df_grid,
|
| 255 |
+
editable=False,
|
| 256 |
+
theme="streamlit",
|
| 257 |
+
fit_columns_on_grid_load=True,
|
| 258 |
+
height=360,
|
| 259 |
+
)
|
| 260 |
for model_name, forecast_df in forecast_results.items():
|
| 261 |
plot_forecasts(forecast_df, df, f'{model_name} Forecast for {y_col}')
|
| 262 |
|