|
|
|
|
|
import os |
|
|
import logging |
|
|
from typing import Optional, Dict |
|
|
|
|
|
import torch |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
|
|
|
from utils.tracing import Tracer |
|
|
from utils.config import AppConfig |
|
|
from transformers import AutoModel, AutoConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
MIN_SERIES_LENGTH = 2 |
|
|
MAX_SERIES_LENGTH = 10000 |
|
|
MIN_HORIZON = 1 |
|
|
MAX_HORIZON = 365 |
|
|
DEFAULT_MODEL_ID = "ibm-granite/granite-timeseries-ttm-r1" |
|
|
|
|
|
|
|
|
class ForecastToolError(Exception): |
|
|
"""Custom exception for forecast tool errors.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class TimeseriesForecastTool: |
|
|
""" |
|
|
Lightweight wrapper around Granite Time Series models for zero-shot forecasting. |
|
|
|
|
|
This wrapper: |
|
|
- Loads the model with AutoModel.from_pretrained |
|
|
- Validates input series and horizon |
|
|
- Attempts multiple inference methods (predict, forward with prediction_length) |
|
|
- Returns a Pandas DataFrame with forecast column |
|
|
- Provides comprehensive error handling and logging |
|
|
|
|
|
Expected input: |
|
|
- series: pd.Series with DatetimeIndex (regular frequency recommended) |
|
|
- horizon: int, number of future steps to forecast |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
cfg: Optional[AppConfig], |
|
|
tracer: Optional[Tracer], |
|
|
model_id: str = DEFAULT_MODEL_ID, |
|
|
device: Optional[str] = None, |
|
|
): |
|
|
self.cfg = cfg |
|
|
self.tracer = tracer |
|
|
self.model_id = model_id |
|
|
self.model = None |
|
|
self.config = None |
|
|
|
|
|
|
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
logger.info(f"TimeseriesForecastTool initialized with device: {self.device}") |
|
|
|
|
|
|
|
|
self._initialized = False |
|
|
|
|
|
def _ensure_loaded(self): |
|
|
"""Lazy load the model and configuration.""" |
|
|
if self._initialized: |
|
|
return |
|
|
|
|
|
try: |
|
|
logger.info(f"Loading Granite time series model: {self.model_id}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.config = AutoConfig.from_pretrained(self.model_id) |
|
|
logger.info(f"Model config loaded: {type(self.config).__name__}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not load model config: {e}") |
|
|
self.config = None |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = AutoModel.from_pretrained( |
|
|
self.model_id, |
|
|
trust_remote_code=True |
|
|
) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
logger.info(f"Model loaded successfully: {type(self.model).__name__}") |
|
|
|
|
|
except Exception as e: |
|
|
raise ForecastToolError( |
|
|
f"Failed to load model '{self.model_id}': {e}\n" |
|
|
"Ensure the model is available and transformers is up to date." |
|
|
) from e |
|
|
|
|
|
self._initialized = True |
|
|
|
|
|
except ForecastToolError: |
|
|
raise |
|
|
except Exception as e: |
|
|
raise ForecastToolError(f"Model initialization failed: {e}") from e |
|
|
|
|
|
def _validate_series(self, series: pd.Series) -> tuple[bool, str]: |
|
|
""" |
|
|
Validate input time series. |
|
|
Returns (is_valid, error_message). |
|
|
""" |
|
|
if not isinstance(series, pd.Series): |
|
|
return False, "Input must be a pandas Series" |
|
|
|
|
|
if series.empty: |
|
|
return False, "Series is empty" |
|
|
|
|
|
if len(series) < MIN_SERIES_LENGTH: |
|
|
return False, f"Series too short (min {MIN_SERIES_LENGTH} points required)" |
|
|
|
|
|
if len(series) > MAX_SERIES_LENGTH: |
|
|
return False, f"Series too long (max {MAX_SERIES_LENGTH} points allowed)" |
|
|
|
|
|
|
|
|
if series.isnull().any(): |
|
|
null_count = series.isnull().sum() |
|
|
return False, f"Series contains {null_count} null values. Please handle missing data first." |
|
|
|
|
|
|
|
|
if not np.isfinite(series).all(): |
|
|
return False, "Series contains infinite values" |
|
|
|
|
|
|
|
|
if not pd.api.types.is_numeric_dtype(series): |
|
|
return False, f"Series must be numeric, got dtype: {series.dtype}" |
|
|
|
|
|
return True, "" |
|
|
|
|
|
def _validate_horizon(self, horizon: int) -> tuple[bool, str]: |
|
|
""" |
|
|
Validate forecast horizon. |
|
|
Returns (is_valid, error_message). |
|
|
""" |
|
|
try: |
|
|
h = int(horizon) |
|
|
except (TypeError, ValueError): |
|
|
return False, f"Horizon must be an integer, got: {horizon}" |
|
|
|
|
|
if h < MIN_HORIZON: |
|
|
return False, f"Horizon too small (min {MIN_HORIZON})" |
|
|
|
|
|
if h > MAX_HORIZON: |
|
|
return False, f"Horizon too large (max {MAX_HORIZON})" |
|
|
|
|
|
return True, "" |
|
|
|
|
|
def _prepare_input_tensor(self, series: pd.Series) -> torch.Tensor: |
|
|
""" |
|
|
Convert pandas Series to PyTorch tensor. |
|
|
Handles type conversion and device placement. |
|
|
""" |
|
|
try: |
|
|
|
|
|
values = series.astype("float32").to_numpy() |
|
|
|
|
|
|
|
|
tensor = torch.tensor(values, dtype=torch.float32, device=self.device) |
|
|
|
|
|
|
|
|
tensor = tensor.unsqueeze(0) |
|
|
|
|
|
logger.debug(f"Input tensor shape: {tensor.shape}, device: {tensor.device}") |
|
|
|
|
|
return tensor |
|
|
|
|
|
except Exception as e: |
|
|
raise ForecastToolError(f"Failed to prepare input tensor: {e}") from e |
|
|
|
|
|
def _try_predict_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]: |
|
|
""" |
|
|
Try using the model's .predict() method. |
|
|
Returns None if method doesn't exist or fails. |
|
|
""" |
|
|
if not hasattr(self.model, "predict"): |
|
|
logger.debug("Model has no 'predict' method") |
|
|
return None |
|
|
|
|
|
try: |
|
|
logger.info("Attempting forecast with .predict() method") |
|
|
preds = self.model.predict(x, prediction_length=horizon) |
|
|
|
|
|
|
|
|
if not isinstance(preds, torch.Tensor): |
|
|
preds = torch.tensor(preds, device=self.device) |
|
|
|
|
|
|
|
|
output = preds.squeeze().detach().cpu().numpy() |
|
|
|
|
|
|
|
|
if output.shape[-1] != horizon: |
|
|
logger.warning( |
|
|
f"Prediction length mismatch: expected {horizon}, got {output.shape[-1]}" |
|
|
) |
|
|
|
|
|
logger.info(f"Forecast successful via .predict(): {output.shape}") |
|
|
return output |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f"predict() method failed: {e}") |
|
|
return None |
|
|
|
|
|
def _try_forward_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]: |
|
|
""" |
|
|
Try using the model's forward() method with prediction_length parameter. |
|
|
Returns None if method fails. |
|
|
""" |
|
|
try: |
|
|
logger.info("Attempting forecast with forward(prediction_length=...)") |
|
|
outputs = self.model(x, prediction_length=horizon) |
|
|
|
|
|
|
|
|
prediction_tensor = None |
|
|
|
|
|
|
|
|
for attr in ("predictions", "prediction", "logits", "forecast", "output"): |
|
|
if hasattr(outputs, attr): |
|
|
candidate = getattr(outputs, attr) |
|
|
|
|
|
|
|
|
if isinstance(candidate, (tuple, list)): |
|
|
candidate = candidate[0] |
|
|
|
|
|
|
|
|
if not isinstance(candidate, torch.Tensor): |
|
|
candidate = torch.tensor(candidate, device=self.device) |
|
|
|
|
|
prediction_tensor = candidate |
|
|
logger.debug(f"Found predictions in attribute: {attr}") |
|
|
break |
|
|
|
|
|
|
|
|
if prediction_tensor is None and isinstance(outputs, torch.Tensor): |
|
|
prediction_tensor = outputs |
|
|
logger.debug("Using raw tensor output") |
|
|
|
|
|
if prediction_tensor is None: |
|
|
logger.warning("Could not extract predictions from forward() output") |
|
|
return None |
|
|
|
|
|
|
|
|
output = prediction_tensor.squeeze().detach().cpu().numpy() |
|
|
|
|
|
|
|
|
if output.ndim > 1: |
|
|
|
|
|
if output.shape[0] == horizon: |
|
|
output = output.flatten() |
|
|
else: |
|
|
output = output[-1] if output.shape[0] < output.shape[1] else output.flatten() |
|
|
|
|
|
|
|
|
if len(output) != horizon: |
|
|
logger.warning( |
|
|
f"Output length {len(output)} doesn't match horizon {horizon}. Truncating/padding." |
|
|
) |
|
|
if len(output) > horizon: |
|
|
output = output[:horizon] |
|
|
else: |
|
|
|
|
|
output = np.pad(output, (0, horizon - len(output)), mode='edge') |
|
|
|
|
|
logger.info(f"Forecast successful via forward(): {output.shape}") |
|
|
return output |
|
|
|
|
|
except TypeError as e: |
|
|
logger.warning(f"forward() doesn't accept prediction_length: {e}") |
|
|
return None |
|
|
except Exception as e: |
|
|
logger.warning(f"forward() method failed: {e}") |
|
|
return None |
|
|
|
|
|
def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame: |
|
|
""" |
|
|
Generate zero-shot forecast for input time series. |
|
|
|
|
|
Args: |
|
|
series: Input time series (pd.Series with numeric values) |
|
|
horizon: Number of periods to forecast (default: 96) |
|
|
|
|
|
Returns: |
|
|
DataFrame with 'forecast' column containing predictions |
|
|
|
|
|
Raises: |
|
|
ForecastToolError: If forecasting fails |
|
|
""" |
|
|
try: |
|
|
|
|
|
is_valid, error_msg = self._validate_series(series) |
|
|
if not is_valid: |
|
|
raise ForecastToolError(f"Invalid series: {error_msg}") |
|
|
|
|
|
is_valid, error_msg = self._validate_horizon(horizon) |
|
|
if not is_valid: |
|
|
raise ForecastToolError(f"Invalid horizon: {error_msg}") |
|
|
|
|
|
|
|
|
self._ensure_loaded() |
|
|
|
|
|
|
|
|
logger.info( |
|
|
f"Forecasting: series_length={len(series)}, " |
|
|
f"horizon={horizon}, " |
|
|
f"series_mean={series.mean():.2f}, " |
|
|
f"series_std={series.std():.2f}" |
|
|
) |
|
|
|
|
|
|
|
|
x = self._prepare_input_tensor(series) |
|
|
|
|
|
|
|
|
output = None |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
output = self._try_predict_method(x, horizon) |
|
|
|
|
|
|
|
|
if output is None: |
|
|
output = self._try_forward_method(x, horizon) |
|
|
|
|
|
|
|
|
if output is None: |
|
|
raise ForecastToolError( |
|
|
"Could not generate forecast using available model methods.\n" |
|
|
"The model may not support zero-shot forecasting with this interface.\n" |
|
|
"Suggestions:\n" |
|
|
" • Check model documentation for correct usage\n" |
|
|
" • Ensure transformers library is up to date\n" |
|
|
" • Try a different model or use traditional forecasting (ARIMA, Prophet)\n" |
|
|
f" • Model type: {type(self.model).__name__}" |
|
|
) |
|
|
|
|
|
|
|
|
result_df = pd.DataFrame({"forecast": output}) |
|
|
|
|
|
|
|
|
logger.info( |
|
|
f"Forecast complete: " |
|
|
f"mean={output.mean():.2f}, " |
|
|
f"std={output.std():.2f}, " |
|
|
f"min={output.min():.2f}, " |
|
|
f"max={output.max():.2f}" |
|
|
) |
|
|
|
|
|
|
|
|
if self.tracer: |
|
|
self.tracer.trace_event("forecast", { |
|
|
"series_length": len(series), |
|
|
"horizon": horizon, |
|
|
"forecast_mean": float(output.mean()), |
|
|
"forecast_std": float(output.std()) |
|
|
}) |
|
|
|
|
|
return result_df |
|
|
|
|
|
except ForecastToolError: |
|
|
raise |
|
|
except Exception as e: |
|
|
error_msg = f"Forecasting failed unexpectedly: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
if self.tracer: |
|
|
self.tracer.trace_event("forecast_error", {"error": error_msg}) |
|
|
raise ForecastToolError(error_msg) from e |
|
|
|
|
|
def get_model_info(self) -> Dict[str, any]: |
|
|
"""Get information about the loaded model.""" |
|
|
self._ensure_loaded() |
|
|
|
|
|
return { |
|
|
"model_id": self.model_id, |
|
|
"model_type": type(self.model).__name__, |
|
|
"device": str(self.device), |
|
|
"has_predict": hasattr(self.model, "predict"), |
|
|
"config": str(self.config) if self.config else None |
|
|
} |