File size: 14,312 Bytes
852fd6f 94db2b6 4cad9bd 94db2b6 113176c 852fd6f 4cad9bd 94db2b6 4cad9bd 852fd6f 4cad9bd 94db2b6 4cad9bd 94db2b6 4cad9bd 852fd6f 94db2b6 4cad9bd 94db2b6 4cad9bd 113176c 4cad9bd 94db2b6 4cad9bd 94db2b6 4cad9bd 94db2b6 4cad9bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 |
# space/tools/ts_forecast_tool.py
import os
import logging
from typing import Optional, Dict
import torch
import pandas as pd
import numpy as np
from utils.tracing import Tracer
from utils.config import AppConfig
from transformers import AutoModel, AutoConfig
logger = logging.getLogger(__name__)
# Constants
MIN_SERIES_LENGTH = 2
MAX_SERIES_LENGTH = 10000
MIN_HORIZON = 1
MAX_HORIZON = 365
DEFAULT_MODEL_ID = "ibm-granite/granite-timeseries-ttm-r1"
class ForecastToolError(Exception):
"""Custom exception for forecast tool errors."""
pass
class TimeseriesForecastTool:
"""
Lightweight wrapper around Granite Time Series models for zero-shot forecasting.
This wrapper:
- Loads the model with AutoModel.from_pretrained
- Validates input series and horizon
- Attempts multiple inference methods (predict, forward with prediction_length)
- Returns a Pandas DataFrame with forecast column
- Provides comprehensive error handling and logging
Expected input:
- series: pd.Series with DatetimeIndex (regular frequency recommended)
- horizon: int, number of future steps to forecast
"""
def __init__(
self,
cfg: Optional[AppConfig],
tracer: Optional[Tracer],
model_id: str = DEFAULT_MODEL_ID,
device: Optional[str] = None,
):
self.cfg = cfg
self.tracer = tracer
self.model_id = model_id
self.model = None
self.config = None
# Determine device
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"TimeseriesForecastTool initialized with device: {self.device}")
# Lazy loading - model loaded on first use
self._initialized = False
def _ensure_loaded(self):
"""Lazy load the model and configuration."""
if self._initialized:
return
try:
logger.info(f"Loading Granite time series model: {self.model_id}")
# Load configuration
try:
self.config = AutoConfig.from_pretrained(self.model_id)
logger.info(f"Model config loaded: {type(self.config).__name__}")
except Exception as e:
logger.warning(f"Could not load model config: {e}")
self.config = None
# Load model
try:
self.model = AutoModel.from_pretrained(
self.model_id,
trust_remote_code=True # Required for some custom models
)
self.model.to(self.device)
self.model.eval()
logger.info(f"Model loaded successfully: {type(self.model).__name__}")
except Exception as e:
raise ForecastToolError(
f"Failed to load model '{self.model_id}': {e}\n"
"Ensure the model is available and transformers is up to date."
) from e
self._initialized = True
except ForecastToolError:
raise
except Exception as e:
raise ForecastToolError(f"Model initialization failed: {e}") from e
def _validate_series(self, series: pd.Series) -> tuple[bool, str]:
"""
Validate input time series.
Returns (is_valid, error_message).
"""
if not isinstance(series, pd.Series):
return False, "Input must be a pandas Series"
if series.empty:
return False, "Series is empty"
if len(series) < MIN_SERIES_LENGTH:
return False, f"Series too short (min {MIN_SERIES_LENGTH} points required)"
if len(series) > MAX_SERIES_LENGTH:
return False, f"Series too long (max {MAX_SERIES_LENGTH} points allowed)"
# Check for nulls
if series.isnull().any():
null_count = series.isnull().sum()
return False, f"Series contains {null_count} null values. Please handle missing data first."
# Check for infinite values
if not np.isfinite(series).all():
return False, "Series contains infinite values"
# Warn if not numeric
if not pd.api.types.is_numeric_dtype(series):
return False, f"Series must be numeric, got dtype: {series.dtype}"
return True, ""
def _validate_horizon(self, horizon: int) -> tuple[bool, str]:
"""
Validate forecast horizon.
Returns (is_valid, error_message).
"""
try:
h = int(horizon)
except (TypeError, ValueError):
return False, f"Horizon must be an integer, got: {horizon}"
if h < MIN_HORIZON:
return False, f"Horizon too small (min {MIN_HORIZON})"
if h > MAX_HORIZON:
return False, f"Horizon too large (max {MAX_HORIZON})"
return True, ""
def _prepare_input_tensor(self, series: pd.Series) -> torch.Tensor:
"""
Convert pandas Series to PyTorch tensor.
Handles type conversion and device placement.
"""
try:
# Convert to float32 numpy array
values = series.astype("float32").to_numpy()
# Create tensor and move to device
tensor = torch.tensor(values, dtype=torch.float32, device=self.device)
# Add batch dimension [1, seq_len]
tensor = tensor.unsqueeze(0)
logger.debug(f"Input tensor shape: {tensor.shape}, device: {tensor.device}")
return tensor
except Exception as e:
raise ForecastToolError(f"Failed to prepare input tensor: {e}") from e
def _try_predict_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]:
"""
Try using the model's .predict() method.
Returns None if method doesn't exist or fails.
"""
if not hasattr(self.model, "predict"):
logger.debug("Model has no 'predict' method")
return None
try:
logger.info("Attempting forecast with .predict() method")
preds = self.model.predict(x, prediction_length=horizon)
# Convert to tensor if needed
if not isinstance(preds, torch.Tensor):
preds = torch.tensor(preds, device=self.device)
# Extract numpy array
output = preds.squeeze().detach().cpu().numpy()
# Validate output shape
if output.shape[-1] != horizon:
logger.warning(
f"Prediction length mismatch: expected {horizon}, got {output.shape[-1]}"
)
logger.info(f"Forecast successful via .predict(): {output.shape}")
return output
except Exception as e:
logger.warning(f"predict() method failed: {e}")
return None
def _try_forward_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]:
"""
Try using the model's forward() method with prediction_length parameter.
Returns None if method fails.
"""
try:
logger.info("Attempting forecast with forward(prediction_length=...)")
outputs = self.model(x, prediction_length=horizon)
# Try to extract predictions from various possible output formats
prediction_tensor = None
# Check common attribute names
for attr in ("predictions", "prediction", "logits", "forecast", "output"):
if hasattr(outputs, attr):
candidate = getattr(outputs, attr)
# Handle tuple/list outputs
if isinstance(candidate, (tuple, list)):
candidate = candidate[0]
# Convert to tensor if needed
if not isinstance(candidate, torch.Tensor):
candidate = torch.tensor(candidate, device=self.device)
prediction_tensor = candidate
logger.debug(f"Found predictions in attribute: {attr}")
break
# If outputs is directly a tensor
if prediction_tensor is None and isinstance(outputs, torch.Tensor):
prediction_tensor = outputs
logger.debug("Using raw tensor output")
if prediction_tensor is None:
logger.warning("Could not extract predictions from forward() output")
return None
# Convert to numpy
output = prediction_tensor.squeeze().detach().cpu().numpy()
# Handle multi-dimensional outputs
if output.ndim > 1:
# Take the last row or flatten based on shape
if output.shape[0] == horizon:
output = output.flatten()
else:
output = output[-1] if output.shape[0] < output.shape[1] else output.flatten()
# Ensure correct length
if len(output) != horizon:
logger.warning(
f"Output length {len(output)} doesn't match horizon {horizon}. Truncating/padding."
)
if len(output) > horizon:
output = output[:horizon]
else:
# Pad with last value
output = np.pad(output, (0, horizon - len(output)), mode='edge')
logger.info(f"Forecast successful via forward(): {output.shape}")
return output
except TypeError as e:
logger.warning(f"forward() doesn't accept prediction_length: {e}")
return None
except Exception as e:
logger.warning(f"forward() method failed: {e}")
return None
def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame:
"""
Generate zero-shot forecast for input time series.
Args:
series: Input time series (pd.Series with numeric values)
horizon: Number of periods to forecast (default: 96)
Returns:
DataFrame with 'forecast' column containing predictions
Raises:
ForecastToolError: If forecasting fails
"""
try:
# Validate inputs
is_valid, error_msg = self._validate_series(series)
if not is_valid:
raise ForecastToolError(f"Invalid series: {error_msg}")
is_valid, error_msg = self._validate_horizon(horizon)
if not is_valid:
raise ForecastToolError(f"Invalid horizon: {error_msg}")
# Ensure model is loaded
self._ensure_loaded()
# Log input statistics
logger.info(
f"Forecasting: series_length={len(series)}, "
f"horizon={horizon}, "
f"series_mean={series.mean():.2f}, "
f"series_std={series.std():.2f}"
)
# Prepare input tensor
x = self._prepare_input_tensor(series)
# Try prediction methods in order of preference
output = None
with torch.no_grad():
# Method 1: Try .predict()
output = self._try_predict_method(x, horizon)
# Method 2: Try forward with prediction_length
if output is None:
output = self._try_forward_method(x, horizon)
# If all methods failed
if output is None:
raise ForecastToolError(
"Could not generate forecast using available model methods.\n"
"The model may not support zero-shot forecasting with this interface.\n"
"Suggestions:\n"
" • Check model documentation for correct usage\n"
" • Ensure transformers library is up to date\n"
" • Try a different model or use traditional forecasting (ARIMA, Prophet)\n"
f" • Model type: {type(self.model).__name__}"
)
# Create output DataFrame
result_df = pd.DataFrame({"forecast": output})
# Log output statistics
logger.info(
f"Forecast complete: "
f"mean={output.mean():.2f}, "
f"std={output.std():.2f}, "
f"min={output.min():.2f}, "
f"max={output.max():.2f}"
)
# Trace event
if self.tracer:
self.tracer.trace_event("forecast", {
"series_length": len(series),
"horizon": horizon,
"forecast_mean": float(output.mean()),
"forecast_std": float(output.std())
})
return result_df
except ForecastToolError:
raise
except Exception as e:
error_msg = f"Forecasting failed unexpectedly: {str(e)}"
logger.error(error_msg)
if self.tracer:
self.tracer.trace_event("forecast_error", {"error": error_msg})
raise ForecastToolError(error_msg) from e
def get_model_info(self) -> Dict[str, any]:
"""Get information about the loaded model."""
self._ensure_loaded()
return {
"model_id": self.model_id,
"model_type": type(self.model).__name__,
"device": str(self.device),
"has_predict": hasattr(self.model, "predict"),
"config": str(self.config) if self.config else None
} |