Spaces:
Running
Running
update app.py, add choice button for VisionTSpp base and large
Browse files
app.py
CHANGED
|
@@ -221,10 +221,6 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
|
|
| 221 |
# pred_range = np.arange(context_len, context_len + pred_len)
|
| 222 |
pred_range = np.arange(context_len-1, context_len + pred_len)
|
| 223 |
|
| 224 |
-
print(true_data[:, i].shape)
|
| 225 |
-
print(pred_median[:, i].shape)
|
| 226 |
-
print(pred_range.shape)
|
| 227 |
-
|
| 228 |
pred_median_visual = np.concatenate([true_data[context_len-1:context_len, i], pred_median[:, i]])
|
| 229 |
print(pred_median_visual.shape)
|
| 230 |
|
|
@@ -236,11 +232,6 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
|
|
| 236 |
lower_quantile_pred_visual = np.concatenate([true_data[context_len-1:context_len, i], lower_quantile_pred])
|
| 237 |
upper_quantile_pred_visual = np.concatenate([true_data[context_len-1:context_len, i], upper_quantile_pred])
|
| 238 |
|
| 239 |
-
print(lower_quantile_pred.shape)
|
| 240 |
-
print(upper_quantile_pred.shape)
|
| 241 |
-
print(lower_quantile_pred_visual.shape)
|
| 242 |
-
print(upper_quantile_pred_visual.shape)
|
| 243 |
-
|
| 244 |
q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
|
| 245 |
ax.fill_between(pred_range, lower_quantile_pred_visual, upper_quantile_pred_visual, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
|
| 246 |
# ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
|
|
@@ -387,40 +378,39 @@ def get_session_dir(session_id: gr.State):
|
|
| 387 |
def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
|
| 388 |
session_dir = get_session_dir(session_id)
|
| 389 |
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
# return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id
|
| 424 |
|
| 425 |
|
| 426 |
with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
|
|
|
|
| 221 |
# pred_range = np.arange(context_len, context_len + pred_len)
|
| 222 |
pred_range = np.arange(context_len-1, context_len + pred_len)
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
pred_median_visual = np.concatenate([true_data[context_len-1:context_len, i], pred_median[:, i]])
|
| 225 |
print(pred_median_visual.shape)
|
| 226 |
|
|
|
|
| 232 |
lower_quantile_pred_visual = np.concatenate([true_data[context_len-1:context_len, i], lower_quantile_pred])
|
| 233 |
upper_quantile_pred_visual = np.concatenate([true_data[context_len-1:context_len, i], upper_quantile_pred])
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
|
| 236 |
ax.fill_between(pred_range, lower_quantile_pred_visual, upper_quantile_pred_visual, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
|
| 237 |
# ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
|
|
|
|
| 378 |
def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
|
| 379 |
session_dir = get_session_dir(session_id)
|
| 380 |
|
| 381 |
+
try:
|
| 382 |
+
if data_source == "Upload CSV":
|
| 383 |
+
if upload_file is None:
|
| 384 |
+
raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
|
| 385 |
+
uploaded_file_path = Path(session_dir) / Path(upload_file.name).name
|
| 386 |
+
shutil.copy(upload_file.name, uploaded_file_path)
|
| 387 |
+
df = pd.read_csv(uploaded_file_path)
|
| 388 |
+
else:
|
| 389 |
+
df = load_preset_data(data_source)
|
| 390 |
+
|
| 391 |
+
index, context_len, pred_len = int(index), int(context_len), int(pred_len)
|
| 392 |
+
# --- Pass model_size to predict_at_index ---
|
| 393 |
+
result = predict_at_index(df, index, context_len, pred_len, session_dir, model_size)
|
| 394 |
+
|
| 395 |
+
final_index = min(index, result.total_samples - 1)
|
| 396 |
+
|
| 397 |
+
return (
|
| 398 |
+
result.ts_fig,
|
| 399 |
+
result.input_img_fig,
|
| 400 |
+
result.recon_img_fig,
|
| 401 |
+
result.csv_path,
|
| 402 |
+
gr.update(maximum=result.total_samples - 1, value=final_index),
|
| 403 |
+
gr.update(value=result.inferred_freq),
|
| 404 |
+
session_dir
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
except Exception as e:
|
| 408 |
+
print(f"Error during forecast: {e}")
|
| 409 |
+
error_fig = plt.figure(figsize=(10, 5))
|
| 410 |
+
plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
|
| 411 |
+
plt.axis('off')
|
| 412 |
+
plt.close(error_fig)
|
| 413 |
+
return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id
|
|
|
|
| 414 |
|
| 415 |
|
| 416 |
with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
|