Spaces:
Running
Running
Gil Stetler
commited on
Commit
·
1d730a5
1
Parent(s):
9a3942b
fix
Browse files
app.py
CHANGED
|
@@ -116,7 +116,8 @@
|
|
| 116 |
|
| 117 |
|
| 118 |
|
| 119 |
-
|
|
|
|
| 120 |
import numpy as np
|
| 121 |
import pandas as pd
|
| 122 |
import torch
|
|
@@ -130,11 +131,11 @@ from chronos import ChronosPipeline
|
|
| 130 |
# Config
|
| 131 |
# --------------------
|
| 132 |
MODEL_ID = "amazon/chronos-t5-large"
|
| 133 |
-
PREDICTION_LENGTH = 30 #
|
| 134 |
-
NUM_SAMPLES =
|
| 135 |
-
RV_WINDOW = 20 #
|
| 136 |
-
ANNUALIZE = True #
|
| 137 |
-
EPS = 1e-8
|
| 138 |
|
| 139 |
# --------------------
|
| 140 |
# Model load
|
|
@@ -188,7 +189,7 @@ def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = T
|
|
| 188 |
# Main
|
| 189 |
# --------------------
|
| 190 |
def run_vol_forecast_and_evaluate():
|
| 191 |
-
# Daten
|
| 192 |
raw = _read_ohlcv_csv()
|
| 193 |
dates = _extract_dates(raw)
|
| 194 |
close = _extract_close(raw)
|
|
@@ -199,45 +200,44 @@ def run_vol_forecast_and_evaluate():
|
|
| 199 |
if n <= H + 5:
|
| 200 |
raise gr.Error(f"RV-Serie zu kurz nach Rolling. Benötigt > {H+5}, erhalten {n}.")
|
| 201 |
|
| 202 |
-
#
|
| 203 |
rv_train = rv[: n - H]
|
| 204 |
rv_test = rv[n - H :]
|
| 205 |
|
| 206 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
context = torch.tensor(rv_train, dtype=torch.float32)
|
| 208 |
-
fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1,
|
| 209 |
-
samples = fcst[0].cpu().numpy() # (
|
| 210 |
-
|
| 211 |
-
p10, p90 = np.quantile(samples, [0.1, 0.9], axis=0) # nur für Band
|
| 212 |
|
| 213 |
-
#
|
| 214 |
-
err =
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
|
| 219 |
-
#
|
| 220 |
-
|
| 221 |
-
abs_pct_err = np.abs(err) / np.maximum(EPS, np.abs(rv_test)) * 100.0
|
| 222 |
-
mape_pct = float(abs_pct_err.mean())
|
| 223 |
rmse = float(np.sqrt(np.mean(err**2)))
|
| 224 |
|
| 225 |
-
# Plot
|
| 226 |
fig = plt.figure(figsize=(10, 4))
|
| 227 |
H0 = len(rv_train)
|
| 228 |
if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 229 |
dates_rv = np.array(dates[-len(rv):])
|
| 230 |
plt.plot(dates_rv[:H0], rv_train, label="realized vol (history)")
|
| 231 |
plt.plot(dates_rv[H0:], rv_test, label="realized vol (actual holdout)")
|
| 232 |
-
plt.plot(dates_rv[H0:],
|
| 233 |
-
plt.fill_between(dates_rv[H0:], p10, p90, alpha=0.3, label="80% interval")
|
| 234 |
plt.xlabel("date")
|
| 235 |
else:
|
| 236 |
x_all = np.arange(len(rv)); x_fcst = np.arange(H0, H0 + H)
|
| 237 |
plt.plot(x_all[:H0], rv_train, label="realized vol (history)")
|
| 238 |
plt.plot(x_fcst, rv_test, label="realized vol (actual holdout)")
|
| 239 |
-
plt.plot(x_fcst,
|
| 240 |
-
plt.fill_between(x_fcst, p10, p90, alpha=0.3, label="80% interval")
|
| 241 |
plt.xlabel("time index")
|
| 242 |
|
| 243 |
plt.title(f"Volatility Forecast (RV window={RV_WINDOW}, H={H})")
|
|
@@ -246,7 +246,6 @@ def run_vol_forecast_and_evaluate():
|
|
| 246 |
plt.tight_layout()
|
| 247 |
|
| 248 |
# Tabelle: Tag-für-Tag Vergleich
|
| 249 |
-
# (falls Datum vorhanden, verwende die letzten H RV-Datenpunkte)
|
| 250 |
if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 251 |
dates_rv = np.array(dates[-len(rv):])
|
| 252 |
last_dates = dates_rv[H0:]
|
|
@@ -256,8 +255,8 @@ def run_vol_forecast_and_evaluate():
|
|
| 256 |
df_days = pd.DataFrame({
|
| 257 |
"date": last_dates,
|
| 258 |
"actual_vol": rv_test,
|
| 259 |
-
"
|
| 260 |
-
"
|
| 261 |
"abs_pct_error_%": abs_pct_err,
|
| 262 |
})
|
| 263 |
|
|
@@ -267,28 +266,35 @@ def run_vol_forecast_and_evaluate():
|
|
| 267 |
"prediction_length": H,
|
| 268 |
"num_samples": NUM_SAMPLES,
|
| 269 |
"annualized": ANNUALIZE,
|
| 270 |
-
"point_forecast": "
|
|
|
|
| 271 |
},
|
| 272 |
"metrics": {
|
| 273 |
"MAPE_%": mape_pct,
|
|
|
|
| 274 |
"RMSE": rmse,
|
| 275 |
},
|
| 276 |
}
|
| 277 |
|
| 278 |
-
metrics_md =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
return fig, out_json, df_days, metrics_md
|
| 280 |
|
| 281 |
# --------------------
|
| 282 |
# UI
|
| 283 |
# --------------------
|
| 284 |
-
with gr.Blocks(title="Volatility Forecast •
|
| 285 |
gr.Markdown(
|
| 286 |
-
"##
|
| 287 |
-
"-
|
| 288 |
-
"-
|
|
|
|
| 289 |
)
|
| 290 |
run_btn = gr.Button("Run", variant="primary")
|
| 291 |
-
plot = gr.Plot(label="Forecast (
|
| 292 |
meta = gr.JSON(label="Konfiguration & Gesamtmetriken")
|
| 293 |
table = gr.Dataframe(label="Per-Day Vergleich", wrap=True)
|
| 294 |
metrics = gr.Markdown(label="Metriken")
|
|
@@ -298,3 +304,4 @@ with gr.Blocks(title="Volatility Forecast • Punktprognose") as demo:
|
|
| 298 |
if __name__ == "__main__":
|
| 299 |
demo.launch()
|
| 300 |
|
|
|
|
|
|
| 116 |
|
| 117 |
|
| 118 |
|
| 119 |
+
# app.py
|
| 120 |
+
import os, random
|
| 121 |
import numpy as np
|
| 122 |
import pandas as pd
|
| 123 |
import torch
|
|
|
|
| 131 |
# Config
|
| 132 |
# --------------------
|
| 133 |
MODEL_ID = "amazon/chronos-t5-large"
|
| 134 |
+
PREDICTION_LENGTH = 30 # letzte 30 Tage
|
| 135 |
+
NUM_SAMPLES = 1 # genau EINE Bahn -> tagesgenaue Punktvorhersage
|
| 136 |
+
RV_WINDOW = 20 # Rolling-Fenster für RV (Handelstage)
|
| 137 |
+
ANNUALIZE = True # annualisiert mit sqrt(252)
|
| 138 |
+
EPS = 1e-8 # Schutz gegen Division durch 0
|
| 139 |
|
| 140 |
# --------------------
|
| 141 |
# Model load
|
|
|
|
| 189 |
# Main
|
| 190 |
# --------------------
|
| 191 |
def run_vol_forecast_and_evaluate():
|
| 192 |
+
# Daten laden
|
| 193 |
raw = _read_ohlcv_csv()
|
| 194 |
dates = _extract_dates(raw)
|
| 195 |
close = _extract_close(raw)
|
|
|
|
| 200 |
if n <= H + 5:
|
| 201 |
raise gr.Error(f"RV-Serie zu kurz nach Rolling. Benötigt > {H+5}, erhalten {n}.")
|
| 202 |
|
| 203 |
+
# Holdout: letzte H Tage
|
| 204 |
rv_train = rv[: n - H]
|
| 205 |
rv_test = rv[n - H :]
|
| 206 |
|
| 207 |
+
# Reproduzierbare EINZELNE Sample-Bahn ziehen
|
| 208 |
+
random.seed(0); np.random.seed(0); torch.manual_seed(0)
|
| 209 |
+
if torch.cuda.is_available():
|
| 210 |
+
torch.cuda.manual_seed_all(0)
|
| 211 |
+
|
| 212 |
context = torch.tensor(rv_train, dtype=torch.float32)
|
| 213 |
+
fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1, 1, H]
|
| 214 |
+
samples = fcst[0].cpu().numpy() # (1, H)
|
| 215 |
+
path_pred = samples[0] # (H,) <-- tagesgenaue Vorhersage
|
|
|
|
| 216 |
|
| 217 |
+
# Tagesfehler & Prozentfehler
|
| 218 |
+
err = path_pred - rv_test
|
| 219 |
+
denom = np.maximum(EPS, np.abs(rv_test))
|
| 220 |
+
abs_pct_err = np.abs(err) / denom * 100.0
|
| 221 |
+
pct_err = err / np.maximum(EPS, rv_test) * 100.0
|
| 222 |
|
| 223 |
+
mape_pct = float(abs_pct_err.mean()) # Hauptmetrik: mittlere absolute proz. Abweichung
|
| 224 |
+
mpe_pct = float(pct_err.mean()) # signiert (Bias)
|
|
|
|
|
|
|
| 225 |
rmse = float(np.sqrt(np.mean(err**2)))
|
| 226 |
|
| 227 |
+
# Plot: History + Actual (Holdout) + Forecast-Pfad
|
| 228 |
fig = plt.figure(figsize=(10, 4))
|
| 229 |
H0 = len(rv_train)
|
| 230 |
if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 231 |
dates_rv = np.array(dates[-len(rv):])
|
| 232 |
plt.plot(dates_rv[:H0], rv_train, label="realized vol (history)")
|
| 233 |
plt.plot(dates_rv[H0:], rv_test, label="realized vol (actual holdout)")
|
| 234 |
+
plt.plot(dates_rv[H0:], path_pred, linestyle="--", label="forecast (sample path)")
|
|
|
|
| 235 |
plt.xlabel("date")
|
| 236 |
else:
|
| 237 |
x_all = np.arange(len(rv)); x_fcst = np.arange(H0, H0 + H)
|
| 238 |
plt.plot(x_all[:H0], rv_train, label="realized vol (history)")
|
| 239 |
plt.plot(x_fcst, rv_test, label="realized vol (actual holdout)")
|
| 240 |
+
plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (sample path)")
|
|
|
|
| 241 |
plt.xlabel("time index")
|
| 242 |
|
| 243 |
plt.title(f"Volatility Forecast (RV window={RV_WINDOW}, H={H})")
|
|
|
|
| 246 |
plt.tight_layout()
|
| 247 |
|
| 248 |
# Tabelle: Tag-für-Tag Vergleich
|
|
|
|
| 249 |
if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 250 |
dates_rv = np.array(dates[-len(rv):])
|
| 251 |
last_dates = dates_rv[H0:]
|
|
|
|
| 255 |
df_days = pd.DataFrame({
|
| 256 |
"date": last_dates,
|
| 257 |
"actual_vol": rv_test,
|
| 258 |
+
"forecast_vol": path_pred,
|
| 259 |
+
"pct_error_% (signed)": pct_err,
|
| 260 |
"abs_pct_error_%": abs_pct_err,
|
| 261 |
})
|
| 262 |
|
|
|
|
| 266 |
"prediction_length": H,
|
| 267 |
"num_samples": NUM_SAMPLES,
|
| 268 |
"annualized": ANNUALIZE,
|
| 269 |
+
"point_forecast": "single_sample_path",
|
| 270 |
+
"seed": 0,
|
| 271 |
},
|
| 272 |
"metrics": {
|
| 273 |
"MAPE_%": mape_pct,
|
| 274 |
+
"MPE_%": mpe_pct,
|
| 275 |
"RMSE": rmse,
|
| 276 |
},
|
| 277 |
}
|
| 278 |
|
| 279 |
+
metrics_md = (
|
| 280 |
+
f"**MAPE (Ø absolute %-Abweichung): {mape_pct:.2f}%** "
|
| 281 |
+
f"**MPE (Ø signed %): {mpe_pct:.2f}%** "
|
| 282 |
+
f"**RMSE:** {rmse:.6f}"
|
| 283 |
+
)
|
| 284 |
return fig, out_json, df_days, metrics_md
|
| 285 |
|
| 286 |
# --------------------
|
| 287 |
# UI
|
| 288 |
# --------------------
|
| 289 |
+
with gr.Blocks(title="Volatility Forecast • Tagesgenaue Punktwerte") as demo:
|
| 290 |
gr.Markdown(
|
| 291 |
+
"## Vorhersage der letzten 30 Tage (tagesgenaue Punktwerte)\n"
|
| 292 |
+
"- Es wird **eine einzelne Sample-Bahn** prognostiziert (keine Mittelung, kein Median).\n"
|
| 293 |
+
"- Vergleich pro Tag: Forecast vs. Actual + Prozentfehler.\n"
|
| 294 |
+
"- Gesamt: **MAPE%** (Hauptmetrik), **MPE%** (Bias) und RMSE."
|
| 295 |
)
|
| 296 |
run_btn = gr.Button("Run", variant="primary")
|
| 297 |
+
plot = gr.Plot(label="Forecast (einzelne Bahn) vs Actual")
|
| 298 |
meta = gr.JSON(label="Konfiguration & Gesamtmetriken")
|
| 299 |
table = gr.Dataframe(label="Per-Day Vergleich", wrap=True)
|
| 300 |
metrics = gr.Markdown(label="Metriken")
|
|
|
|
| 304 |
if __name__ == "__main__":
|
| 305 |
demo.launch()
|
| 306 |
|
| 307 |
+
|