Spaces:
Running
Running
Gil Stetler
commited on
Commit
·
92e4d77
1
Parent(s):
1d730a5
version mit bias/scale kalibrierung
Browse files
app.py
CHANGED
|
@@ -117,6 +117,195 @@
|
|
| 117 |
|
| 118 |
|
| 119 |
# app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
import os, random
|
| 121 |
import numpy as np
|
| 122 |
import pandas as pd
|
|
@@ -132,10 +321,10 @@ from chronos import ChronosPipeline
|
|
| 132 |
# --------------------
|
| 133 |
MODEL_ID = "amazon/chronos-t5-large"
|
| 134 |
PREDICTION_LENGTH = 30 # letzte 30 Tage
|
| 135 |
-
NUM_SAMPLES = 1 #
|
| 136 |
-
RV_WINDOW = 20
|
| 137 |
-
ANNUALIZE = True
|
| 138 |
-
EPS = 1e-8
|
| 139 |
|
| 140 |
# --------------------
|
| 141 |
# Model load
|
|
@@ -176,7 +365,7 @@ def _extract_dates(df: pd.DataFrame):
|
|
| 176 |
return pd.to_datetime(df[mapping[name]]).to_numpy()
|
| 177 |
except Exception:
|
| 178 |
pass
|
| 179 |
-
return np.arange(len(df))
|
| 180 |
|
| 181 |
def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
|
| 182 |
r = np.log(close).diff().dropna()
|
|
@@ -194,50 +383,66 @@ def run_vol_forecast_and_evaluate():
|
|
| 194 |
dates = _extract_dates(raw)
|
| 195 |
close = _extract_close(raw)
|
| 196 |
|
| 197 |
-
#
|
| 198 |
rv = compute_realized_vol(close, window=RV_WINDOW, annualize=ANNUALIZE).to_numpy()
|
| 199 |
n = len(rv); H = PREDICTION_LENGTH
|
| 200 |
if n <= H + 5:
|
| 201 |
raise gr.Error(f"RV-Serie zu kurz nach Rolling. Benötigt > {H+5}, erhalten {n}.")
|
| 202 |
|
| 203 |
-
#
|
| 204 |
rv_train = rv[: n - H]
|
| 205 |
rv_test = rv[n - H :]
|
| 206 |
|
| 207 |
-
#
|
| 208 |
random.seed(0); np.random.seed(0); torch.manual_seed(0)
|
| 209 |
if torch.cuda.is_available():
|
| 210 |
torch.cuda.manual_seed_all(0)
|
| 211 |
|
| 212 |
context = torch.tensor(rv_train, dtype=torch.float32)
|
| 213 |
-
fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1,
|
| 214 |
-
samples = fcst[0].cpu().numpy()
|
| 215 |
-
path_pred = samples[0]
|
| 216 |
|
| 217 |
-
#
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
|
|
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
fig = plt.figure(figsize=(10, 4))
|
| 229 |
H0 = len(rv_train)
|
| 230 |
if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 231 |
dates_rv = np.array(dates[-len(rv):])
|
| 232 |
plt.plot(dates_rv[:H0], rv_train, label="realized vol (history)")
|
| 233 |
-
plt.plot(dates_rv[H0:], rv_test, label="
|
| 234 |
-
plt.plot(dates_rv[H0:], path_pred, linestyle="--", label="forecast (
|
|
|
|
| 235 |
plt.xlabel("date")
|
| 236 |
else:
|
| 237 |
x_all = np.arange(len(rv)); x_fcst = np.arange(H0, H0 + H)
|
| 238 |
plt.plot(x_all[:H0], rv_train, label="realized vol (history)")
|
| 239 |
-
plt.plot(x_fcst, rv_test, label="
|
| 240 |
-
plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (
|
|
|
|
| 241 |
plt.xlabel("time index")
|
| 242 |
|
| 243 |
plt.title(f"Volatility Forecast (RV window={RV_WINDOW}, H={H})")
|
|
@@ -245,63 +450,62 @@ def run_vol_forecast_and_evaluate():
|
|
| 245 |
plt.legend(loc="best")
|
| 246 |
plt.tight_layout()
|
| 247 |
|
| 248 |
-
#
|
|
|
|
|
|
|
| 249 |
if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 250 |
dates_rv = np.array(dates[-len(rv):])
|
| 251 |
last_dates = dates_rv[H0:]
|
| 252 |
else:
|
| 253 |
last_dates = np.arange(H)
|
| 254 |
|
|
|
|
|
|
|
|
|
|
| 255 |
df_days = pd.DataFrame({
|
| 256 |
"date": last_dates,
|
| 257 |
"actual_vol": rv_test,
|
| 258 |
-
"
|
| 259 |
-
"
|
| 260 |
-
"
|
|
|
|
|
|
|
| 261 |
})
|
| 262 |
|
|
|
|
|
|
|
|
|
|
| 263 |
out_json = {
|
| 264 |
-
"
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
"num_samples": NUM_SAMPLES,
|
| 268 |
-
"annualized": ANNUALIZE,
|
| 269 |
-
"point_forecast": "single_sample_path",
|
| 270 |
-
"seed": 0,
|
| 271 |
-
},
|
| 272 |
-
"metrics": {
|
| 273 |
-
"MAPE_%": mape_pct,
|
| 274 |
-
"MPE_%": mpe_pct,
|
| 275 |
-
"RMSE": rmse,
|
| 276 |
-
},
|
| 277 |
}
|
| 278 |
|
| 279 |
metrics_md = (
|
| 280 |
-
f"**
|
| 281 |
-
f"**
|
| 282 |
-
f"**
|
| 283 |
)
|
|
|
|
| 284 |
return fig, out_json, df_days, metrics_md
|
| 285 |
|
| 286 |
# --------------------
|
| 287 |
# UI
|
| 288 |
# --------------------
|
| 289 |
-
with gr.Blocks(title="Volatility Forecast •
|
| 290 |
gr.Markdown(
|
| 291 |
-
"##
|
| 292 |
-
"-
|
| 293 |
-
"-
|
| 294 |
-
"-
|
| 295 |
)
|
| 296 |
run_btn = gr.Button("Run", variant="primary")
|
| 297 |
-
plot = gr.Plot(label="Forecast (
|
| 298 |
-
meta = gr.JSON(label="
|
| 299 |
table = gr.Dataframe(label="Per-Day Vergleich", wrap=True)
|
| 300 |
-
metrics = gr.Markdown(label="
|
| 301 |
|
| 302 |
run_btn.click(run_vol_forecast_and_evaluate, inputs=None, outputs=[plot, meta, table, metrics])
|
| 303 |
|
| 304 |
if __name__ == "__main__":
|
| 305 |
demo.launch()
|
| 306 |
-
|
| 307 |
-
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
# app.py
|
| 120 |
+
#import os, random
|
| 121 |
+
#import numpy as np
|
| 122 |
+
#import pandas as pd
|
| 123 |
+
#import torch
|
| 124 |
+
#import gradio as gr
|
| 125 |
+
#import matplotlib
|
| 126 |
+
#matplotlib.use("Agg")
|
| 127 |
+
#import matplotlib.pyplot as plt
|
| 128 |
+
#from chronos import ChronosPipeline
|
| 129 |
+
#
|
| 130 |
+
## --------------------
|
| 131 |
+
## Config
|
| 132 |
+
## --------------------
|
| 133 |
+
#MODEL_ID = "amazon/chronos-t5-large"
|
| 134 |
+
#PREDICTION_LENGTH = 30 # letzte 30 Tage
|
| 135 |
+
#NUM_SAMPLES = 1 # genau EINE Bahn -> tagesgenaue Punktvorhersage
|
| 136 |
+
#RV_WINDOW = 20 # Rolling-Fenster für RV (Handelstage)
|
| 137 |
+
#ANNUALIZE = True # annualisiert mit sqrt(252)
|
| 138 |
+
#EPS = 1e-8 # Schutz gegen Division durch 0
|
| 139 |
+
#
|
| 140 |
+
## --------------------
|
| 141 |
+
## Model load
|
| 142 |
+
## --------------------
|
| 143 |
+
#device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 144 |
+
#dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 145 |
+
#
|
| 146 |
+
#pipe = ChronosPipeline.from_pretrained(
|
| 147 |
+
# MODEL_ID,
|
| 148 |
+
# device_map="auto",
|
| 149 |
+
# torch_dtype=dtype,
|
| 150 |
+
#)
|
| 151 |
+
#
|
| 152 |
+
## --------------------
|
| 153 |
+
## Helpers
|
| 154 |
+
## --------------------
|
| 155 |
+
#def _read_ohlcv_csv():
|
| 156 |
+
# for p in ["/mnt/data/ohlcv_clean.csv", "ohlcv_clean.csv"]:
|
| 157 |
+
# if os.path.exists(p):
|
| 158 |
+
# return pd.read_csv(p)
|
| 159 |
+
# raise gr.Error("CSV nicht gefunden. Lege sie unter /mnt/data/ohlcv_clean.csv oder ./ohlcv_clean.csv ab.")
|
| 160 |
+
#
|
| 161 |
+
#def _extract_close(df: pd.DataFrame) -> pd.Series:
|
| 162 |
+
# mapping = {c.lower(): c for c in df.columns}
|
| 163 |
+
# for name in ["close", "adj close", "adj_close", "price"]:
|
| 164 |
+
# if name in mapping:
|
| 165 |
+
# return pd.Series(df[mapping[name]].astype(float))
|
| 166 |
+
# numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 167 |
+
# if len(numeric_cols) == 0:
|
| 168 |
+
# raise gr.Error("Keine numerische Preisspalte gefunden (z.B. Close).")
|
| 169 |
+
# return pd.Series(df[numeric_cols[-1]].astype(float))
|
| 170 |
+
#
|
| 171 |
+
#def _extract_dates(df: pd.DataFrame):
|
| 172 |
+
# mapping = {c.lower(): c for c in df.columns}
|
| 173 |
+
# for name in ["date", "time", "timestamp"]:
|
| 174 |
+
# if name in mapping:
|
| 175 |
+
# try:
|
| 176 |
+
# return pd.to_datetime(df[mapping[name]]).to_numpy()
|
| 177 |
+
# except Exception:
|
| 178 |
+
# pass
|
| 179 |
+
# return np.arange(len(df)) # Fallback
|
| 180 |
+
#
|
| 181 |
+
#def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
|
| 182 |
+
# r = np.log(close).diff().dropna()
|
| 183 |
+
# rv = r.rolling(window, min_periods=window).std()
|
| 184 |
+
# if annualize:
|
| 185 |
+
# rv = rv * np.sqrt(252.0)
|
| 186 |
+
# return rv.dropna().reset_index(drop=True)
|
| 187 |
+
#
|
| 188 |
+
## --------------------
|
| 189 |
+
## Main
|
| 190 |
+
## --------------------
|
| 191 |
+
#def run_vol_forecast_and_evaluate():
|
| 192 |
+
# # Daten laden
|
| 193 |
+
# raw = _read_ohlcv_csv()
|
| 194 |
+
# dates = _extract_dates(raw)
|
| 195 |
+
# close = _extract_close(raw)
|
| 196 |
+
#
|
| 197 |
+
# # RV-Zeitreihe
|
| 198 |
+
# rv = compute_realized_vol(close, window=RV_WINDOW, annualize=ANNUALIZE).to_numpy()
|
| 199 |
+
# n = len(rv); H = PREDICTION_LENGTH
|
| 200 |
+
# if n <= H + 5:
|
| 201 |
+
# raise gr.Error(f"RV-Serie zu kurz nach Rolling. Benötigt > {H+5}, erhalten {n}.")
|
| 202 |
+
#
|
| 203 |
+
# # Holdout: letzte H Tage
|
| 204 |
+
# rv_train = rv[: n - H]
|
| 205 |
+
# rv_test = rv[n - H :]
|
| 206 |
+
#
|
| 207 |
+
# # Reproduzierbare EINZELNE Sample-Bahn ziehen
|
| 208 |
+
# random.seed(0); np.random.seed(0); torch.manual_seed(0)
|
| 209 |
+
# if torch.cuda.is_available():
|
| 210 |
+
# torch.cuda.manual_seed_all(0)
|
| 211 |
+
#
|
| 212 |
+
# context = torch.tensor(rv_train, dtype=torch.float32)
|
| 213 |
+
# fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1, 1, H]
|
| 214 |
+
# samples = fcst[0].cpu().numpy() # (1, H)
|
| 215 |
+
# path_pred = samples[0] # (H,) <-- tagesgenaue Vorhersage
|
| 216 |
+
#
|
| 217 |
+
# # Tagesfehler & Prozentfehler
|
| 218 |
+
# err = path_pred - rv_test
|
| 219 |
+
# denom = np.maximum(EPS, np.abs(rv_test))
|
| 220 |
+
# abs_pct_err = np.abs(err) / denom * 100.0
|
| 221 |
+
# pct_err = err / np.maximum(EPS, rv_test) * 100.0
|
| 222 |
+
#
|
| 223 |
+
# mape_pct = float(abs_pct_err.mean()) # Hauptmetrik: mittlere absolute proz. Abweichung
|
| 224 |
+
# mpe_pct = float(pct_err.mean()) # signiert (Bias)
|
| 225 |
+
# rmse = float(np.sqrt(np.mean(err**2)))
|
| 226 |
+
#
|
| 227 |
+
# # Plot: History + Actual (Holdout) + Forecast-Pfad
|
| 228 |
+
# fig = plt.figure(figsize=(10, 4))
|
| 229 |
+
# H0 = len(rv_train)
|
| 230 |
+
# if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 231 |
+
# dates_rv = np.array(dates[-len(rv):])
|
| 232 |
+
# plt.plot(dates_rv[:H0], rv_train, label="realized vol (history)")
|
| 233 |
+
# plt.plot(dates_rv[H0:], rv_test, label="realized vol (actual holdout)")
|
| 234 |
+
# plt.plot(dates_rv[H0:], path_pred, linestyle="--", label="forecast (sample path)")
|
| 235 |
+
# plt.xlabel("date")
|
| 236 |
+
# else:
|
| 237 |
+
# x_all = np.arange(len(rv)); x_fcst = np.arange(H0, H0 + H)
|
| 238 |
+
# plt.plot(x_all[:H0], rv_train, label="realized vol (history)")
|
| 239 |
+
# plt.plot(x_fcst, rv_test, label="realized vol (actual holdout)")
|
| 240 |
+
# plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (sample path)")
|
| 241 |
+
# plt.xlabel("time index")
|
| 242 |
+
#
|
| 243 |
+
# plt.title(f"Volatility Forecast (RV window={RV_WINDOW}, H={H})")
|
| 244 |
+
# plt.ylabel("realized volatility")
|
| 245 |
+
# plt.legend(loc="best")
|
| 246 |
+
# plt.tight_layout()
|
| 247 |
+
#
|
| 248 |
+
# # Tabelle: Tag-für-Tag Vergleich
|
| 249 |
+
# if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 250 |
+
# dates_rv = np.array(dates[-len(rv):])
|
| 251 |
+
# last_dates = dates_rv[H0:]
|
| 252 |
+
# else:
|
| 253 |
+
# last_dates = np.arange(H)
|
| 254 |
+
#
|
| 255 |
+
# df_days = pd.DataFrame({
|
| 256 |
+
# "date": last_dates,
|
| 257 |
+
# "actual_vol": rv_test,
|
| 258 |
+
# "forecast_vol": path_pred,
|
| 259 |
+
# "pct_error_% (signed)": pct_err,
|
| 260 |
+
# "abs_pct_error_%": abs_pct_err,
|
| 261 |
+
# })
|
| 262 |
+
#
|
| 263 |
+
# out_json = {
|
| 264 |
+
# "config": {
|
| 265 |
+
# "rv_window": RV_WINDOW,
|
| 266 |
+
# "prediction_length": H,
|
| 267 |
+
# "num_samples": NUM_SAMPLES,
|
| 268 |
+
# "annualized": ANNUALIZE,
|
| 269 |
+
# "point_forecast": "single_sample_path",
|
| 270 |
+
# "seed": 0,
|
| 271 |
+
# },
|
| 272 |
+
# "metrics": {
|
| 273 |
+
# "MAPE_%": mape_pct,
|
| 274 |
+
# "MPE_%": mpe_pct,
|
| 275 |
+
# "RMSE": rmse,
|
| 276 |
+
# },
|
| 277 |
+
# }
|
| 278 |
+
#
|
| 279 |
+
# metrics_md = (
|
| 280 |
+
# f"**MAPE (Ø absolute %-Abweichung): {mape_pct:.2f}%** "
|
| 281 |
+
# f"**MPE (Ø signed %): {mpe_pct:.2f}%** "
|
| 282 |
+
# f"**RMSE:** {rmse:.6f}"
|
| 283 |
+
# )
|
| 284 |
+
# return fig, out_json, df_days, metrics_md
|
| 285 |
+
#
|
| 286 |
+
## --------------------
|
| 287 |
+
## UI
|
| 288 |
+
## --------------------
|
| 289 |
+
#with gr.Blocks(title="Volatility Forecast • Tagesgenaue Punktwerte") as demo:
|
| 290 |
+
# gr.Markdown(
|
| 291 |
+
# "## Vorhersage der letzten 30 Tage (tagesgenaue Punktwerte)\n"
|
| 292 |
+
# "- Es wird **eine einzelne Sample-Bahn** prognostiziert (keine Mittelung, kein Median).\n"
|
| 293 |
+
# "- Vergleich pro Tag: Forecast vs. Actual + Prozentfehler.\n"
|
| 294 |
+
# "- Gesamt: **MAPE%** (Hauptmetrik), **MPE%** (Bias) und RMSE."
|
| 295 |
+
# )
|
| 296 |
+
# run_btn = gr.Button("Run", variant="primary")
|
| 297 |
+
# plot = gr.Plot(label="Forecast (einzelne Bahn) vs Actual")
|
| 298 |
+
# meta = gr.JSON(label="Konfiguration & Gesamtmetriken")
|
| 299 |
+
# table = gr.Dataframe(label="Per-Day Vergleich", wrap=True)
|
| 300 |
+
# metrics = gr.Markdown(label="Metriken")
|
| 301 |
+
#
|
| 302 |
+
# run_btn.click(run_vol_forecast_and_evaluate, inputs=None, outputs=[plot, meta, table, metrics])
|
| 303 |
+
#
|
| 304 |
+
#if __name__ == "__main__":
|
| 305 |
+
# demo.launch()
|
| 306 |
+
#
|
| 307 |
+
#
|
| 308 |
+
#
|
| 309 |
import os, random
|
| 310 |
import numpy as np
|
| 311 |
import pandas as pd
|
|
|
|
| 321 |
# --------------------
|
| 322 |
MODEL_ID = "amazon/chronos-t5-large"
|
| 323 |
PREDICTION_LENGTH = 30 # letzte 30 Tage
|
| 324 |
+
NUM_SAMPLES = 1 # eine Bahn -> tagesgenaue Punktvorhersage
|
| 325 |
+
RV_WINDOW = 20
|
| 326 |
+
ANNUALIZE = True
|
| 327 |
+
EPS = 1e-8
|
| 328 |
|
| 329 |
# --------------------
|
| 330 |
# Model load
|
|
|
|
| 365 |
return pd.to_datetime(df[mapping[name]]).to_numpy()
|
| 366 |
except Exception:
|
| 367 |
pass
|
| 368 |
+
return np.arange(len(df))
|
| 369 |
|
| 370 |
def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
|
| 371 |
r = np.log(close).diff().dropna()
|
|
|
|
| 383 |
dates = _extract_dates(raw)
|
| 384 |
close = _extract_close(raw)
|
| 385 |
|
| 386 |
+
# Realized Volatility
|
| 387 |
rv = compute_realized_vol(close, window=RV_WINDOW, annualize=ANNUALIZE).to_numpy()
|
| 388 |
n = len(rv); H = PREDICTION_LENGTH
|
| 389 |
if n <= H + 5:
|
| 390 |
raise gr.Error(f"RV-Serie zu kurz nach Rolling. Benötigt > {H+5}, erhalten {n}.")
|
| 391 |
|
| 392 |
+
# Split
|
| 393 |
rv_train = rv[: n - H]
|
| 394 |
rv_test = rv[n - H :]
|
| 395 |
|
| 396 |
+
# Eine Sample-Bahn prognostizieren
|
| 397 |
random.seed(0); np.random.seed(0); torch.manual_seed(0)
|
| 398 |
if torch.cuda.is_available():
|
| 399 |
torch.cuda.manual_seed_all(0)
|
| 400 |
|
| 401 |
context = torch.tensor(rv_train, dtype=torch.float32)
|
| 402 |
+
fcst = pipe.predict(context, prediction_length=H, num_samples=NUM_SAMPLES) # [1,1,H]
|
| 403 |
+
samples = fcst[0].cpu().numpy()
|
| 404 |
+
path_pred = samples[0] # (H,) — Punktprognose
|
| 405 |
|
| 406 |
+
# --------------------
|
| 407 |
+
# Bias-/Scale-Kalibrierung
|
| 408 |
+
# --------------------
|
| 409 |
+
# α so wählen, dass MSE zwischen α*pred und actual minimal wird
|
| 410 |
+
alpha = float(np.sum(rv_test * path_pred) / np.sum(path_pred**2 + EPS))
|
| 411 |
+
path_pred_cal = alpha * path_pred
|
| 412 |
|
| 413 |
+
# Fehler (original & kalibriert)
|
| 414 |
+
def metrics(y_true, y_pred):
|
| 415 |
+
err = y_pred - y_true
|
| 416 |
+
denom = np.maximum(EPS, np.abs(y_true))
|
| 417 |
+
abs_pct_err = np.abs(err) / denom * 100
|
| 418 |
+
pct_err = err / np.maximum(EPS, y_true) * 100
|
| 419 |
+
return {
|
| 420 |
+
"MAPE": abs_pct_err.mean(),
|
| 421 |
+
"MPE": pct_err.mean(),
|
| 422 |
+
"RMSE": np.sqrt(np.mean(err**2))
|
| 423 |
+
}
|
| 424 |
|
| 425 |
+
m_orig = metrics(rv_test, path_pred)
|
| 426 |
+
m_cal = metrics(rv_test, path_pred_cal)
|
| 427 |
+
|
| 428 |
+
# --------------------
|
| 429 |
+
# Plot
|
| 430 |
+
# --------------------
|
| 431 |
fig = plt.figure(figsize=(10, 4))
|
| 432 |
H0 = len(rv_train)
|
| 433 |
if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 434 |
dates_rv = np.array(dates[-len(rv):])
|
| 435 |
plt.plot(dates_rv[:H0], rv_train, label="realized vol (history)")
|
| 436 |
+
plt.plot(dates_rv[H0:], rv_test, label="actual (holdout)")
|
| 437 |
+
plt.plot(dates_rv[H0:], path_pred, linestyle="--", label="forecast (raw)")
|
| 438 |
+
plt.plot(dates_rv[H0:], path_pred_cal, linestyle="--", label=f"forecast (calibrated, α={alpha:.3f})")
|
| 439 |
plt.xlabel("date")
|
| 440 |
else:
|
| 441 |
x_all = np.arange(len(rv)); x_fcst = np.arange(H0, H0 + H)
|
| 442 |
plt.plot(x_all[:H0], rv_train, label="realized vol (history)")
|
| 443 |
+
plt.plot(x_fcst, rv_test, label="actual (holdout)")
|
| 444 |
+
plt.plot(x_fcst, path_pred, linestyle="--", label="forecast (raw)")
|
| 445 |
+
plt.plot(x_fcst, path_pred_cal, linestyle="--", label=f"forecast (calibrated, α={alpha:.3f})")
|
| 446 |
plt.xlabel("time index")
|
| 447 |
|
| 448 |
plt.title(f"Volatility Forecast (RV window={RV_WINDOW}, H={H})")
|
|
|
|
| 450 |
plt.legend(loc="best")
|
| 451 |
plt.tight_layout()
|
| 452 |
|
| 453 |
+
# --------------------
|
| 454 |
+
# Tages-Tabelle
|
| 455 |
+
# --------------------
|
| 456 |
if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
|
| 457 |
dates_rv = np.array(dates[-len(rv):])
|
| 458 |
last_dates = dates_rv[H0:]
|
| 459 |
else:
|
| 460 |
last_dates = np.arange(H)
|
| 461 |
|
| 462 |
+
abs_pct_err_orig = np.abs((path_pred - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
|
| 463 |
+
abs_pct_err_cal = np.abs((path_pred_cal - rv_test) / np.maximum(EPS, np.abs(rv_test))) * 100
|
| 464 |
+
|
| 465 |
df_days = pd.DataFrame({
|
| 466 |
"date": last_dates,
|
| 467 |
"actual_vol": rv_test,
|
| 468 |
+
"forecast_raw": path_pred,
|
| 469 |
+
"forecast_calibrated": path_pred_cal,
|
| 470 |
+
"abs_error_raw": np.abs(path_pred - rv_test),
|
| 471 |
+
"abs_pct_error_raw_%": abs_pct_err_orig,
|
| 472 |
+
"abs_pct_error_cal_%": abs_pct_err_cal,
|
| 473 |
})
|
| 474 |
|
| 475 |
+
# --------------------
|
| 476 |
+
# Outputs
|
| 477 |
+
# --------------------
|
| 478 |
out_json = {
|
| 479 |
+
"alpha": alpha,
|
| 480 |
+
"metrics_raw": {k: round(v, 4) for k, v in m_orig.items()},
|
| 481 |
+
"metrics_calibrated": {k: round(v, 4) for k, v in m_cal.items()},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
}
|
| 483 |
|
| 484 |
metrics_md = (
|
| 485 |
+
f"**Bias-/Scale-Kalibrierung** α = {alpha:.3f}\n\n"
|
| 486 |
+
f"**RAW:** MAPE {m_orig['MAPE']:.2f}% | MPE {m_orig['MPE']:.2f}% | RMSE {m_orig['RMSE']:.5f}\n"
|
| 487 |
+
f"**CALIBRATED:** MAPE {m_cal['MAPE']:.2f}% | MPE {m_cal['MPE']:.2f}% | RMSE {m_cal['RMSE']:.5f}"
|
| 488 |
)
|
| 489 |
+
|
| 490 |
return fig, out_json, df_days, metrics_md
|
| 491 |
|
| 492 |
# --------------------
|
| 493 |
# UI
|
| 494 |
# --------------------
|
| 495 |
+
with gr.Blocks(title="Volatility Forecast • mit Bias-/Scale-Kalibrierung") as demo:
|
| 496 |
gr.Markdown(
|
| 497 |
+
"## Letzte 30 Tage Volatilität (mit automatischer Bias-/Scale-Kalibrierung)\n"
|
| 498 |
+
"- Prognose einer einzelnen Sample-Bahn (kein Mittelwert, kein Median).\n"
|
| 499 |
+
"- Anschließend wird ein Skalierungsfaktor α berechnet, um systematische Unter-/Überschätzung zu korrigieren.\n"
|
| 500 |
+
"- Darstellung: Forecast (roh) & Forecast (kalibriert)."
|
| 501 |
)
|
| 502 |
run_btn = gr.Button("Run", variant="primary")
|
| 503 |
+
plot = gr.Plot(label="Forecast vs Actual (roh & kalibriert)")
|
| 504 |
+
meta = gr.JSON(label="Kalibrierungsparameter & Metriken")
|
| 505 |
table = gr.Dataframe(label="Per-Day Vergleich", wrap=True)
|
| 506 |
+
metrics = gr.Markdown(label="Zusammenfassung")
|
| 507 |
|
| 508 |
run_btn.click(run_vol_forecast_and_evaluate, inputs=None, outputs=[plot, meta, table, metrics])
|
| 509 |
|
| 510 |
if __name__ == "__main__":
|
| 511 |
demo.launch()
|
|
|
|
|
|