Gil Stetler commited on
Commit
1d730a5
·
1 Parent(s): 9a3942b
Files changed (1) hide show
  1. app.py +44 -37
app.py CHANGED
@@ -116,7 +116,8 @@
116
 
117
 
118
 
119
- import os
 
120
  import numpy as np
121
  import pandas as pd
122
  import torch
@@ -130,11 +131,11 @@ from chronos import ChronosPipeline
130
  # Config
131
  # --------------------
132
  MODEL_ID = "amazon/chronos-t5-large"
133
- PREDICTION_LENGTH = 30 # Vorhersage-Horizont (letzte 30 Tage)
134
- NUM_SAMPLES = 100 # >1: stabilerer Punktwert (Mittelwert). Für deterministisch: 1
135
- RV_WINDOW = 20 # Rollendes Fenster für RV (Handelstage)
136
- ANNUALIZE = True # annualisiere mit sqrt(252)
137
- EPS = 1e-8
138
 
139
  # --------------------
140
  # Model load
@@ -188,7 +189,7 @@ def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = T
188
  # Main
189
  # --------------------
190
  def run_vol_forecast_and_evaluate():
191
- # Daten
192
  raw = _read_ohlcv_csv()
193
  dates = _extract_dates(raw)
194
  close = _extract_close(raw)
@@ -199,45 +200,44 @@ def run_vol_forecast_and_evaluate():
199
  if n <= H + 5:
200
  raise gr.Error(f"RV-Serie zu kurz nach Rolling. Benötigt > {H+5}, erhalten {n}.")
201
 
202
- # Split: letzte H Tage als Holdout
203
  rv_train = rv[: n - H]
204
  rv_test = rv[n - H :]
205
 
206
- # Forecast (Samples) und **Punktprognose = Mittelwert**
 
 
 
 
207
  context = torch.tensor(rv_train, dtype=torch.float32)
208
- fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1, S, H]
209
- samples = fcst[0].cpu().numpy() # (S, H)
210
- mean_pred = samples.mean(axis=0) # (H,) <-- Punktprognose
211
- p10, p90 = np.quantile(samples, [0.1, 0.9], axis=0) # nur für Band
212
 
213
- # Fehler je Tag
214
- err = mean_pred - rv_test
215
- abs_pct_err = np.abs(err) / np.maximum(EPS, np.abs(rv_test)) * 100.0
216
- mape_pct = float(abs_pct_err.mean())
217
- rmse = float(np.sqrt(np.mean(err**2)))
218
 
219
- # Fehler je Tag
220
- err = mean_pred - rv_test
221
- abs_pct_err = np.abs(err) / np.maximum(EPS, np.abs(rv_test)) * 100.0
222
- mape_pct = float(abs_pct_err.mean())
223
  rmse = float(np.sqrt(np.mean(err**2)))
224
 
225
- # Plot
226
  fig = plt.figure(figsize=(10, 4))
227
  H0 = len(rv_train)
228
  if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
229
  dates_rv = np.array(dates[-len(rv):])
230
  plt.plot(dates_rv[:H0], rv_train, label="realized vol (history)")
231
  plt.plot(dates_rv[H0:], rv_test, label="realized vol (actual holdout)")
232
- plt.plot(dates_rv[H0:], mean_pred, linestyle="--", label="forecast (point/mean)")
233
- plt.fill_between(dates_rv[H0:], p10, p90, alpha=0.3, label="80% interval")
234
  plt.xlabel("date")
235
  else:
236
  x_all = np.arange(len(rv)); x_fcst = np.arange(H0, H0 + H)
237
  plt.plot(x_all[:H0], rv_train, label="realized vol (history)")
238
  plt.plot(x_fcst, rv_test, label="realized vol (actual holdout)")
239
- plt.plot(x_fcst, mean_pred, linestyle="--", label="forecast (point/mean)")
240
- plt.fill_between(x_fcst, p10, p90, alpha=0.3, label="80% interval")
241
  plt.xlabel("time index")
242
 
243
  plt.title(f"Volatility Forecast (RV window={RV_WINDOW}, H={H})")
@@ -246,7 +246,6 @@ def run_vol_forecast_and_evaluate():
246
  plt.tight_layout()
247
 
248
  # Tabelle: Tag-für-Tag Vergleich
249
- # (falls Datum vorhanden, verwende die letzten H RV-Datenpunkte)
250
  if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
251
  dates_rv = np.array(dates[-len(rv):])
252
  last_dates = dates_rv[H0:]
@@ -256,8 +255,8 @@ def run_vol_forecast_and_evaluate():
256
  df_days = pd.DataFrame({
257
  "date": last_dates,
258
  "actual_vol": rv_test,
259
- "forecast_vol_point": mean_pred,
260
- "abs_error": np.abs(err),
261
  "abs_pct_error_%": abs_pct_err,
262
  })
263
 
@@ -267,28 +266,35 @@ def run_vol_forecast_and_evaluate():
267
  "prediction_length": H,
268
  "num_samples": NUM_SAMPLES,
269
  "annualized": ANNUALIZE,
270
- "point_forecast": "mean",
 
271
  },
272
  "metrics": {
273
  "MAPE_%": mape_pct,
 
274
  "RMSE": rmse,
275
  },
276
  }
277
 
278
- metrics_md = f"**MAPE (durchschn. %-Fehler): {mape_pct:.2f}%**  **RMSE:** {rmse:.6f}"
 
 
 
 
279
  return fig, out_json, df_days, metrics_md
280
 
281
  # --------------------
282
  # UI
283
  # --------------------
284
- with gr.Blocks(title="Volatility Forecast • Punktprognose") as demo:
285
  gr.Markdown(
286
- "## Letzte 30 Tage Volatilität prognostizieren und pro Tag vergleichen\n"
287
- "- Punktprognose = **Mittelwert** der Verteilung (kein Median).\n"
288
- "- Ausgabe: Plot, MAPE%, RMSE, und **tägliche Tabelle** (Actual vs. Forecast + %-Fehler)."
 
289
  )
290
  run_btn = gr.Button("Run", variant="primary")
291
- plot = gr.Plot(label="Forecast (point) vs Actual")
292
  meta = gr.JSON(label="Konfiguration & Gesamtmetriken")
293
  table = gr.Dataframe(label="Per-Day Vergleich", wrap=True)
294
  metrics = gr.Markdown(label="Metriken")
@@ -298,3 +304,4 @@ with gr.Blocks(title="Volatility Forecast • Punktprognose") as demo:
298
  if __name__ == "__main__":
299
  demo.launch()
300
 
 
 
116
 
117
 
118
 
119
+ # app.py
120
+ import os, random
121
  import numpy as np
122
  import pandas as pd
123
  import torch
 
131
  # Config
132
  # --------------------
133
  MODEL_ID = "amazon/chronos-t5-large"
134
+ PREDICTION_LENGTH = 30 # letzte 30 Tage
135
+ NUM_SAMPLES = 1 # genau EINE Bahn -> tagesgenaue Punktvorhersage
136
+ RV_WINDOW = 20 # Rolling-Fenster für RV (Handelstage)
137
+ ANNUALIZE = True # annualisiert mit sqrt(252)
138
+ EPS = 1e-8 # Schutz gegen Division durch 0
139
 
140
  # --------------------
141
  # Model load
 
189
  # Main
190
  # --------------------
191
  def run_vol_forecast_and_evaluate():
192
+ # Daten laden
193
  raw = _read_ohlcv_csv()
194
  dates = _extract_dates(raw)
195
  close = _extract_close(raw)
 
200
  if n <= H + 5:
201
  raise gr.Error(f"RV-Serie zu kurz nach Rolling. Benötigt > {H+5}, erhalten {n}.")
202
 
203
+ # Holdout: letzte H Tage
204
  rv_train = rv[: n - H]
205
  rv_test = rv[n - H :]
206
 
207
+ # Reproduzierbare EINZELNE Sample-Bahn ziehen
208
+ random.seed(0); np.random.seed(0); torch.manual_seed(0)
209
+ if torch.cuda.is_available():
210
+ torch.cuda.manual_seed_all(0)
211
+
212
  context = torch.tensor(rv_train, dtype=torch.float32)
213
+ fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1, 1, H]
214
+ samples = fcst[0].cpu().numpy() # (1, H)
215
+ path_pred = samples[0] # (H,) <-- tagesgenaue Vorhersage
 
216
 
217
+ # Tagesfehler & Prozentfehler
218
+ err = path_pred - rv_test
219
+ denom = np.maximum(EPS, np.abs(rv_test))
220
+ abs_pct_err = np.abs(err) / denom * 100.0
221
+ pct_err = err / np.maximum(EPS, rv_test) * 100.0
222
 
223
+ mape_pct = float(abs_pct_err.mean()) # Hauptmetrik: mittlere absolute proz. Abweichung
224
+ mpe_pct = float(pct_err.mean()) # signiert (Bias)
 
 
225
  rmse = float(np.sqrt(np.mean(err**2)))
226
 
227
+ # Plot: History + Actual (Holdout) + Forecast-Pfad
228
  fig = plt.figure(figsize=(10, 4))
229
  H0 = len(rv_train)
230
  if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
231
  dates_rv = np.array(dates[-len(rv):])
232
  plt.plot(dates_rv[:H0], rv_train, label="realized vol (history)")
233
  plt.plot(dates_rv[H0:], rv_test, label="realized vol (actual holdout)")
234
+ plt.plot(dates_rv[H0:], path_pred, linestyle="--", label="forecast (sample path)")
 
235
  plt.xlabel("date")
236
  else:
237
  x_all = np.arange(len(rv)); x_fcst = np.arange(H0, H0 + H)
238
  plt.plot(x_all[:H0], rv_train, label="realized vol (history)")
239
  plt.plot(x_fcst, rv_test, label="realized vol (actual holdout)")
240
+ plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (sample path)")
 
241
  plt.xlabel("time index")
242
 
243
  plt.title(f"Volatility Forecast (RV window={RV_WINDOW}, H={H})")
 
246
  plt.tight_layout()
247
 
248
  # Tabelle: Tag-für-Tag Vergleich
 
249
  if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
250
  dates_rv = np.array(dates[-len(rv):])
251
  last_dates = dates_rv[H0:]
 
255
  df_days = pd.DataFrame({
256
  "date": last_dates,
257
  "actual_vol": rv_test,
258
+ "forecast_vol": path_pred,
259
+ "pct_error_% (signed)": pct_err,
260
  "abs_pct_error_%": abs_pct_err,
261
  })
262
 
 
266
  "prediction_length": H,
267
  "num_samples": NUM_SAMPLES,
268
  "annualized": ANNUALIZE,
269
+ "point_forecast": "single_sample_path",
270
+ "seed": 0,
271
  },
272
  "metrics": {
273
  "MAPE_%": mape_pct,
274
+ "MPE_%": mpe_pct,
275
  "RMSE": rmse,
276
  },
277
  }
278
 
279
+ metrics_md = (
280
+ f"**MAPE (Ø absolute %-Abweichung): {mape_pct:.2f}%**  "
281
+ f"**MPE (Ø signed %): {mpe_pct:.2f}%**  "
282
+ f"**RMSE:** {rmse:.6f}"
283
+ )
284
  return fig, out_json, df_days, metrics_md
285
 
286
  # --------------------
287
  # UI
288
  # --------------------
289
+ with gr.Blocks(title="Volatility Forecast • Tagesgenaue Punktwerte") as demo:
290
  gr.Markdown(
291
+ "## Vorhersage der letzten 30 Tage (tagesgenaue Punktwerte)\n"
292
+ "- Es wird **eine einzelne Sample-Bahn** prognostiziert (keine Mittelung, kein Median).\n"
293
+ "- Vergleich pro Tag: Forecast vs. Actual + Prozentfehler.\n"
294
+ "- Gesamt: **MAPE%** (Hauptmetrik), **MPE%** (Bias) und RMSE."
295
  )
296
  run_btn = gr.Button("Run", variant="primary")
297
+ plot = gr.Plot(label="Forecast (einzelne Bahn) vs Actual")
298
  meta = gr.JSON(label="Konfiguration & Gesamtmetriken")
299
  table = gr.Dataframe(label="Per-Day Vergleich", wrap=True)
300
  metrics = gr.Markdown(label="Metriken")
 
304
  if __name__ == "__main__":
305
  demo.launch()
306
 
307
+