Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import inspect | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import torch | |
| import plotly.graph_objects as go | |
| from chronos import Chronos2Pipeline | |
| MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2") | |
| DATA_DIR = "data" | |
| OUT_DIR = "/tmp" | |
| # ------------------------- | |
| # Data | |
| # ------------------------- | |
| def available_test_csv() -> List[str]: | |
| if not os.path.isdir(DATA_DIR): | |
| return [] | |
| return sorted([f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv")]) | |
| def pick_device(ui_choice: str) -> str: | |
| return "cuda" if (ui_choice or "").startswith("cuda") and torch.cuda.is_available() else "cpu" | |
| def make_sample_series(n: int, seed: int, trend: float, season_period: int, season_amp: float, noise: float) -> np.ndarray: | |
| rng = np.random.default_rng(int(seed)) | |
| t = np.arange(int(n), dtype=np.float32) | |
| y = (trend * t + season_amp * np.sin(2 * np.pi * t / max(1, int(season_period))) + rng.normal(0, noise, size=int(n))).astype(np.float32) | |
| if float(np.min(y)) < 0: | |
| y -= float(np.min(y)) | |
| return y | |
| def load_series_from_csv(csv_path: str, column: Optional[str]) -> Tuple[np.ndarray, str]: | |
| df = pd.read_csv(csv_path) | |
| col = (column or "").strip() | |
| if not col: | |
| numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] | |
| if not numeric_cols: | |
| # try coercion | |
| for c in df.columns: | |
| coerced = pd.to_numeric(df[c], errors="coerce") | |
| if coerced.notna().sum() > 0: | |
| numeric_cols.append(c) | |
| if not numeric_cols: | |
| raise ValueError("Non trovo colonne numeriche nel CSV.") | |
| col = numeric_cols[0] | |
| if col not in df.columns: | |
| raise ValueError(f"Colonna '{col}' non trovata. Disponibili: {list(df.columns)}") | |
| y = pd.to_numeric(df[col], errors="coerce").dropna().astype(np.float32).to_numpy() | |
| if len(y) < 10: | |
| raise ValueError("Serie troppo corta.") | |
| return y, col | |
| # ------------------------- | |
| # Model cache | |
| # ------------------------- | |
| _PIPE = None | |
| _META = {"model_id": None, "device": None} | |
| def get_pipeline(model_id: str, device: str) -> Chronos2Pipeline: | |
| global _PIPE, _META | |
| model_id = (model_id or MODEL_ID_DEFAULT).strip() | |
| device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu" | |
| if _PIPE is None or _META["model_id"] != model_id or _META["device"] != device: | |
| _PIPE = Chronos2Pipeline.from_pretrained(model_id, device_map=device) | |
| _META = {"model_id": model_id, "device": device} | |
| return _PIPE | |
| # ------------------------- | |
| # Predict (STABLE) | |
| # ------------------------- | |
| def _to_numpy(x: Any) -> np.ndarray: | |
| if isinstance(x, np.ndarray): | |
| return x | |
| if torch.is_tensor(x): | |
| return x.detach().cpu().numpy() | |
| return np.asarray(x) | |
| def _extract_samples(raw: Any) -> np.ndarray: | |
| if isinstance(raw, dict): | |
| for k in ["samples", "predictions", "prediction", "output"]: | |
| if k in raw: | |
| return _to_numpy(raw[k]) | |
| if len(raw) > 0: | |
| return _to_numpy(next(iter(raw.values()))) | |
| return np.asarray([], dtype=np.float32) | |
| return _to_numpy(raw) | |
| def chronos2_predict(pipe: Chronos2Pipeline, y: np.ndarray, horizon: int, requested_samples: int) -> Tuple[np.ndarray, bool, str]: | |
| """ | |
| Returns: | |
| samples: (S, H) | |
| multi: whether S>1 is real (not replicated) | |
| note: debug note | |
| """ | |
| sig = inspect.signature(pipe.predict) | |
| params = sig.parameters | |
| # input format: ALWAYS batch = [series] | |
| inputs = [y.tolist()] | |
| # kw for horizon | |
| horizon_kw = None | |
| for cand in ["prediction_length", "horizon", "steps", "n_steps", "pred_len"]: | |
| if cand in params: | |
| horizon_kw = cand | |
| break | |
| # kw for samples count (many versions don't have it!) | |
| sample_kw = None | |
| for cand in ["n_samples", "num_return_sequences", "num_samples"]: | |
| if cand in params: | |
| sample_kw = cand | |
| break | |
| kwargs: Dict[str, Any] = {} | |
| if horizon_kw: | |
| kwargs[horizon_kw] = int(horizon) | |
| else: | |
| # worst case: try positional horizon if supported (rare) | |
| kwargs["prediction_length"] = int(horizon) | |
| if sample_kw: | |
| kwargs[sample_kw] = int(requested_samples) | |
| # call | |
| raw = pipe.predict(inputs=inputs, **kwargs) if "inputs" in params else pipe.predict(inputs, **kwargs) | |
| arr = _extract_samples(raw).astype(np.float32, copy=False) | |
| # normalize shape -> (S,H) | |
| arr = np.squeeze(arr) | |
| if arr.ndim == 1: | |
| # could be (H,) or (S,) - assume horizon if length == H | |
| arr = arr[None, :] | |
| # Sometimes output is (B,S,H) or (B,H). If batch dim exists, take first | |
| if arr.ndim == 3: | |
| # assume (B,S,H) or (S,B,H); safest: pick first on axis=0 | |
| arr = arr[0] | |
| if arr.ndim == 1: | |
| arr = arr[None, :] | |
| # ensure horizon length | |
| if arr.shape[-1] != horizon: | |
| if arr.shape[-1] > horizon: | |
| arr = arr[..., :horizon] | |
| else: | |
| pad = horizon - arr.shape[-1] | |
| last = arr[..., -1:] | |
| arr = np.concatenate([arr, np.repeat(last, pad, axis=-1)], axis=-1) | |
| # If we got only 1 sample, we can still plot median but band is not meaningful | |
| real_multi = arr.shape[0] > 1 | |
| note = f"predict_signature={sig} | used_horizon_kw={horizon_kw} | used_sample_kw={sample_kw} | got_shape={tuple(arr.shape)}" | |
| return arr, real_multi, note | |
| # ------------------------- | |
| # Plotly | |
| # ------------------------- | |
| def plot_forecast(y, median, low, high, title, show_band: bool, band_label: str) -> go.Figure: | |
| t_hist = np.arange(len(y)) | |
| t_fcst = np.arange(len(y), len(y) + len(median)) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=t_hist, y=y, mode="lines", name="History")) | |
| fig.add_trace(go.Scatter(x=t_fcst, y=median, mode="lines", name="Forecast (median)")) | |
| fig.add_vline(x=len(y) - 1, line_width=1, line_dash="dash", opacity=0.6) | |
| if show_band: | |
| fig.add_trace(go.Scatter(x=t_fcst, y=high, mode="lines", line=dict(width=0), | |
| showlegend=False, hoverinfo="skip")) | |
| fig.add_trace(go.Scatter( | |
| x=t_fcst, y=low, mode="lines", fill="tonexty", | |
| line=dict(width=0), name=band_label | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| hovermode="x unified", | |
| margin=dict(l=10, r=10, t=55, b=10), | |
| legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0), | |
| xaxis_title="t", | |
| yaxis_title="value", | |
| ) | |
| return fig | |
| def kpi_card(label: str, value: str, hint: str = "") -> str: | |
| hint_html = f"<div style='opacity:.75;font-size:12px;margin-top:6px;'>{hint}</div>" if hint else "" | |
| return f""" | |
| <div style="border:1px solid rgba(255,255,255,.12); border-radius:16px; padding:14px 16px; | |
| background: rgba(255,255,255,.04);"> | |
| <div style="font-size:12px;opacity:.8;">{label}</div> | |
| <div style="font-size:22px;font-weight:700;margin-top:4px;">{value}</div> | |
| {hint_html} | |
| </div> | |
| """ | |
| def kpi_grid(cards: List[str]) -> str: | |
| return f"<div style='display:grid; grid-template-columns: repeat(6, minmax(0, 1fr)); gap:12px;'>{''.join(cards)}</div>" | |
| def explain(y, median, low, high, band_enabled: bool, q_low: float, q_high: float, extra: str) -> str: | |
| horizon = len(median) | |
| base = float(np.mean(y)) | |
| delta = float(median[-1] - median[0]) | |
| pct = (delta / max(1e-6, base)) * 100.0 | |
| if abs(pct) < 2: | |
| trend_txt = "sostanzialmente stabile" | |
| elif pct > 0: | |
| trend_txt = "in crescita" | |
| else: | |
| trend_txt = "in calo" | |
| txt = f""" | |
| ### 🧠 Spiegazione | |
| Nei prossimi **{horizon} step** la previsione mediana è **{trend_txt}** (variazione ≈ **{pct:+.1f}%** rispetto al livello medio storico). | |
| - **Ultimo valore mediano previsto:** **{median[-1]:.2f}** | |
| """ | |
| if band_enabled: | |
| txt += f"- **Banda [{q_low:.0%}–{q_high:.0%}] (ultimo step):** **[{low[-1]:.2f} – {high[-1]:.2f}]**\n" | |
| else: | |
| txt += "- **Banda di incertezza:** disattivata (questa versione di Chronos2 non restituisce campioni multipli con i parametri disponibili).\n" | |
| txt += f"\n<details><summary>Debug</summary>\n\n`{extra}`\n\n</details>\n" | |
| return txt | |
| # ------------------------- | |
| # Run | |
| # ------------------------- | |
| def run_all( | |
| input_mode, test_csv_name, upload_csv, csv_column, | |
| n, seed, trend, season_period, season_amp, noise, | |
| prediction_length, requested_samples, q_low, q_high, | |
| device_ui, model_id, | |
| ): | |
| if q_low >= q_high: | |
| raise gr.Error("Quantile low deve essere < quantile high.") | |
| device = pick_device(device_ui) | |
| pipe = get_pipeline(model_id, device) | |
| # data | |
| if input_mode == "Test CSV": | |
| if not test_csv_name: | |
| raise gr.Error("Seleziona un Test CSV.") | |
| path = os.path.join(DATA_DIR, test_csv_name) | |
| y, used_col = load_series_from_csv(path, csv_column) | |
| source = f"Test CSV: {test_csv_name} • col={used_col}" | |
| elif input_mode == "Upload CSV": | |
| if upload_csv is None: | |
| raise gr.Error("Carica un CSV.") | |
| y, used_col = load_series_from_csv(upload_csv.name, csv_column) | |
| source = f"Upload CSV • col={used_col}" | |
| else: | |
| y = make_sample_series(n, seed, trend, season_period, season_amp, noise) | |
| source = "Sample series" | |
| t0 = time.time() | |
| samples, real_multi, note = chronos2_predict(pipe, y, int(prediction_length), int(requested_samples)) | |
| latency = time.time() - t0 | |
| median = np.quantile(samples, 0.50, axis=0) | |
| band_enabled = real_multi and samples.shape[0] > 2 | |
| if band_enabled: | |
| low = np.quantile(samples, float(q_low), axis=0) | |
| high = np.quantile(samples, float(q_high), axis=0) | |
| else: | |
| low = median.copy() | |
| high = median.copy() | |
| # KPI | |
| cards = [ | |
| kpi_card("Device", device.upper(), f"cuda_available={torch.cuda.is_available()}"), | |
| kpi_card("Latency", f"{latency:.2f}s", "predict()"), | |
| kpi_card("Samples", str(samples.shape[0]), "returned by model"), | |
| kpi_card("Band", "ON" if band_enabled else "OFF", "needs multi-samples"), | |
| kpi_card("Horizon", str(prediction_length)), | |
| kpi_card("Model", (model_id or MODEL_ID_DEFAULT)), | |
| ] | |
| kpis_html = kpi_grid(cards) | |
| # Plot | |
| fig = plot_forecast( | |
| y=y, | |
| median=median, | |
| low=low, | |
| high=high, | |
| title=f"Forecast — {source}", | |
| show_band=band_enabled, | |
| band_label=f"Band [{q_low:.2f}, {q_high:.2f}]", | |
| ) | |
| # Table + export | |
| t_fcst = np.arange(len(y), len(y) + int(prediction_length)) | |
| out_df = pd.DataFrame({ | |
| "t": t_fcst, | |
| "median": median, | |
| }) | |
| if band_enabled: | |
| out_df[f"q{q_low:.2f}"] = low | |
| out_df[f"q{q_high:.2f}"] = high | |
| out_path = os.path.join(OUT_DIR, "chronos2_forecast.csv") | |
| out_df.to_csv(out_path, index=False) | |
| explanation_md = explain(y, median, low, high, band_enabled, q_low, q_high, note) | |
| info = { | |
| "source": source, | |
| "history_points": int(len(y)), | |
| "prediction_length": int(prediction_length), | |
| "requested_samples": int(requested_samples), | |
| "returned_samples": int(samples.shape[0]), | |
| "band_enabled": bool(band_enabled), | |
| "predict_signature": str(inspect.signature(pipe.predict)), | |
| "debug_note": note, | |
| } | |
| return kpis_html, explanation_md, fig, out_df, out_path, info | |
| # ------------------------- | |
| # UI | |
| # ------------------------- | |
| css = """.gradio-container { max-width: 1200px !important; }""" | |
| with gr.Blocks(title="Chronos-2 • Pro Dashboard (Stable)", css=css) as demo: | |
| gr.Markdown("# ⏱️ Chronos-2 Forecast Dashboard — Stable Edition") | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=360): | |
| input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Input") | |
| test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV (data/)") | |
| upload_csv = gr.File(label="Upload CSV", file_types=[".csv"]) | |
| csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value") | |
| device_ui = gr.Dropdown( | |
| ["cpu", "cuda (se disponibile)"], | |
| value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu", | |
| label="Device", | |
| ) | |
| model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID") | |
| with gr.Accordion("Sample generator", open=False): | |
| n = gr.Slider(60, 2000, value=300, step=10, label="History length") | |
| seed = gr.Number(value=42, precision=0, label="Seed") | |
| trend = gr.Slider(0.0, 0.2, value=0.03, step=0.005, label="Trend") | |
| season_period = gr.Slider(2, 240, value=14, step=1, label="Season period") | |
| season_amp = gr.Slider(0.0, 12.0, value=3.0, step=0.1, label="Season amplitude") | |
| noise = gr.Slider(0.0, 6.0, value=0.8, step=0.05, label="Noise") | |
| prediction_length = gr.Slider(1, 365, value=30, step=1, label="Prediction length") | |
| requested_samples = gr.Slider(1, 800, value=200, step=25, label="Requested samples (best effort)") | |
| q_low = gr.Slider(0.01, 0.49, value=0.10, step=0.01, label="Quantile low") | |
| q_high = gr.Slider(0.51, 0.99, value=0.90, step=0.01, label="Quantile high") | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Column(scale=2): | |
| kpis = gr.HTML() | |
| with gr.Tabs(): | |
| with gr.Tab("Forecast"): | |
| forecast_plot = gr.Plot() | |
| forecast_table = gr.Dataframe(interactive=False) | |
| with gr.Tab("Spiegazione"): | |
| explanation = gr.Markdown() | |
| with gr.Tab("Export"): | |
| download = gr.File() | |
| with gr.Tab("Info"): | |
| info = gr.JSON() | |
| run_btn.click( | |
| fn=run_all, | |
| inputs=[ | |
| input_mode, test_csv_name, upload_csv, csv_column, | |
| n, seed, trend, season_period, season_amp, noise, | |
| prediction_length, requested_samples, q_low, q_high, | |
| device_ui, model_id, | |
| ], | |
| outputs=[kpis, explanation, forecast_plot, forecast_table, download, info], | |
| ) | |
| demo.queue() | |
| demo.launch(ssr_mode=False) | |