Gil Stetler commited on
Commit
682cd17
·
1 Parent(s): 438b5d2

updated files

Browse files
Files changed (2) hide show
  1. app.py +39 -32
  2. pipeline_v2.py +50 -175
app.py CHANGED
@@ -512,7 +512,9 @@
512
  #
513
 
514
 
 
515
  import os, random
 
516
  import numpy as np
517
  import pandas as pd
518
  import torch
@@ -522,8 +524,8 @@ matplotlib.use("Agg")
522
  import matplotlib.pyplot as plt
523
  from chronos import ChronosPipeline
524
 
525
- # >>> import your pipeline <<<
526
- import pipeline_v2 as pipe2 # provides update_ticker_csv(...)
527
 
528
  # --------------------
529
  # Config
@@ -551,17 +553,22 @@ pipe = ChronosPipeline.from_pretrained(
551
  # Helpers
552
  # --------------------
553
  def _extract_close(df: pd.DataFrame) -> pd.Series:
 
554
  mapping = {c.lower(): c for c in df.columns}
555
- for name in ["close", "adj close", "adj_close", "price"]:
556
  if name in mapping:
557
  return pd.Series(df[mapping[name]]).astype(float)
558
- # fallback: last numeric column
559
  num_cols = df.select_dtypes(include=[np.number]).columns
560
  if len(num_cols) == 0:
561
- raise gr.Error("Could not find a numeric price column (e.g., Close).")
562
  return pd.Series(df[num_cols[-1]]).astype(float)
563
 
564
  def _extract_dates(df: pd.DataFrame):
 
 
 
 
565
  mapping = {c.lower(): c for c in df.columns}
566
  for name in ["date", "time", "timestamp"]:
567
  if name in mapping:
@@ -569,12 +576,7 @@ def _extract_dates(df: pd.DataFrame):
569
  return pd.to_datetime(df[mapping[name]]).to_numpy()
570
  except Exception:
571
  pass
572
- # If the CSV has a Date index, respect that
573
- if df.index.name is not None:
574
- try:
575
- return pd.to_datetime(df.index).to_numpy()
576
- except Exception:
577
- pass
578
  return np.arange(len(df))
579
 
580
  def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
@@ -584,8 +586,7 @@ def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = T
584
  rv = rv * np.sqrt(252.0)
585
  return rv.dropna().reset_index(drop=True)
586
 
587
- def bias_scale_calibration(y_true: np.ndarray, y_pred: np.ndarray) -> tuple[float, np.ndarray]:
588
- """Return alpha and calibrated predictions alpha * y_pred (MSE-optimal scaling)."""
589
  alpha = float(np.sum(y_true * y_pred) / (np.sum(y_pred**2) + EPS))
590
  return alpha, alpha * y_pred
591
 
@@ -602,22 +603,29 @@ def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> dict:
602
  # --------------------
603
  def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: bool):
604
  """
605
- tickers: comma/space separated (first is used for plotting/eval)
606
  start: YYYY-MM-DD
607
- interval: '1d', '1wk', '1mo' (yfinance-safe)
608
- use_calibration: whether to apply bias/scale calibration on the 30-day path
609
  """
610
- # parse first ticker
611
- tick_list = [t.strip().upper() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
612
  if not tick_list:
613
- raise gr.Error("Please enter at least one ticker (e.g., AAPL).")
614
  ticker = tick_list[0]
615
 
616
- # 1) Fetch/update CSV via your pipeline
617
- csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
 
 
 
618
 
619
  # 2) Load CSV and build realized vol
620
- df = pd.read_csv(csv_path, index_col=0, parse_dates=[0])
 
 
 
 
 
621
  dates = _extract_dates(df)
622
  close = _extract_close(df)
623
 
@@ -639,7 +647,7 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
639
  samples = fcst[0].cpu().numpy() # (1, H)
640
  path_pred = samples[0] # (H,)
641
 
642
- # 4) (Optional) bias/scale calibration
643
  alpha = None
644
  if use_calibration:
645
  alpha, path_pred_cal = bias_scale_calibration(rv_test, path_pred)
@@ -654,9 +662,8 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
654
  fig = plt.figure(figsize=(10, 4))
655
  H0 = len(rv_train)
656
 
657
- # choose proper x-axis
658
- if isinstance(dates, np.ndarray) and dates.shape[0] >= len(close):
659
- # Align dates to rv length (after rolling dropna)
660
  dates_rv = np.array(dates[-len(rv):])
661
  x_hist = dates_rv[:H0]
662
  x_fcst = dates_rv[H0:]
@@ -672,7 +679,7 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
672
  if use_calibration:
673
  plt.plot(x_fcst, path_pred_cal, linestyle="--", label=f"forecast (calibrated, α={alpha:.3f})")
674
 
675
- plt.title(f"{ticker} — Volatility Forecast (RV={RV_WINDOW}, H={H}, interval={interval})")
676
  plt.xlabel(x_lbl); plt.ylabel("realized volatility")
677
  plt.legend(loc="best"); plt.tight_layout()
678
 
@@ -692,7 +699,7 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
692
 
693
  # 7) JSON + metrics text
694
  out = {
695
- "ticker": ticker,
696
  "csv_path": csv_path,
697
  "config": {
698
  "start": start,
@@ -720,13 +727,13 @@ def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: boo
720
  with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
721
  gr.Markdown(
722
  "### Predict last 30 days of realized volatility for any ticker\n"
723
- "- Data fetched via **yfinance** (your `pipeline_v2.update_ticker_csv`).\n"
724
  "- Forecast uses **Chronos-T5-Large** (single path, no mean/median).\n"
725
- "- Compare day-by-day to actual RV and see **MAPE/MPE/RMSE**.\n"
726
- "- Optional **Bias/Scale Calibration (α)** to remove systematic under/overestimation."
727
  )
728
  with gr.Row():
729
- tickers_in = gr.Textbox(value="AAPL", label="Tickers (comma-separated, first is evaluated)")
730
  with gr.Row():
731
  start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
732
  interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
 
512
  #
513
 
514
 
515
+ # app.py
516
  import os, random
517
+ from typing import Tuple
518
  import numpy as np
519
  import pandas as pd
520
  import torch
 
524
  import matplotlib.pyplot as plt
525
  from chronos import ChronosPipeline
526
 
527
+ # --- our data pipeline ---
528
+ import pipeline_v2 as pipe2 # update_ticker_csv(...)
529
 
530
  # --------------------
531
  # Config
 
553
  # Helpers
554
  # --------------------
555
  def _extract_close(df: pd.DataFrame) -> pd.Series:
556
+ # Prefer 'Adj Close' > 'Close', else last numeric column
557
  mapping = {c.lower(): c for c in df.columns}
558
+ for name in ["adj close", "adj_close", "close", "price"]:
559
  if name in mapping:
560
  return pd.Series(df[mapping[name]]).astype(float)
561
+
562
  num_cols = df.select_dtypes(include=[np.number]).columns
563
  if len(num_cols) == 0:
564
+ raise gr.Error("Could not find a numeric price column (e.g., Close / Adj Close).")
565
  return pd.Series(df[num_cols[-1]]).astype(float)
566
 
567
  def _extract_dates(df: pd.DataFrame):
568
+ # If index is DatetimeIndex, use it
569
+ if isinstance(df.index, pd.DatetimeIndex):
570
+ return df.index.to_numpy()
571
+ # Else look for a date-like column
572
  mapping = {c.lower(): c for c in df.columns}
573
  for name in ["date", "time", "timestamp"]:
574
  if name in mapping:
 
576
  return pd.to_datetime(df[mapping[name]]).to_numpy()
577
  except Exception:
578
  pass
579
+ # Fallback to a simple range
 
 
 
 
 
580
  return np.arange(len(df))
581
 
582
  def compute_realized_vol(close: pd.Series, window: int = 20, annualize: bool = True) -> pd.Series:
 
586
  rv = rv * np.sqrt(252.0)
587
  return rv.dropna().reset_index(drop=True)
588
 
589
+ def bias_scale_calibration(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, np.ndarray]:
 
590
  alpha = float(np.sum(y_true * y_pred) / (np.sum(y_pred**2) + EPS))
591
  return alpha, alpha * y_pred
592
 
 
603
  # --------------------
604
  def run_for_ticker(tickers: str, start: str, interval: str, use_calibration: bool):
605
  """
606
+ tickers: comma/space separated; we use the FIRST for plotting/eval.
607
  start: YYYY-MM-DD
608
+ interval: '1d', '1wk', '1mo'
 
609
  """
610
+ # Parse first ticker
611
+ tick_list = [t.strip() for t in tickers.replace(";", ",").replace("|", ",").split(",") if t.strip()]
612
  if not tick_list:
613
+ raise gr.Error("Please enter at least one ticker, e.g. AAPL")
614
  ticker = tick_list[0]
615
 
616
+ # 1) Fetch/update CSV via pipeline
617
+ try:
618
+ csv_path = pipe2.update_ticker_csv(ticker, start=start, interval=interval)
619
+ except Exception as e:
620
+ raise gr.Error(f"Data fetch failed for '{ticker}': {e}")
621
 
622
  # 2) Load CSV and build realized vol
623
+ try:
624
+ df = pd.read_csv(csv_path, index_col=0, parse_dates=[0])
625
+ except Exception:
626
+ # Fallback if index parsing fails
627
+ df = pd.read_csv(csv_path)
628
+
629
  dates = _extract_dates(df)
630
  close = _extract_close(df)
631
 
 
647
  samples = fcst[0].cpu().numpy() # (1, H)
648
  path_pred = samples[0] # (H,)
649
 
650
+ # 4) Optional bias/scale calibration
651
  alpha = None
652
  if use_calibration:
653
  alpha, path_pred_cal = bias_scale_calibration(rv_test, path_pred)
 
662
  fig = plt.figure(figsize=(10, 4))
663
  H0 = len(rv_train)
664
 
665
+ # Align dates to rv length if we have real dates
666
+ if isinstance(dates, np.ndarray) and len(dates) >= len(close):
 
667
  dates_rv = np.array(dates[-len(rv):])
668
  x_hist = dates_rv[:H0]
669
  x_fcst = dates_rv[H0:]
 
679
  if use_calibration:
680
  plt.plot(x_fcst, path_pred_cal, linestyle="--", label=f"forecast (calibrated, α={alpha:.3f})")
681
 
682
+ plt.title(f"{ticker.upper()} — Volatility Forecast (RV={RV_WINDOW}, H={H}, interval={interval})")
683
  plt.xlabel(x_lbl); plt.ylabel("realized volatility")
684
  plt.legend(loc="best"); plt.tight_layout()
685
 
 
699
 
700
  # 7) JSON + metrics text
701
  out = {
702
+ "ticker": ticker.upper(),
703
  "csv_path": csv_path,
704
  "config": {
705
  "start": start,
 
727
  with gr.Blocks(title="Volatility Forecast • yfinance pipeline + Chronos") as demo:
728
  gr.Markdown(
729
  "### Predict last 30 days of realized volatility for any ticker\n"
730
+ "- Fetches data via **yfinance** using your `pipeline_v2.update_ticker_csv`.\n"
731
  "- Forecast uses **Chronos-T5-Large** (single path, no mean/median).\n"
732
+ "- Compares day-by-day to actual RV and reports **MAPE/MPE/RMSE**.\n"
733
+ "- Optional **Bias/Scale Calibration (α)** to remove systematic bias."
734
  )
735
  with gr.Row():
736
+ tickers_in = gr.Textbox(value="AAPL", label="Tickers (comma-separated; first is evaluated)")
737
  with gr.Row():
738
  start_in = gr.Textbox(value="2015-01-01", label="Start date (YYYY-MM-DD)")
739
  interval_in = gr.Dropdown(choices=["1d", "1wk", "1mo"], value="1d", label="Interval")
pipeline_v2.py CHANGED
@@ -1,189 +1,64 @@
 
1
  import os
2
- from datetime import timedelta
3
  import pandas as pd
4
- import yfinance as yf
5
 
6
- os.makedirs("data", exist_ok=True)
7
- CSV_TEMPLATE = "data/{ticker}_{interval}.csv"
 
 
 
 
8
 
9
- DEFAULT_START = "2015-01-01"
10
- DEFAULT_INTERVAL = "1d"
11
- DEFAULT_TICKERS = ["SPY", "QQQ", "AAPL", "MSFT", "NVDA", "NESN"]
12
- MAX_RETRIES = 3
13
 
14
- def download_ohlcv(ticker: str, start: str, interval: str, end: str = None) -> pd.DataFrame:
15
- print(f"[INFO] Downloading {ticker} from {start} (interval={interval}, end={end})")
16
- df = pd.DataFrame()
17
 
18
- for attempt in range(MAX_RETRIES):
19
- df = yf.download(
20
- ticker,
21
- start=start,
22
- end=end, # end is exclusive on yfinance
23
- interval=interval,
24
- auto_adjust=True,
25
- progress=False,
26
- threads=True,
27
- group_by="column", # helps avoid MultiIndex columns
28
- )
29
- if not df.empty:
30
- break
31
- if attempt < MAX_RETRIES - 1:
32
- print(f"[WARN] Empty response for {ticker}, retrying... ({attempt+1}/{MAX_RETRIES})")
33
 
34
- if df.empty:
35
- raise ValueError(f"No data returned for {ticker}")
36
 
37
- # --- NEW: collapse MultiIndex columns if present (single ticker) ---
38
- if isinstance(df.columns, pd.MultiIndex):
39
- # If levels are ['Price','Ticker'] or similar, drop the Ticker level
40
- level_names = list(df.columns.names) if df.columns.names else []
41
- if 'Ticker' in level_names:
42
- df = df.droplevel('Ticker', axis=1)
43
- else:
44
- # Drop the *second* level by default (the ticker is usually the last level)
45
- df = df.droplevel(-1, axis=1)
46
- # -----------------------------------------
47
 
48
- # Basic cleaning
49
- if interval not in ("1d", "1wk", "1mo"):
50
- df.index = pd.to_datetime(df.index, utc=True)
51
- # df.index = pd.to_datetime(df.index, utc=True) # ensure timezone # Only needed for smaller than 1d Intervals
52
- df = df[~df.index.duplicated(keep="last")] # drop duplicate timestamps
53
- df = df.sort_index() # ensure time order
54
-
55
- # standardize core columns if present
56
- cols = [c for c in ["Open","High","Low","Close","Adj Close","Volume"] if c in df.columns]
57
- df = df[cols] if cols else df
58
- if "Volume" in df.columns:
59
- df["Volume"] = pd.to_numeric(df["Volume"], errors="coerce").fillna(0).astype("int64", errors="ignore")
60
- return df
61
-
62
- def load_cached_csv(path: str) -> pd.DataFrame:
63
- if not os.path.exists(path):
64
- return pd.DataFrame()
65
- df = pd.read_csv(path, index_col=0, parse_dates=[0]) # Date index as datetime64[ns] (naive)
66
- # df.index = pd.to_datetime(df.index, utc=True)
67
- # tidy just in case
68
- df = df[~df.index.duplicated(keep="last")].sort_index()
69
- return df
70
-
71
-
72
- def next_start_from_cache(df_cached: pd.DataFrame) -> str:
73
- last_day = pd.to_datetime(df_cached.index.max()).date()
74
- return (last_day + timedelta(days=1)).isoformat()
75
-
76
- def drop_partial_today_daily(df: pd.DataFrame) -> pd.DataFrame:
77
- """
78
- For daily bars, optionally drop a partial 'today' row if the script runs before the session is complete.
79
- This is a policy choice—use it if you want your cache to only contain completed daily bars.
80
  """
81
- if df.empty:
82
- return df
83
- last_day = pd.to_datetime(df.index[-1]).date()
84
- today_utc = pd.Timestamp.utcnow().date()
85
- return df.iloc[:-1] if last_day >= today_utc else df
86
-
87
- def update_ticker_csv(ticker: str, start: str = "2015-01-01", interval: str = "1d") -> str:
88
- """
89
- Update (or create) a CSV cache for the ticker. Returns the CSV path.
90
  """
91
- out_path = CSV_TEMPLATE.format(ticker=ticker.upper(), interval=interval)
92
- cached = load_cached_csv(out_path)
93
-
94
- #if interval in ("1d", "1wk", "1mo"):
95
- # cached = drop_partial_today_daily(cached)
96
-
97
- # --- make fetch_start a date, not a string ---
98
- if cached.empty:
99
- fetch_start = pd.to_datetime(start).date()
100
- print(f"[INFO] No existing cache for {ticker}. Full download from {fetch_start}.")
101
- else:
102
- # next_start_from_cache currently returns a string -> parse to date
103
- fetch_start = pd.to_datetime(next_start_from_cache(cached)).date()
104
- print(f"[INFO] Found cache with {len(cached)} rows. Incremental from {fetch_start}.")
105
- # ---------------------------------------------
106
-
107
- # ----- NEW: avoid requesting future dates -----
108
- today_utc = pd.Timestamp.utcnow().date()
109
-
110
- if interval in ("1d", "1wk", "1mo"):
111
- # If fetch_start is in the future, there is nothing to fetch yet
112
- if fetch_start > today_utc:
113
- print(f"[OK] {ticker}: nothing to fetch yet (next trading day {fetch_start} > today {today_utc}).")
114
- df_new = pd.DataFrame(index=pd.DatetimeIndex([], name=cached.index.name or "Date"))
115
- else:
116
- # Optional: set an 'end' to be safe; yfinance's 'end' is exclusive, so add 1 day
117
- end_date = today_utc + pd.Timedelta(days=1)
118
- df_new = download_ohlcv(ticker, start=str(fetch_start), interval=interval, end=str(end_date))
119
- else:
120
- # Intraday: let 'now' be the implicit end
121
- df_new = download_ohlcv(ticker, start=str(fetch_start), interval=interval)
122
- # ----------------------------------------------
123
-
124
- if cached.empty and df_new.empty:
125
- raise ValueError(f"No data returned for {ticker}. Check ticker or start date.")
126
-
127
- if df_new.empty:
128
- print(f"[OK] {ticker}: no new rows to add.")
129
- merged = cached
130
  else:
131
- # merge, drop duplicates, sort
132
- merged = pd.concat([cached, df_new], axis=0)
133
- merged = merged[~merged.index.duplicated(keep="last")].sort_index()
134
- print(f"[OK] {ticker}: added {len(merged) - len(cached)} new rows.")
135
-
136
- # Optional: keep only completed daily bars
137
- #if interval in ("1d", "1wk", "1mo"):
138
- # merged = drop_partial_today_daily(merged)
139
-
140
- # Only drop partial 'today' if we fetched something new
141
- #fetched_any = not df_new.empty
142
-
143
- #if interval in ("1d", "1wk", "1mo") and fetched_any:
144
- # merged = drop_partial_today_daily(merged)
145
-
146
- #added = len(merged) - len(cached)
147
- #if added < 0:
148
- # Safety net (shouldn’t happen with the guard above)
149
- #added = 0
150
- # save
151
- merged.to_csv(out_path, date_format="%Y-%m-%d")
152
- added = len(merged) - len(cached)
153
- print(f"[OK] {ticker}: added {added} new row(s). Now {len(merged)} total.")
154
- print(f"[OK] Saved {ticker} → {out_path}")
155
-
156
- return out_path
157
-
158
- def update_many(
159
- tickers: str = DEFAULT_TICKERS,
160
- start: str = DEFAULT_START,
161
- interval: str = DEFAULT_INTERVAL,
162
- ) -> dict[str, str]:
163
- """
164
- Update multiple tickers; continue on errors.
165
- Returns dict[ticker] -> csv_path (or None if failed).
166
- """
167
- results: Dict[str, Optional[str]] = {}
168
- for t in [t.strip().upper() for t in tickers if t and t.strip()]:
169
- print("\n" + "=" * 60)
170
- print(f"[RUN] {t}")
171
- try:
172
- path = update_ticker_csv(t, start=start, interval=interval)
173
- results[t] = path
174
- except Exception as e:
175
- print(f"[ERR] {t}: {e}")
176
- results[t] = None
177
- print("\n" + "=" * 60)
178
- ok = sum(1 for v in results.values() if v)
179
- print(f"[SUMMARY] Completed {ok}/{len(results)} tickers.")
180
- return results
181
-
182
 
183
- if __name__ == "__main__":
184
- # choose your universe here (or later via CLI)
185
- TICKERS = DEFAULT_TICKERS
186
- START = DEFAULT_START
187
- INTERVAL = DEFAULT_INTERVAL
188
 
189
- update_many(TICKERS, start=START, interval=INTERVAL)
 
1
+ # pipeline_v2.py
2
  import os
3
+ from typing import Tuple
4
  import pandas as pd
 
5
 
6
+ try:
7
+ import yfinance as yf
8
+ except Exception as e:
9
+ raise ImportError(
10
+ "yfinance is not installed. Add `yfinance>=0.2.40` to requirements.txt."
11
+ ) from e
12
 
 
 
 
 
13
 
14
+ def _ensure_dir(path: str) -> None:
15
+ os.makedirs(path, exist_ok=True)
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def _sanitize_ticker(t: str) -> str:
19
+ return t.strip().upper().replace(" ", "").replace("/", "-").replace(".", "-")
20
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def update_ticker_csv(
23
+ ticker: str,
24
+ start: str = "2015-01-01",
25
+ interval: str = "1d",
26
+ dst_dir: str = "/mnt/data" # HF Spaces writeable path
27
+ ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  """
29
+ Download OHLCV for `ticker` using yfinance and save as CSV.
30
+ Returns the CSV file path.
31
+
32
+ Args:
33
+ ticker: e.g. "AAPL"
34
+ start: "YYYY-MM-DD"
35
+ interval: "1d", "1wk", "1mo"
36
+ dst_dir: directory to write CSVs (default: /mnt/data for Spaces)
 
37
  """
38
+ _ensure_dir(dst_dir)
39
+ tkr = _sanitize_ticker(ticker)
40
+
41
+ df = yf.download(
42
+ tkr,
43
+ start=start,
44
+ interval=interval,
45
+ auto_adjust=False, # keep explicit Adj Close; we’ll pick Close / Adj Close later
46
+ progress=False,
47
+ threads=True,
48
+ )
49
+
50
+ if df is None or df.empty:
51
+ raise ValueError(f"No data returned for ticker '{tkr}' with start={start}, interval={interval}.")
52
+
53
+ # Ensure a clean, single-index Date column
54
+ if isinstance(df.index, pd.DatetimeIndex):
55
+ df = df.copy()
56
+ df.index.name = "Date"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  else:
58
+ df = df.reset_index().rename(columns={df.columns[0]: "Date"}).set_index("Date")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Save
61
+ csv_path = os.path.join(dst_dir, f"{tkr}_{interval}.csv")
62
+ df.to_csv(csv_path)
 
 
63
 
64
+ return csv_path