Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| from chronos import BaseChronosPipeline | |
| st.set_page_config(page_title="Chronos-Bolt Zero-Shot Forecast", layout="centered") | |
| st.title("Chronos-Bolt Zero-Shot Forecast") | |
| st.caption("Zero-shot probabilistic forecasting (q10/q50/q90) using amazon/chronos-bolt-* models.") | |
| # -------------------- Indicator helpers (no pandas-ta needed) -------------------- | |
| def ema(series, length=20): | |
| s = pd.Series(series).astype("float64") | |
| return s.ewm(span=length, adjust=False).mean() | |
| def rsi(series, length=14): | |
| s = pd.Series(series).astype("float64") | |
| delta = s.diff() | |
| gain = delta.clip(lower=0).ewm(alpha=1/length, adjust=False).mean() | |
| loss = (-delta.clip(upper=0)).ewm(alpha=1/length, adjust=False).mean() | |
| rs = gain / loss.replace(0, np.nan) | |
| return 100 - (100 / (1 + rs)) | |
| def stochastic_kd(high, low, close, k=14, d=3, smooth_k=3): | |
| h = pd.Series(high).astype("float64") | |
| l = pd.Series(low).astype("float64") | |
| c = pd.Series(close).astype("float64") | |
| hh = h.rolling(k).max() | |
| ll = l.rolling(k).min() | |
| raw_k = 100 * (c - ll) / (hh - ll) | |
| k_smoothed = raw_k.rolling(smooth_k).mean() | |
| d_line = k_smoothed.rolling(d).mean() | |
| return k_smoothed, d_line | |
| # -------------------- Model options -------------------- | |
| MODEL_CHOICES = { | |
| "Bolt Mini (CPU-friendly)": "amazon/chronos-bolt-mini", | |
| "Bolt Small (better; GPU if available)": "amazon/chronos-bolt-small", | |
| } | |
| def load_pipeline(model_id: str): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| return BaseChronosPipeline.from_pretrained(model_id, device_map=device, torch_dtype=dtype) | |
| # -------------------- Data loaders (always return 1-D) -------------------- | |
| def _force_1d(a): | |
| a = pd.Series(a, dtype="float32").replace([np.inf, -np.inf], np.nan).dropna() | |
| return a.to_numpy().reshape(-1) | |
| def load_ticker_series(ticker: str, period: str = "2y"): | |
| import yfinance as yf | |
| df = yf.download(ticker, period=period, interval="1d", auto_adjust=True, progress=False) | |
| if df.empty: | |
| return np.asarray([], dtype="float32") | |
| close = df["Close"] | |
| if isinstance(close, pd.DataFrame): # handle rare multi-index cases | |
| close = close.iloc[:, 0] | |
| return _force_1d(close) | |
| def parse_pasted_series(txt: str): | |
| import re | |
| toks = re.split(r"[,\s]+", txt.strip()) | |
| vals = [] | |
| for t in toks: | |
| if not t: | |
| continue | |
| try: | |
| vals.append(float(t)) | |
| except: | |
| pass | |
| return _force_1d(vals) | |
| def load_csv_series(file, column=None): | |
| df = pd.read_csv(file) | |
| if column is None: | |
| num_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)] | |
| column = num_cols[0] if num_cols else None | |
| if column is None: | |
| return np.asarray([], dtype="float32"), df, None | |
| return _force_1d(df[column]), df, column | |
| # -------------------- UI -------------------- | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| model_label = st.selectbox("Model", list(MODEL_CHOICES.keys()), index=0) | |
| with c2: | |
| pred_len = st.number_input("Prediction length (steps)", 1, 365, 30) | |
| src = st.radio("Data source", ["Ticker (yfinance)", "Paste numbers", "Upload CSV"], horizontal=True) | |
| series = None | |
| if src == "Ticker (yfinance)": | |
| t1, t2 = st.columns([2, 1]) | |
| with t1: | |
| ticker = st.text_input("Ticker (e.g., AAPL, SPY, BTC-USD)", "AAPL") | |
| with t2: | |
| period = st.selectbox("History window", ["6mo", "1y", "2y", "5y"], index=2) | |
| if st.button("Load data"): | |
| series = load_ticker_series(ticker.strip(), period) | |
| if series.size == 0: | |
| st.error("No data returned. Try another ticker/window.") | |
| elif src == "Paste numbers": | |
| txt = st.text_area("One value per line (or comma/space separated)", "1\n2\n3\n4\n5\n6\n7\n8\n9\n10") | |
| if st.button("Use pasted data"): | |
| series = parse_pasted_series(txt) | |
| else: | |
| uploaded = st.file_uploader("Upload CSV", type=["csv"]) | |
| if uploaded is not None: | |
| df = pd.read_csv(uploaded) | |
| numeric_cols = [c for c in df.columns if np.issubdtype(df[c].dtype, np.number)] | |
| col = st.selectbox("Pick numeric column", numeric_cols) if numeric_cols else None | |
| if st.button("Load CSV column") and col: | |
| series, _, _ = load_csv_series(uploaded, column=col) | |
| elif uploaded and not numeric_cols: | |
| st.error("No numeric columns found in CSV.") | |
| # -------------------- Plot + Forecast -------------------- | |
| if series is not None and series.size > 5: | |
| st.write(f"Loaded {series.size} points.") | |
| st.line_chart(pd.DataFrame(series, columns=["value"])) # always 1-D -> no error | |
| if st.button("Forecast"): | |
| with st.spinner("Running Chronos-Bolt..."): | |
| pipe = load_pipeline(MODEL_CHOICES[model_label]) | |
| ctx = torch.tensor(series, dtype=torch.float32) | |
| q_levels = [0.10, 0.50, 0.90] | |
| quantiles, mean = pipe.predict_quantiles( | |
| context=ctx, | |
| prediction_length=int(pred_len), | |
| quantile_levels=q_levels, | |
| ) | |
| q_np = quantiles[0].cpu().numpy() # shape [pred_len, 3] | |
| lo, med, hi = q_np[:, 0], q_np[:, 1], q_np[:, 2] | |
| import matplotlib.pyplot as plt | |
| hist_x = np.arange(len(series)) | |
| fut_x = np.arange(len(series), len(series) + int(pred_len)) | |
| fig = plt.figure(figsize=(9, 4.5)) | |
| plt.plot(hist_x, series, label="history") | |
| plt.plot(fut_x, med, label="median forecast") | |
| plt.fill_between(fut_x, lo, hi, alpha=0.3, label="q10–q90 band") | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| st.pyplot(fig) | |
| out = pd.DataFrame({"t": fut_x, "q10": lo, "q50": med, "q90": hi}) | |
| st.download_button( | |
| "Download forecast CSV", | |
| out.to_csv(index=False).encode("utf-8"), | |
| file_name="chronos_forecast.csv", | |
| mime="text/csv", | |
| ) | |
| else: | |
| st.info("Load a ticker, paste values, or upload a CSV to begin.") | |
| # ================================ | |
| # Train with RSI / EMA / Stochastic (AutoGluon) — no pandas-ta | |
| # ================================ | |
| with st.expander("Train with Indicators (RSI, EMA, Stochastic)"): | |
| st.write("Fine-tune Chronos-Bolt on one ticker using indicator covariates (past-only).") | |
| tcol1, tcol2, tcol3 = st.columns([2, 1, 1]) | |
| with tcol1: | |
| ft_ticker = st.text_input("Ticker", "SPY") | |
| with tcol3: | |
| ft_interval = st.selectbox("Interval", ["1d", "60m", "30m", "15m"], index=0) | |
| # Allowed lookbacks depend on interval | |
| if ft_interval == "1d": | |
| allowed_periods = ["6mo", "1y", "2y", "5y"] | |
| default_idx = 2 | |
| else: | |
| allowed_periods = ["5d", "30d", "60d"] | |
| default_idx = 1 | |
| with tcol2: | |
| ft_period = st.selectbox("Lookback", allowed_periods, index=default_idx) | |
| ft_steps = st.slider("Fine-tune steps", 100, 1500, 300, step=50) | |
| run_ft = st.button("Train fine-tuned model") | |
| if run_ft: | |
| with st.spinner("Downloading & computing indicators…"): | |
| import yfinance as yf | |
| from autogluon.timeseries import TimeSeriesPredictor, TimeSeriesDataFrame | |
| # 1) Load OHLC so we can compute Stochastic (needs High/Low/Close) | |
| df = yf.download( | |
| ft_ticker.strip(), | |
| period=ft_period, | |
| interval=ft_interval, | |
| auto_adjust=True, | |
| progress=False, | |
| ) | |
| # Fallback: if the chosen combo is too long for intraday, clamp and retry | |
| if df.empty: | |
| alt_period = "60d" if ft_interval != "1d" else "1y" | |
| if alt_period != ft_period: | |
| df = yf.download( | |
| ft_ticker.strip(), | |
| period=alt_period, | |
| interval=ft_interval, | |
| auto_adjust=True, | |
| progress=False, | |
| ) | |
| if df.empty: | |
| st.error("No data returned. Try a shorter lookback for intraday (e.g., 30d/60d) or use Interval=1d.") | |
| st.stop() | |
| # Determine frequency alias for AutoGluon and ensure tz-naive index | |
| freq_alias = {"1d": "B", "60m": "60min", "30m": "30min", "15m": "15min"}.get(ft_interval, "B") | |
| df.index = pd.DatetimeIndex(df.index).tz_localize(None) | |
| # Handle MultiIndex columns (yfinance can return 2-level columns) | |
| if isinstance(df.columns, pd.MultiIndex): | |
| try: | |
| sym = df.columns.get_level_values(1).unique()[0] | |
| df = df.xs(sym, axis=1, level=1) | |
| except Exception: | |
| # Fallback: flatten by taking the top-level name (Close/High/Low) | |
| df.columns = [c[0] for c in df.columns.to_flat_index()] | |
| # Keep only needed cols | |
| df = df[["Close", "High", "Low"]].copy() | |
| # Ensure each column is 1-D (avoid (N,1) arrays) | |
| for _c in ["Close", "High", "Low"]: | |
| if isinstance(df[_c], pd.DataFrame): | |
| df[_c] = df[_c].iloc[:, 0] | |
| df[_c] = pd.Series(np.asarray(df[_c]).reshape(-1), index=df.index) | |
| df = df.dropna() | |
| # 2) Indicators (helpers above) | |
| df["rsi14"] = rsi(df["Close"], 14) | |
| df["ema20"] = ema(df["Close"], 20) | |
| df["stoch_k"], df["stoch_d"] = stochastic_kd(df["High"], df["Low"], df["Close"], 14, 3, 3) | |
| df = df.dropna().astype("float32") | |
| if df.shape[0] < 200: | |
| st.warning("Very short history after indicators; results may be noisy.") | |
| # 3) Build TimeSeriesDataFrame (target + past covariates) | |
| ts = df[["Close", "rsi14", "ema20", "stoch_k", "stoch_d"]].copy() | |
| ts["item_id"] = ft_ticker.upper() | |
| ts["timestamp"] = ts.index | |
| ts = ts.rename(columns={"Close": "target"}) | |
| tsdf = TimeSeriesDataFrame.from_data_frame( | |
| ts, id_column="item_id", timestamp_column="timestamp" | |
| ) | |
| # Ensure a regular time grid for AutoGluon | |
| try: | |
| tsdf = tsdf.convert_frequency(freq=freq_alias) | |
| except Exception: | |
| pass | |
| with st.spinner("Fine-tuning Chronos-Bolt (small demo)…"): | |
| # Chronos-Bolt preset via hyperparameters; fine_tune on CPU is OK for small steps | |
| predictor = TimeSeriesPredictor( | |
| prediction_length=int(pred_len), # reuse your UI's pred_len | |
| eval_metric="WQL", | |
| quantile_levels=[0.1, 0.5, 0.9], | |
| freq=freq_alias, | |
| ).fit( | |
| train_data=tsdf, | |
| enable_ensemble=False, | |
| time_limit=300, # small demo budget; increase offline/GPU | |
| hyperparameters={ | |
| "Chronos": { | |
| "model_path": "bolt_mini", # CPU-friendly; try 'bolt_small' on GPU | |
| "fine_tune": True, | |
| "fine_tune_steps": int(ft_steps), | |
| "fine_tune_lr": 1e-5, | |
| } | |
| }, | |
| ) | |
| # 4) Forecast with the fine-tuned model | |
| preds = predictor.predict(tsdf) # AG starts at series end | |
| item = ft_ticker.upper() | |
| yhist = tsdf.loc[item]["target"].to_numpy() | |
| ypred = preds.loc[item] # MultiIndex -> rows for horizon | |
| lo = ypred["0.1"].to_numpy() | |
| med = ypred["0.5"].to_numpy() | |
| hi = ypred["0.9"].to_numpy() | |
| import matplotlib.pyplot as plt | |
| hx = np.arange(len(yhist)) | |
| fx = np.arange(len(yhist), len(yhist) + len(med)) | |
| fig = plt.figure(figsize=(9, 4.5)) | |
| plt.plot(hx, yhist, label="history") | |
| plt.plot(fx, med, label="median (fine-tuned)") | |
| plt.fill_between(fx, lo, hi, alpha=0.3, label="q10–q90") | |
| plt.legend(); plt.grid(True, alpha=0.3) | |
| st.pyplot(fig) | |
| out = pd.DataFrame({"t": fx, "q10": lo, "q50": med, "q90": hi}) | |
| st.download_button( | |
| "Download fine-tuned forecast CSV", | |
| out.to_csv(index=False).encode("utf-8"), | |
| file_name=f"{item}_chronos_finetuned.csv", | |
| mime="text/csv", | |
| ) |