Lefei commited on
Commit
4d14e5a
·
verified ·
1 Parent(s): ef5d89c

update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -82
app.py CHANGED
@@ -18,15 +18,15 @@ REPO_ID = "Lefei/VisionTSpp"
18
  LOCAL_DIR = "./hf_models/VisionTSpp"
19
  CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
20
 
21
- ARCH = 'mae_base' # 可选: 'mae_base', 'mae_large', 'mae_huge'
22
 
23
- # 下载模型(Space 构建时执行一次)
24
  if not os.path.exists(CKPT_PATH):
25
  os.makedirs(LOCAL_DIR, exist_ok=True)
26
  print("Downloading model from Hugging Face Hub...")
27
  snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
28
 
29
- # 加载模型(全局加载一次)
30
  model = VisionTSpp(
31
  ARCH,
32
  ckpt_path=CKPT_PATH,
@@ -37,15 +37,41 @@ model = VisionTSpp(
37
  ).to(DEVICE)
38
  print(f"Model loaded on {DEVICE}")
39
 
 
 
 
 
 
40
  # ========================
41
- # 核心预测与可视化函数
42
  # ========================
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
45
  """
46
- 可视化真实值 vs 预测值
47
- true: [T, nvars]
48
- preds: [T, nvars],与 true 对齐
49
  """
50
  if isinstance(true, torch.Tensor):
51
  true = true.cpu().numpy()
@@ -53,7 +79,6 @@ def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
53
  preds = preds.cpu().numpy()
54
 
55
  nvars = true.shape[1]
56
-
57
  FIG_WIDTH = 12
58
  FIG_HEIGHT_PER_VAR = 1.8
59
  FONT_S = 10
@@ -73,7 +98,6 @@ def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
73
  ax.plot(np.arange(lookback_len, len(true)), preds[lookback_len:, i],
74
  label='Prediction (Median)', color='blue', linewidth=1.8)
75
 
76
- # 分隔线
77
  y_min, y_max = ax.get_ylim()
78
  ax.vlines(x=lookback_len, ymin=y_min, ymax=y_max,
79
  colors='gray', linestyles='--', alpha=0.7, linewidth=1)
@@ -82,12 +106,10 @@ def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
82
  ax.set_xticks([])
83
  ax.text(0.005, 0.8, f'Var {i+1}', transform=ax.transAxes, fontsize=FONT_S, weight='bold')
84
 
85
- # 图例
86
  if preds is not None:
87
  handles, labels = axes[0].get_legend_handles_labels()
88
  fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), prop={'size': FONT_S})
89
 
90
- # 计算 MSE/MAE
91
  if preds is not None:
92
  true_eval = true[-pred_len:]
93
  pred_eval = preds[-pred_len:]
@@ -96,77 +118,81 @@ def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
96
  fig.suptitle(f'MSE: {mse:.4f}, MAE: {mae:.4f}', fontsize=12, y=0.95)
97
 
98
  plt.subplots_adjust(hspace=0)
99
- return fig # 返回 matplotlib figure
 
100
 
101
 
102
- def predict_and_visualize(df, context_len=960, pred_len=394, freq="15Min"):
 
 
 
 
103
  """
104
- 输入: df (pandas.DataFrame),必须包含 'date' 列和其他数值列
105
- 输出: matplotlib 图像
 
106
  """
107
  if 'date' in df.columns:
108
- df['date'] = pd.to_datetime(df['date'])
109
- df = df.set_index('date')
110
- else:
111
- # 如果没有 date 列,假设是纯数值序列
112
- df = df.copy()
113
 
114
  data = df.values # [T, nvars]
115
  nvars = data.shape[1]
 
116
 
117
- if data.shape[0] < context_len + pred_len:
118
- raise ValueError(f"数据太短,至少需要 {context_len + pred_len} 行,当前只有 {data.shape[0]} 行。")
 
 
119
 
120
- # 归一化(使用训练集前 70% 的统计量)
121
  train_len = int(len(data) * 0.7)
122
  x_mean = data[:train_len].mean(axis=0, keepdims=True)
123
  x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
124
  data_norm = (data - x_mean) / x_std
125
 
126
- # 取最后一段作为测试窗口
127
- end_idx = len(data_norm)
128
- start_idx = end_idx - (context_len + pred_len)
129
- x = data_norm[start_idx:start_idx + context_len] # [context_len, nvars]
130
- y_true = data_norm[start_idx + context_len:end_idx] # [pred_len, nvars]
131
 
132
- # 设置周期性
133
  periodicity_list = freq_to_seasonality_list(freq)
134
  periodicity = periodicity_list[0] if periodicity_list else 1
135
- color_list = [i % 3 for i in range(nvars)] # RGB 循环着色
136
-
137
- # 更新模型配置
138
- model.update_config(
139
- context_len=context_len,
140
- pred_len=pred_len,
141
- periodicity=periodicity,
142
- num_patch_input=7,
143
- padding_mode='constant'
144
- )
145
 
146
- # 转为 tensor
147
  x_tensor = torch.FloatTensor(x).unsqueeze(0).to(DEVICE) # [1, T, N]
148
- y_true_tensor = torch.FloatTensor(y_true).unsqueeze(0).to(DEVICE)
149
 
150
- # 预测
151
  with torch.no_grad():
152
- y_pred, _, _, _, _ = model.forward(x_tensor, export_image=True, color_list=color_list)
 
 
153
  y_pred_median = y_pred[0] # median prediction
154
 
155
  # 反归一化
156
  y_true_original = y_true * x_std + x_mean
157
  y_pred_original = y_pred_median[0].cpu().numpy() * x_std + x_mean
158
 
159
- # 构造完整序列用于可视化
160
  full_true = np.concatenate([x * x_std + x_mean, y_true_original], axis=0)
161
  full_pred = np.concatenate([x * x_std + x_mean, y_pred_original], axis=0)
162
 
163
- # 可视化
164
- fig = visual_ts(true=full_true, preds=full_pred, lookback_len_visual=context_len, pred_len=pred_len)
165
- return fig
 
 
 
 
 
 
 
 
166
 
167
 
168
  # ========================
169
- # 默认数据加载
170
  # ========================
171
  def load_default_data():
172
  data_path = "./datasets/ETTm1.csv"
@@ -181,63 +207,84 @@ def load_default_data():
181
 
182
 
183
  # ========================
184
- # Gradio 界面
185
  # ========================
186
- def run_forecast(file_input, context_len, pred_len, freq):
187
  if file_input is not None:
188
  df = pd.read_csv(file_input.name)
189
- title = "Uploaded Data Prediction"
190
  else:
191
  df = load_default_data()
192
- title = "Default ETTm1 Dataset Prediction"
193
 
194
  try:
195
- fig = predict_and_visualize(df, context_len=int(context_len), pred_len=int(pred_len), freq=freq)
196
- fig.suptitle(title, fontsize=14, y=0.98)
197
- plt.close(fig) # 防止重复显示
198
- return fig
 
 
 
 
199
  except Exception as e:
200
- # 返回错误信息图像
201
- fig, ax = plt.subplots()
202
- ax.text(0.5, 0.5, f"Error: {str(e)}", ha='center', va='center', wrap=True)
203
- ax.axis('off')
204
- plt.close(fig)
205
- return fig
 
206
 
 
207
 
 
 
208
  # Gradio UI
209
- with gr.Blocks(title="VisionTS++ 时间序列预测") as demo:
 
210
  gr.Markdown("# 🕰️ VisionTS++ 时间序列预测平台")
211
- gr.Markdown("上传你的多变量时间序列 CSV 文件,或使用默认 ETTm1 数据进行预���。")
212
 
213
  with gr.Row():
214
- file_input = gr.File(label="上传 CSV 文件(含 date 列或纯数值)", file_types=['.csv'])
215
- with gr.Column():
216
- context_len = gr.Number(label="历史长度 (context_len)", value=960)
217
- pred_len = gr.Number(label="预测长度 (pred_len)", value=394)
218
- freq = gr.Textbox(label="时间频率 (如 15Min, H)", value="15Min")
219
-
220
- btn = gr.Button("🚀 开始预测")
221
-
222
- output_plot = gr.Plot(label="预测结果")
223
-
 
 
 
 
 
 
224
  btn.click(
225
  fn=run_forecast,
226
- inputs=[file_input, context_len, pred_len, freq],
227
- outputs=output_plot
 
 
 
 
 
 
 
228
  )
229
 
230
- # 示例:使用默认数据
231
  gr.Examples(
232
  examples=[
233
  [None, 960, 394, "15Min"]
234
  ],
235
  inputs=[file_input, context_len, pred_len, freq],
236
- outputs=output_plot,
237
- fn=run_forecast,
238
- label="点击运行默认示例"
239
  )
240
 
241
  # 启动
242
- if __name__ == "__main__":
243
- demo.launch()
 
18
  LOCAL_DIR = "./hf_models/VisionTSpp"
19
  CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
20
 
21
+ ARCH = 'mae_base'
22
 
23
+ # 下载模型
24
  if not os.path.exists(CKPT_PATH):
25
  os.makedirs(LOCAL_DIR, exist_ok=True)
26
  print("Downloading model from Hugging Face Hub...")
27
  snapshot_download(repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False)
28
 
29
+ # 加载模型
30
  model = VisionTSpp(
31
  ARCH,
32
  ckpt_path=CKPT_PATH,
 
37
  ).to(DEVICE)
38
  print(f"Model loaded on {DEVICE}")
39
 
40
+ # Image normalization constants
41
+ imagenet_mean = np.array([0.485, 0.456, 0.406])
42
+ imagenet_std = np.array([0.229, 0.224, 0.225])
43
+
44
+
45
  # ========================
46
+ # 可视化函数
47
  # ========================
48
 
49
+ def show_image_tensor(image, title='', cur_nvars=1, cur_color_list=None):
50
+ """
51
+ image: [H, W, 3] tensor
52
+ 返回 matplotlib figure
53
+ """
54
+ cur_image = torch.zeros_like(image)
55
+ height_per_var = image.shape[0] // cur_nvars
56
+
57
+ for i in range(cur_nvars):
58
+ cur_color = cur_color_list[i]
59
+ cur_image[i*height_per_var:(i+1)*height_per_var, :, cur_color] = \
60
+ (image[i*height_per_var:(i+1)*height_per_var, :, cur_color] * imagenet_std[cur_color] + imagenet_mean[cur_color]) * 255
61
+
62
+ cur_image = torch.clamp(cur_image, 0, 255).cpu().int()
63
+
64
+ fig, ax = plt.subplots(figsize=(6, 6))
65
+ ax.imshow(cur_image.numpy())
66
+ ax.set_title(title, fontsize=14)
67
+ ax.axis('off')
68
+ plt.close(fig)
69
+ return fig
70
+
71
+
72
  def visual_ts(true, preds=None, lookback_len_visual=300, pred_len=96):
73
  """
74
+ 绘制时间序列预测图(多变量)
 
 
75
  """
76
  if isinstance(true, torch.Tensor):
77
  true = true.cpu().numpy()
 
79
  preds = preds.cpu().numpy()
80
 
81
  nvars = true.shape[1]
 
82
  FIG_WIDTH = 12
83
  FIG_HEIGHT_PER_VAR = 1.8
84
  FONT_S = 10
 
98
  ax.plot(np.arange(lookback_len, len(true)), preds[lookback_len:, i],
99
  label='Prediction (Median)', color='blue', linewidth=1.8)
100
 
 
101
  y_min, y_max = ax.get_ylim()
102
  ax.vlines(x=lookback_len, ymin=y_min, ymax=y_max,
103
  colors='gray', linestyles='--', alpha=0.7, linewidth=1)
 
106
  ax.set_xticks([])
107
  ax.text(0.005, 0.8, f'Var {i+1}', transform=ax.transAxes, fontsize=FONT_S, weight='bold')
108
 
 
109
  if preds is not None:
110
  handles, labels = axes[0].get_legend_handles_labels()
111
  fig.legend(handles, labels, loc='upper right', bbox_to_anchor=(0.9, 0.9), prop={'size': FONT_S})
112
 
 
113
  if preds is not None:
114
  true_eval = true[-pred_len:]
115
  pred_eval = preds[-pred_len:]
 
118
  fig.suptitle(f'MSE: {mse:.4f}, MAE: {mae:.4f}', fontsize=12, y=0.95)
119
 
120
  plt.subplots_adjust(hspace=0)
121
+ plt.close(fig)
122
+ return fig
123
 
124
 
125
+ # ========================
126
+ # 数据预处理与预测
127
+ # ========================
128
+
129
+ def predict_at_index(df, index, context_len=960, pred_len=394, freq="15Min"):
130
  """
131
+ 在指定 index 处预测
132
+ index: index 个样本(从 0 开始)
133
+ 返回: (ts_fig, input_img_fig, recon_img_fig)
134
  """
135
  if 'date' in df.columns:
136
+ df = df.set_index(pd.to_datetime(df['date'])).drop(columns=['date'])
 
 
 
 
137
 
138
  data = df.values # [T, nvars]
139
  nvars = data.shape[1]
140
+ total_samples = len(data) - context_len - pred_len + 1
141
 
142
+ if total_samples <= 0:
143
+ raise ValueError(f"数据太短,无法构造任何样本(需要至少 {context_len + pred_len} 行)")
144
+ if index >= total_samples:
145
+ raise ValueError(f"索引越界,最大允许索引为 {total_samples - 1}")
146
 
147
+ # 归一化
148
  train_len = int(len(data) * 0.7)
149
  x_mean = data[:train_len].mean(axis=0, keepdims=True)
150
  x_std = data[:train_len].std(axis=0, keepdims=True) + 1e-8
151
  data_norm = (data - x_mean) / x_std
152
 
153
+ # 提取当前样本
154
+ start_idx = index
155
+ x = data_norm[start_idx:start_idx + context_len] # [context_len, nvars]
156
+ y_true = data_norm[start_idx + context_len:start_idx + context_len + pred_len] # [pred_len, nvars]
 
157
 
158
+ # 周期性
159
  periodicity_list = freq_to_seasonality_list(freq)
160
  periodicity = periodicity_list[0] if periodicity_list else 1
161
+ color_list = [i % 3 for i in range(nvars)]
162
+
163
+ model.update_config(context_len=context_len, pred_len=pred_len, periodicity=periodicity)
 
 
 
 
 
 
 
164
 
 
165
  x_tensor = torch.FloatTensor(x).unsqueeze(0).to(DEVICE) # [1, T, N]
 
166
 
 
167
  with torch.no_grad():
168
+ y_pred, input_image, reconstructed_image, nvars_out, color_list_out = model.forward(
169
+ x_tensor, export_image=True, color_list=color_list
170
+ )
171
  y_pred_median = y_pred[0] # median prediction
172
 
173
  # 反归一化
174
  y_true_original = y_true * x_std + x_mean
175
  y_pred_original = y_pred_median[0].cpu().numpy() * x_std + x_mean
176
 
177
+ # 完整序列(用于可视化)
178
  full_true = np.concatenate([x * x_std + x_mean, y_true_original], axis=0)
179
  full_pred = np.concatenate([x * x_std + x_mean, y_pred_original], axis=0)
180
 
181
+ # === 可视化 ===
182
+ ts_fig = visual_ts(true=full_true, preds=full_pred, lookback_len_visual=context_len, pred_len=pred_len)
183
+
184
+ input_img_fig = show_image_tensor(
185
+ input_image[0, 0], title=f'Input Image (Sample {index})', cur_nvars=nvars, cur_color_list=color_list
186
+ )
187
+ recon_img_fig = show_image_tensor(
188
+ reconstructed_image[0, 0], title=f'Reconstructed Image', cur_nvars=nvars, cur_color_list=color_list
189
+ )
190
+
191
+ return ts_fig, input_img_fig, recon_img_fig, total_samples
192
 
193
 
194
  # ========================
195
+ # 默认数据
196
  # ========================
197
  def load_default_data():
198
  data_path = "./datasets/ETTm1.csv"
 
207
 
208
 
209
  # ========================
210
+ # Gradio 接口
211
  # ========================
212
+ def run_forecast(file_input, sample_index, context_len, pred_len, freq):
213
  if file_input is not None:
214
  df = pd.read_csv(file_input.name)
215
+ title_prefix = "Uploaded Data"
216
  else:
217
  df = load_default_data()
218
+ title_prefix = "ETTm1 Dataset"
219
 
220
  try:
221
+ ts_fig, input_img_fig, recon_img_fig, total_samples = predict_at_index(
222
+ df, int(sample_index), context_len=int(context_len), pred_len=int(pred_len), freq=freq
223
+ )
224
+
225
+ # 修改标题
226
+ ts_fig.suptitle(f"{title_prefix} - Sample {int(sample_index)}", fontsize=14, y=0.98)
227
+
228
+ return ts_fig, input_img_fig, recon_img_fig, gr.update(maximum=total_samples - 1, value=total_samples - 1)
229
  except Exception as e:
230
+ # 错误图
231
+ def error_fig(msg):
232
+ fig, ax = plt.subplots()
233
+ ax.text(0.5, 0.5, msg, ha='center', va='center', wrap=True)
234
+ ax.axis('off')
235
+ plt.close(fig)
236
+ return fig
237
 
238
+ return error_fig("Error"), error_fig("Error"), error_fig("Error"), gr.Number()
239
 
240
+
241
+ # ========================
242
  # Gradio UI
243
+ # ========================
244
+ with gr.Blocks(title="VisionTS++ 多变量预测") as demo:
245
  gr.Markdown("# 🕰️ VisionTS++ 时间序列预测平台")
246
+ gr.Markdown("上传 CSV 或使用默认 ETTm1 数据。滑动选择不同样本进行预测,并查看原始图像表示。")
247
 
248
  with gr.Row():
249
+ with gr.Column(scale=2):
250
+ file_input = gr.File(label="上传 CSV 文件", file_types=['.csv'])
251
+ context_len = gr.Number(label="历史长度", value=960)
252
+ pred_len = gr.Number(label="预测长度", value=394)
253
+ freq = gr.Textbox(label="频率 (如 15Min)", value="15Min")
254
+ sample_index = gr.Slider(label="样本索引", minimum=0, maximum=100, step=1, value=0)
255
+
256
+ with gr.Column(scale=3):
257
+ ts_plot = gr.Plot(label="时间序列预测")
258
+ with gr.Row():
259
+ input_img_plot = gr.Plot(label="Input Image")
260
+ recon_img_plot = gr.Plot(label="Reconstructed Image")
261
+
262
+ btn = gr.Button("🚀 更新预测")
263
+
264
+ # 点击按钮或滑动条变化时更新
265
  btn.click(
266
  fn=run_forecast,
267
+ inputs=[file_input, sample_index, context_len, pred_len, freq],
268
+ outputs=[ts_plot, input_img_plot, recon_img_plot, sample_index]
269
+ )
270
+
271
+ # 滑动条变化时也触发(但只在点击后才允许滑动)
272
+ # 我们用 sample_index.change 依赖于前一次运行的结果
273
+ demo.load(
274
+ fn=lambda: gr.update(maximum=100, value=0),
275
+ outputs=sample_index
276
  )
277
 
278
+ # 示例
279
  gr.Examples(
280
  examples=[
281
  [None, 960, 394, "15Min"]
282
  ],
283
  inputs=[file_input, context_len, pred_len, freq],
284
+ outputs=[ts_plot, input_img_plot, recon_img_plot, sample_index],
285
+ fn=lambda f, i, c, p, fr: run_forecast(f, 0, c, p, fr), # 默认 index=0
286
+ label="运行默认示例"
287
  )
288
 
289
  # 启动
290
+ demo.launch()