weather_predict / model_utils.py
jeffliulab's picture
Add WeatherViT model (7.4M params, best rain AUC 0.776)
2507ab5
"""
Model loading and inference utilities for the weather forecast demo.
Wraps the existing inference/predict.py logic, adding user-friendly
post-processing (Celsius, wind speed/direction, rain likelihood).
"""
import math
import sys
from pathlib import Path
import numpy as np
import torch
# In HF Space, models/ is in the same directory as this file
PROJECT_ROOT = Path(__file__).resolve().parent
sys.path.insert(0, str(PROJECT_ROOT))
from models import create_model, get_model_defaults
# ── Model cache (loaded once, reused across requests) ──────────────────
_model_cache: dict = {}
TARGET_VARS = [
("TMP@2m_above_ground", "Temperature (2m)", "K"),
("RH@2m_above_ground", "Relative Humidity", "%"),
("UGRD@10m_above_ground", "U-Wind (10m)", "m/s"),
("VGRD@10m_above_ground", "V-Wind (10m)", "m/s"),
("GUST@surface", "Wind Gust", "m/s"),
("APCP_1hr_acc_fcst@surface", "Precipitation (1hr)", "mm"),
]
# Available models with display info
AVAILABLE_MODELS = {
"cnn_baseline": {
"display_name": "CNN Baseline",
"checkpoint": "checkpoints/cnn_baseline.pt",
"params": "11.3M",
},
"resnet18": {
"display_name": "ResNet-18",
"checkpoint": "checkpoints/resnet18.pt",
"params": "11.2M",
},
"vit": {
"display_name": "WeatherViT",
"checkpoint": "checkpoints/vit.pt",
"params": "7.4M",
},
}
def load_model(model_name: str, device: str = "cpu"):
"""
Load a trained model from checkpoint. Caches in memory for reuse.
Returns:
(model, norm_stats) tuple
"""
if model_name in _model_cache:
return _model_cache[model_name]
ckpt_path = PROJECT_ROOT / AVAILABLE_MODELS[model_name]["checkpoint"]
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
args = ckpt["args"]
ckpt_model_name = args["model"]
defaults = get_model_defaults(ckpt_model_name)
n_frames = args.get("n_frames") or defaults["n_frames"]
model_kwargs = {
"n_input_channels": 42,
"n_targets": 6,
"base_channels": args.get("base_channels", 64),
}
if n_frames > 1:
model_kwargs["n_frames"] = n_frames
model = create_model(ckpt_model_name, **model_kwargs)
model.load_state_dict(ckpt["model"])
model.to(device).eval()
norm_stats = ckpt.get("norm_stats")
_model_cache[model_name] = (model, norm_stats)
return model, norm_stats
def predict_raw(model, norm_stats, input_array: np.ndarray, device: str = "cpu") -> np.ndarray:
"""
Run inference on a (450, 449, 42) input array.
Returns:
np.ndarray of shape (6,) with denormalized physical values.
"""
x = torch.from_numpy(input_array).float()
x = x.permute(2, 0, 1).unsqueeze(0) # (1, 42, 450, 449)
if norm_stats:
mean = norm_stats["input_mean"]
std = norm_stats["input_std"]
# Ensure correct device
if isinstance(mean, torch.Tensor):
mean = mean.float()
std = std.float()
x = (x - mean) / (std + 1e-7)
x = x.to(device)
with torch.no_grad():
pred = model(x).squeeze(0).cpu() # (6,)
if norm_stats:
target_mean = norm_stats["target_mean"]
target_std = norm_stats["target_std"]
if isinstance(target_mean, torch.Tensor):
target_mean = target_mean.float()
target_std = target_std.float()
pred = pred * target_std + target_mean
return pred.numpy()
def _wind_direction_str(degrees: float) -> str:
"""Convert wind direction in degrees to compass string."""
dirs = ["N", "NNE", "NE", "ENE", "E", "ESE", "SE", "SSE",
"S", "SSW", "SW", "WSW", "W", "WNW", "NW", "NNW"]
idx = round(degrees / 22.5) % 16
return dirs[idx]
def format_forecast(pred: np.ndarray) -> dict:
"""
Convert raw model output (6 physical values) into a user-friendly forecast dict.
"""
temp_k = float(pred[0])
rh = float(pred[1])
u_wind = float(pred[2])
v_wind = float(pred[3])
gust = float(pred[4])
apcp = float(pred[5])
# Derived quantities
temp_c = temp_k - 273.15
temp_f = temp_c * 9 / 5 + 32
wind_speed = math.sqrt(u_wind**2 + v_wind**2)
# Meteorological wind direction: direction FROM which wind blows
wind_dir_deg = (math.degrees(math.atan2(-u_wind, -v_wind)) + 360) % 360
wind_dir_str = _wind_direction_str(wind_dir_deg)
# Rain likelihood based on APCP threshold
apcp = max(apcp, 0.0) # Clamp negative predictions
if apcp > 5.0:
rain_str = "Heavy Rain Likely"
elif apcp > 2.0:
rain_str = "Rain Likely"
elif apcp > 0.5:
rain_str = "Light Rain Possible"
else:
rain_str = "No Rain Expected"
return {
"temperature_k": temp_k,
"temperature_c": temp_c,
"temperature_f": temp_f,
"humidity_pct": max(0.0, min(100.0, rh)),
"u_wind_ms": u_wind,
"v_wind_ms": v_wind,
"wind_speed_ms": wind_speed,
"wind_dir_deg": wind_dir_deg,
"wind_dir_str": wind_dir_str,
"gust_ms": max(gust, 0.0),
"precipitation_mm": apcp,
"rain_status": rain_str,
}
def run_forecast(model_name: str, input_array: np.ndarray, device: str = "cpu") -> dict:
"""Full pipeline: load model β†’ predict β†’ format results."""
model, norm_stats = load_model(model_name, device)
pred = predict_raw(model, norm_stats, input_array, device)
return format_forecast(pred)