Lefei commited on
Commit
c93002b
·
verified ·
1 Parent(s): cce3958

update app.py, add choice button for VisionTSpp base and large

Browse files
Files changed (1) hide show
  1. app.py +33 -43
app.py CHANGED
@@ -221,10 +221,6 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
221
  # pred_range = np.arange(context_len, context_len + pred_len)
222
  pred_range = np.arange(context_len-1, context_len + pred_len)
223
 
224
- print(true_data[:, i].shape)
225
- print(pred_median[:, i].shape)
226
- print(pred_range.shape)
227
-
228
  pred_median_visual = np.concatenate([true_data[context_len-1:context_len, i], pred_median[:, i]])
229
  print(pred_median_visual.shape)
230
 
@@ -236,11 +232,6 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
236
  lower_quantile_pred_visual = np.concatenate([true_data[context_len-1:context_len, i], lower_quantile_pred])
237
  upper_quantile_pred_visual = np.concatenate([true_data[context_len-1:context_len, i], upper_quantile_pred])
238
 
239
- print(lower_quantile_pred.shape)
240
- print(upper_quantile_pred.shape)
241
- print(lower_quantile_pred_visual.shape)
242
- print(upper_quantile_pred_visual.shape)
243
-
244
  q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
245
  ax.fill_between(pred_range, lower_quantile_pred_visual, upper_quantile_pred_visual, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
246
  # ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
@@ -387,40 +378,39 @@ def get_session_dir(session_id: gr.State):
387
  def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
388
  session_dir = get_session_dir(session_id)
389
 
390
- # try:
391
-
392
- if data_source == "Upload CSV":
393
- if upload_file is None:
394
- raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
395
- uploaded_file_path = Path(session_dir) / Path(upload_file.name).name
396
- shutil.copy(upload_file.name, uploaded_file_path)
397
- df = pd.read_csv(uploaded_file_path)
398
- else:
399
- df = load_preset_data(data_source)
400
-
401
- index, context_len, pred_len = int(index), int(context_len), int(pred_len)
402
- # --- Pass model_size to predict_at_index ---
403
- result = predict_at_index(df, index, context_len, pred_len, session_dir, model_size)
404
-
405
- final_index = min(index, result.total_samples - 1)
406
-
407
- return (
408
- result.ts_fig,
409
- result.input_img_fig,
410
- result.recon_img_fig,
411
- result.csv_path,
412
- gr.update(maximum=result.total_samples - 1, value=final_index),
413
- gr.update(value=result.inferred_freq),
414
- session_dir
415
- )
416
-
417
- # except Exception as e:
418
- # print(f"Error during forecast: {e}")
419
- # error_fig = plt.figure(figsize=(10, 5))
420
- # plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
421
- # plt.axis('off')
422
- # plt.close(error_fig)
423
- # return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id
424
 
425
 
426
  with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo:
 
221
  # pred_range = np.arange(context_len, context_len + pred_len)
222
  pred_range = np.arange(context_len-1, context_len + pred_len)
223
 
 
 
 
 
224
  pred_median_visual = np.concatenate([true_data[context_len-1:context_len, i], pred_median[:, i]])
225
  print(pred_median_visual.shape)
226
 
 
232
  lower_quantile_pred_visual = np.concatenate([true_data[context_len-1:context_len, i], lower_quantile_pred])
233
  upper_quantile_pred_visual = np.concatenate([true_data[context_len-1:context_len, i], upper_quantile_pred])
234
 
 
 
 
 
 
235
  q_low, q_high = quantile_vals[j], quantile_vals[-(j+1)]
236
  ax.fill_between(pred_range, lower_quantile_pred_visual, upper_quantile_pred_visual, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
237
  # ax.fill_between(pred_range, lower_quantile_pred, upper_quantile_pred, color=quantile_colors[j], alpha=0.7, label=f'{int(q_low*100)}-{int(q_high*100)}% Quantile')
 
378
  def run_forecast(data_source, upload_file, index, context_len, pred_len, model_size, session_id: gr.State):
379
  session_dir = get_session_dir(session_id)
380
 
381
+ try:
382
+ if data_source == "Upload CSV":
383
+ if upload_file is None:
384
+ raise gr.Error("Please upload a CSV file when 'Upload CSV' is selected.")
385
+ uploaded_file_path = Path(session_dir) / Path(upload_file.name).name
386
+ shutil.copy(upload_file.name, uploaded_file_path)
387
+ df = pd.read_csv(uploaded_file_path)
388
+ else:
389
+ df = load_preset_data(data_source)
390
+
391
+ index, context_len, pred_len = int(index), int(context_len), int(pred_len)
392
+ # --- Pass model_size to predict_at_index ---
393
+ result = predict_at_index(df, index, context_len, pred_len, session_dir, model_size)
394
+
395
+ final_index = min(index, result.total_samples - 1)
396
+
397
+ return (
398
+ result.ts_fig,
399
+ result.input_img_fig,
400
+ result.recon_img_fig,
401
+ result.csv_path,
402
+ gr.update(maximum=result.total_samples - 1, value=final_index),
403
+ gr.update(value=result.inferred_freq),
404
+ session_dir
405
+ )
406
+
407
+ except Exception as e:
408
+ print(f"Error during forecast: {e}")
409
+ error_fig = plt.figure(figsize=(10, 5))
410
+ plt.text(0.5, 0.5, f"An error occurred:\n{str(e)}", ha='center', va='center', wrap=True, color='red', fontsize=12)
411
+ plt.axis('off')
412
+ plt.close(error_fig)
413
+ return error_fig, None, None, None, gr.update(), gr.update(value="Error"), session_id
 
414
 
415
 
416
  with gr.Blocks(title="VisionTS++ Advanced Forecasting Platform", theme=gr.themes.Soft()) as demo: