ALM_LLM / tools /ts_forecast_tool.py
AshenH's picture
Update tools/ts_forecast_tool.py
4cad9bd verified
# space/tools/ts_forecast_tool.py
import os
import logging
from typing import Optional, Dict
import torch
import pandas as pd
import numpy as np
from utils.tracing import Tracer
from utils.config import AppConfig
from transformers import AutoModel, AutoConfig
logger = logging.getLogger(__name__)
# Constants
MIN_SERIES_LENGTH = 2
MAX_SERIES_LENGTH = 10000
MIN_HORIZON = 1
MAX_HORIZON = 365
DEFAULT_MODEL_ID = "ibm-granite/granite-timeseries-ttm-r1"
class ForecastToolError(Exception):
"""Custom exception for forecast tool errors."""
pass
class TimeseriesForecastTool:
"""
Lightweight wrapper around Granite Time Series models for zero-shot forecasting.
This wrapper:
- Loads the model with AutoModel.from_pretrained
- Validates input series and horizon
- Attempts multiple inference methods (predict, forward with prediction_length)
- Returns a Pandas DataFrame with forecast column
- Provides comprehensive error handling and logging
Expected input:
- series: pd.Series with DatetimeIndex (regular frequency recommended)
- horizon: int, number of future steps to forecast
"""
def __init__(
self,
cfg: Optional[AppConfig],
tracer: Optional[Tracer],
model_id: str = DEFAULT_MODEL_ID,
device: Optional[str] = None,
):
self.cfg = cfg
self.tracer = tracer
self.model_id = model_id
self.model = None
self.config = None
# Determine device
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"TimeseriesForecastTool initialized with device: {self.device}")
# Lazy loading - model loaded on first use
self._initialized = False
def _ensure_loaded(self):
"""Lazy load the model and configuration."""
if self._initialized:
return
try:
logger.info(f"Loading Granite time series model: {self.model_id}")
# Load configuration
try:
self.config = AutoConfig.from_pretrained(self.model_id)
logger.info(f"Model config loaded: {type(self.config).__name__}")
except Exception as e:
logger.warning(f"Could not load model config: {e}")
self.config = None
# Load model
try:
self.model = AutoModel.from_pretrained(
self.model_id,
trust_remote_code=True # Required for some custom models
)
self.model.to(self.device)
self.model.eval()
logger.info(f"Model loaded successfully: {type(self.model).__name__}")
except Exception as e:
raise ForecastToolError(
f"Failed to load model '{self.model_id}': {e}\n"
"Ensure the model is available and transformers is up to date."
) from e
self._initialized = True
except ForecastToolError:
raise
except Exception as e:
raise ForecastToolError(f"Model initialization failed: {e}") from e
def _validate_series(self, series: pd.Series) -> tuple[bool, str]:
"""
Validate input time series.
Returns (is_valid, error_message).
"""
if not isinstance(series, pd.Series):
return False, "Input must be a pandas Series"
if series.empty:
return False, "Series is empty"
if len(series) < MIN_SERIES_LENGTH:
return False, f"Series too short (min {MIN_SERIES_LENGTH} points required)"
if len(series) > MAX_SERIES_LENGTH:
return False, f"Series too long (max {MAX_SERIES_LENGTH} points allowed)"
# Check for nulls
if series.isnull().any():
null_count = series.isnull().sum()
return False, f"Series contains {null_count} null values. Please handle missing data first."
# Check for infinite values
if not np.isfinite(series).all():
return False, "Series contains infinite values"
# Warn if not numeric
if not pd.api.types.is_numeric_dtype(series):
return False, f"Series must be numeric, got dtype: {series.dtype}"
return True, ""
def _validate_horizon(self, horizon: int) -> tuple[bool, str]:
"""
Validate forecast horizon.
Returns (is_valid, error_message).
"""
try:
h = int(horizon)
except (TypeError, ValueError):
return False, f"Horizon must be an integer, got: {horizon}"
if h < MIN_HORIZON:
return False, f"Horizon too small (min {MIN_HORIZON})"
if h > MAX_HORIZON:
return False, f"Horizon too large (max {MAX_HORIZON})"
return True, ""
def _prepare_input_tensor(self, series: pd.Series) -> torch.Tensor:
"""
Convert pandas Series to PyTorch tensor.
Handles type conversion and device placement.
"""
try:
# Convert to float32 numpy array
values = series.astype("float32").to_numpy()
# Create tensor and move to device
tensor = torch.tensor(values, dtype=torch.float32, device=self.device)
# Add batch dimension [1, seq_len]
tensor = tensor.unsqueeze(0)
logger.debug(f"Input tensor shape: {tensor.shape}, device: {tensor.device}")
return tensor
except Exception as e:
raise ForecastToolError(f"Failed to prepare input tensor: {e}") from e
def _try_predict_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]:
"""
Try using the model's .predict() method.
Returns None if method doesn't exist or fails.
"""
if not hasattr(self.model, "predict"):
logger.debug("Model has no 'predict' method")
return None
try:
logger.info("Attempting forecast with .predict() method")
preds = self.model.predict(x, prediction_length=horizon)
# Convert to tensor if needed
if not isinstance(preds, torch.Tensor):
preds = torch.tensor(preds, device=self.device)
# Extract numpy array
output = preds.squeeze().detach().cpu().numpy()
# Validate output shape
if output.shape[-1] != horizon:
logger.warning(
f"Prediction length mismatch: expected {horizon}, got {output.shape[-1]}"
)
logger.info(f"Forecast successful via .predict(): {output.shape}")
return output
except Exception as e:
logger.warning(f"predict() method failed: {e}")
return None
def _try_forward_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]:
"""
Try using the model's forward() method with prediction_length parameter.
Returns None if method fails.
"""
try:
logger.info("Attempting forecast with forward(prediction_length=...)")
outputs = self.model(x, prediction_length=horizon)
# Try to extract predictions from various possible output formats
prediction_tensor = None
# Check common attribute names
for attr in ("predictions", "prediction", "logits", "forecast", "output"):
if hasattr(outputs, attr):
candidate = getattr(outputs, attr)
# Handle tuple/list outputs
if isinstance(candidate, (tuple, list)):
candidate = candidate[0]
# Convert to tensor if needed
if not isinstance(candidate, torch.Tensor):
candidate = torch.tensor(candidate, device=self.device)
prediction_tensor = candidate
logger.debug(f"Found predictions in attribute: {attr}")
break
# If outputs is directly a tensor
if prediction_tensor is None and isinstance(outputs, torch.Tensor):
prediction_tensor = outputs
logger.debug("Using raw tensor output")
if prediction_tensor is None:
logger.warning("Could not extract predictions from forward() output")
return None
# Convert to numpy
output = prediction_tensor.squeeze().detach().cpu().numpy()
# Handle multi-dimensional outputs
if output.ndim > 1:
# Take the last row or flatten based on shape
if output.shape[0] == horizon:
output = output.flatten()
else:
output = output[-1] if output.shape[0] < output.shape[1] else output.flatten()
# Ensure correct length
if len(output) != horizon:
logger.warning(
f"Output length {len(output)} doesn't match horizon {horizon}. Truncating/padding."
)
if len(output) > horizon:
output = output[:horizon]
else:
# Pad with last value
output = np.pad(output, (0, horizon - len(output)), mode='edge')
logger.info(f"Forecast successful via forward(): {output.shape}")
return output
except TypeError as e:
logger.warning(f"forward() doesn't accept prediction_length: {e}")
return None
except Exception as e:
logger.warning(f"forward() method failed: {e}")
return None
def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame:
"""
Generate zero-shot forecast for input time series.
Args:
series: Input time series (pd.Series with numeric values)
horizon: Number of periods to forecast (default: 96)
Returns:
DataFrame with 'forecast' column containing predictions
Raises:
ForecastToolError: If forecasting fails
"""
try:
# Validate inputs
is_valid, error_msg = self._validate_series(series)
if not is_valid:
raise ForecastToolError(f"Invalid series: {error_msg}")
is_valid, error_msg = self._validate_horizon(horizon)
if not is_valid:
raise ForecastToolError(f"Invalid horizon: {error_msg}")
# Ensure model is loaded
self._ensure_loaded()
# Log input statistics
logger.info(
f"Forecasting: series_length={len(series)}, "
f"horizon={horizon}, "
f"series_mean={series.mean():.2f}, "
f"series_std={series.std():.2f}"
)
# Prepare input tensor
x = self._prepare_input_tensor(series)
# Try prediction methods in order of preference
output = None
with torch.no_grad():
# Method 1: Try .predict()
output = self._try_predict_method(x, horizon)
# Method 2: Try forward with prediction_length
if output is None:
output = self._try_forward_method(x, horizon)
# If all methods failed
if output is None:
raise ForecastToolError(
"Could not generate forecast using available model methods.\n"
"The model may not support zero-shot forecasting with this interface.\n"
"Suggestions:\n"
" • Check model documentation for correct usage\n"
" • Ensure transformers library is up to date\n"
" • Try a different model or use traditional forecasting (ARIMA, Prophet)\n"
f" • Model type: {type(self.model).__name__}"
)
# Create output DataFrame
result_df = pd.DataFrame({"forecast": output})
# Log output statistics
logger.info(
f"Forecast complete: "
f"mean={output.mean():.2f}, "
f"std={output.std():.2f}, "
f"min={output.min():.2f}, "
f"max={output.max():.2f}"
)
# Trace event
if self.tracer:
self.tracer.trace_event("forecast", {
"series_length": len(series),
"horizon": horizon,
"forecast_mean": float(output.mean()),
"forecast_std": float(output.std())
})
return result_df
except ForecastToolError:
raise
except Exception as e:
error_msg = f"Forecasting failed unexpectedly: {str(e)}"
logger.error(error_msg)
if self.tracer:
self.tracer.trace_event("forecast_error", {"error": error_msg})
raise ForecastToolError(error_msg) from e
def get_model_info(self) -> Dict[str, any]:
"""Get information about the loaded model."""
self._ensure_loaded()
return {
"model_id": self.model_id,
"model_type": type(self.model).__name__,
"device": str(self.device),
"has_predict": hasattr(self.model, "predict"),
"config": str(self.config) if self.config else None
}