Spaces:
Sleeping
Sleeping
| """Model loading + inference helpers for the HF Spaces app. | |
| Loads the Part 1 CNN-Transformer baseline (1.75 M params, 5.24 % MAPE | |
| on the 2022-12-30/31 self-eval slice) and runs forward on a synthetic | |
| weather tensor + real recent ISO-NE demand history. | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from models.cnn_transformer_baseline import CNNTransformerBaselineForecaster # noqa: E402 | |
| ZONE_COLS = ["ME", "NH", "VT", "CT", "RI", "SEMA", "WCMA", "NEMA_BOST"] | |
| N_ZONES = 8 | |
| CAL_DIM = 44 | |
| HISTORY_LEN = 24 | |
| FUTURE_LEN = 24 | |
| WEATHER_H, WEATHER_W, WEATHER_C = 450, 449, 7 | |
| def load_baseline(ckpt_path, device: str = "cpu"): | |
| """Load the trained baseline + its norm_stats from a single checkpoint.""" | |
| ckpt_path = Path(ckpt_path) | |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) | |
| args = ckpt.get("args", {}) | |
| model = CNNTransformerBaselineForecaster( | |
| n_weather_channels=WEATHER_C, | |
| n_zones=N_ZONES, | |
| cal_dim=CAL_DIM, | |
| history_len=args.get("history_len", HISTORY_LEN), | |
| embed_dim=args.get("embed_dim", 128), | |
| grid_size=args.get("grid_size", 8), | |
| n_layers=args.get("n_layers", 4), | |
| n_heads=args.get("n_heads", 4), | |
| dropout=args.get("dropout", 0.1), | |
| ) | |
| model.load_state_dict(ckpt["model"]) | |
| model = model.to(device).eval() | |
| norm_stats = ckpt.get("norm_stats") | |
| if norm_stats is None: | |
| ns_path = ckpt_path.parent / "norm_stats.pt" | |
| if ns_path.exists(): | |
| norm_stats = torch.load(ns_path, map_location=device, weights_only=False) | |
| else: | |
| raise RuntimeError( | |
| f"checkpoint {ckpt_path} missing norm_stats and no sibling norm_stats.pt" | |
| ) | |
| return model, norm_stats | |
| def normalize_demand(demand_mwh: np.ndarray, norm_stats: dict) -> np.ndarray: | |
| """(T, 8) MWh -> (T, 8) z-scored.""" | |
| mean = norm_stats["energy_mean"].cpu().numpy().reshape(-1) | |
| std = norm_stats["energy_std"].cpu().numpy().reshape(-1) | |
| return ((demand_mwh - mean) / std).astype(np.float32) | |
| def denormalize_demand(z: np.ndarray, norm_stats: dict) -> np.ndarray: | |
| mean = norm_stats["energy_mean"].cpu().numpy().reshape(-1) | |
| std = norm_stats["energy_std"].cpu().numpy().reshape(-1) | |
| return (z * std + mean).astype(np.float32) | |
| def normalize_weather(raster: np.ndarray, norm_stats: dict) -> np.ndarray: | |
| """(T, H, W, 7) raw HRRR -> (T, H, W, 7) z-scored using training stats. | |
| norm_stats stores per-channel mean/std as (1, 1, 1, 7) tensors. | |
| """ | |
| mean = norm_stats["weather_mean"].cpu().numpy().reshape(1, 1, 1, -1) | |
| std = norm_stats["weather_std"].cpu().numpy().reshape(1, 1, 1, -1) | |
| return ((raster - mean) / std).astype(np.float32) | |
| def synthetic_weather_z(history_len: int = HISTORY_LEN, | |
| future_len: int = FUTURE_LEN) -> np.ndarray: | |
| """Return a (S+24, H, W, C) array of zeros (training-mean weather | |
| in z-score space). Kept as a fallback when the live HRRR fetcher | |
| fails (e.g. no network, S3 outage); the model is degraded but still | |
| produces calibrated output from demand + calendar.""" | |
| return np.zeros((history_len + future_len, WEATHER_H, WEATHER_W, WEATHER_C), | |
| dtype=np.float32) | |
| def run_forecast(model: torch.nn.Module, | |
| hist_demand_mwh: np.ndarray, | |
| hist_cal: np.ndarray, | |
| future_cal: np.ndarray, | |
| norm_stats: dict, | |
| hist_weather_raw: np.ndarray, | |
| future_weather_raw: np.ndarray, | |
| device: str = "cpu") -> np.ndarray: | |
| """Run the baseline forecast. | |
| Args: | |
| hist_demand_mwh: (24, 8) recent ISO-NE per-zone demand in MWh. | |
| hist_cal: (24, 44) calendar features for the history window. | |
| future_cal: (24, 44) calendar features for the next 24 h. | |
| hist_weather_raw: (24, 450, 449, 7) RAW HRRR f00 analyses for the | |
| history window. Will be z-scored internally. | |
| future_weather_raw: (24, 450, 449, 7) RAW HRRR f01..f24 forecasts | |
| (or analyses, if available) for the future | |
| window. Will be z-scored internally. | |
| Returns: | |
| (24, 8) forecast in MWh. | |
| """ | |
| if hist_weather_raw.shape != (HISTORY_LEN, WEATHER_H, WEATHER_W, WEATHER_C): | |
| raise ValueError( | |
| f"hist_weather_raw shape {hist_weather_raw.shape} != " | |
| f"({HISTORY_LEN}, {WEATHER_H}, {WEATHER_W}, {WEATHER_C})") | |
| if future_weather_raw.shape != (FUTURE_LEN, WEATHER_H, WEATHER_W, WEATHER_C): | |
| raise ValueError( | |
| f"future_weather_raw shape {future_weather_raw.shape} != " | |
| f"({FUTURE_LEN}, {WEATHER_H}, {WEATHER_W}, {WEATHER_C})") | |
| hist_w_z = normalize_weather(hist_weather_raw, norm_stats) | |
| fut_w_z = normalize_weather(future_weather_raw, norm_stats) | |
| hist_w = torch.from_numpy(hist_w_z).unsqueeze(0).to(device) | |
| fut_w = torch.from_numpy(fut_w_z).unsqueeze(0).to(device) | |
| hist_y_z = normalize_demand(hist_demand_mwh, norm_stats) | |
| hist_y = torch.from_numpy(hist_y_z).unsqueeze(0).to(device) | |
| hist_c = torch.from_numpy(hist_cal.astype(np.float32)).unsqueeze(0).to(device) | |
| fut_c = torch.from_numpy(future_cal.astype(np.float32)).unsqueeze(0).to(device) | |
| pred_z = model(hist_w, hist_y, hist_c, fut_w, fut_c) # (1, 24, 8) z-space | |
| pred_mwh = denormalize_demand(pred_z.squeeze(0).cpu().numpy(), norm_stats) | |
| return pred_mwh | |
| # ===================================================================== | |
| # Foundation-model ensemble (Chronos-Bolt-mini, zero-shot) | |
| # ===================================================================== | |
| # | |
| # Per Table 10 of the report, chronos-bolt-mini (21 M params) gives the | |
| # best per-zone ensemble (4.21 % test MAPE) on the 2-day 2022 self-eval | |
| # slice — actually slightly better than chronos-bolt-base (205 M, 4.33 %). | |
| # Smaller weights => faster cold start + lower memory on the HF Spaces | |
| # free tier (16 GB RAM, 2 vCPU). We hard-code the per-zone alpha that the | |
| # offline grid search returned for the mini variant: | |
| # | |
| # alpha[z] = weight on the BASELINE prediction for zone z; | |
| # (1 - alpha[z]) goes to the Chronos zero-shot prediction. | |
| # | |
| # Higher alpha => baseline dominates (good for small, weather-driven zones | |
| # like ME / NH / VT). alpha = 0 => baseline is dropped entirely (good for | |
| # the dense urban-coastal zones CT / SEMA / NEMA_BOST that Chronos models | |
| # better from demand history alone). | |
| CHRONOS_MODEL_CARD = "amazon/chronos-bolt-mini" | |
| CHRONOS_CONTEXT = 672 # 4 weeks of hourly history per zone | |
| CHRONOS_QUANTILE = 0.5 # use median for the point forecast | |
| ALPHA_PER_ZONE_MINI = { | |
| "ME": 0.30, | |
| "NH": 0.30, | |
| "VT": 0.80, | |
| "CT": 0.00, | |
| "RI": 0.10, | |
| "SEMA": 0.00, | |
| "WCMA": 0.05, | |
| "NEMA_BOST": 0.00, | |
| } | |
| def load_chronos(model_card: str = CHRONOS_MODEL_CARD, device: str = "cpu"): | |
| """Load Chronos-Bolt pipeline (lazy import so baseline-only path doesn't need | |
| chronos-forecasting installed at module-load time).""" | |
| from chronos import BaseChronosPipeline # noqa: WPS433 | |
| pipeline = BaseChronosPipeline.from_pretrained( | |
| model_card, device_map=device, torch_dtype=torch.float32, | |
| ) | |
| return pipeline | |
| def run_chronos_zeroshot(pipeline, | |
| hist_demand_mwh_long: np.ndarray) -> np.ndarray: | |
| """Run Chronos-Bolt zero-shot for a 24-h forecast on each of the 8 zones | |
| independently. | |
| Args: | |
| hist_demand_mwh_long: (T, 8) per-zone demand history in MWh, with | |
| T >= CHRONOS_CONTEXT. Only the last CHRONOS_CONTEXT rows are used; | |
| if shorter, we pad by repeating the earliest available sample | |
| (same fallback the baseline uses when the live API is short). | |
| Returns: | |
| (24, 8) zero-shot median forecast in MWh. | |
| """ | |
| T, n_zones = hist_demand_mwh_long.shape | |
| if T < CHRONOS_CONTEXT: | |
| # Pad by repeating the first available row at the front. | |
| pad = np.repeat(hist_demand_mwh_long[:1], CHRONOS_CONTEXT - T, axis=0) | |
| hist_demand_mwh_long = np.concatenate([pad, hist_demand_mwh_long], axis=0) | |
| ctx = hist_demand_mwh_long[-CHRONOS_CONTEXT:] # (672, 8) | |
| ctx_tensor = torch.from_numpy(ctx.T).to(torch.float32) # (8, 672) | |
| quantiles, _mean = pipeline.predict_quantiles( | |
| context=ctx_tensor, | |
| prediction_length=FUTURE_LEN, | |
| quantile_levels=[CHRONOS_QUANTILE], | |
| ) | |
| # quantiles: (8 zones, 24 hours, 1 quantile) -> (24, 8) | |
| median = quantiles[:, :, 0].cpu().numpy().T # (24, 8) | |
| return median.astype(np.float32) | |
| def per_zone_ensemble(baseline_mwh: np.ndarray, | |
| chronos_mwh: np.ndarray, | |
| alpha_per_zone: dict[str, float] = ALPHA_PER_ZONE_MINI) -> np.ndarray: | |
| """Late-fusion ensemble: | |
| y_ens[h, z] = alpha[z] * y_baseline[h, z] + (1 - alpha[z]) * y_chronos[h, z] | |
| """ | |
| alpha = np.array([alpha_per_zone[z] for z in ZONE_COLS], dtype=np.float32) | |
| return alpha[None, :] * baseline_mwh + (1 - alpha[None, :]) * chronos_mwh | |