BizIntel_AI / tools /forecaster.py
mgbam's picture
Update tools/forecaster.py
afa778c verified
# tools/forecaster.py
# ------------------------------------------------------------
# Fits an ARIMA(1,1,1) model to any (date, value) series,
# forecasts the next `periods` steps, plots history + forecast,
# and saves a hi‑res PNG copy to /tmp (or custom output_dir).
import os
import tempfile
from typing import Tuple, Union
import pandas as pd
import plotly.graph_objects as go
from statsmodels.tsa.arima.model import ARIMA
# Typing alias
Plot = go.Figure
def forecast_metric_tool(
file_path: str,
date_col: str,
value_col: str,
periods: int = 3,
output_dir: str = "/tmp",
) -> Union[Tuple[pd.DataFrame, str], str]:
"""
Parameters
----------
file_path : str
CSV or Excel path.
date_col : str
Column to treat as the date index.
value_col : str
Numeric column to forecast.
periods : int
Steps ahead to forecast.
output_dir: str
Directory to save PNG.
Returns
-------
(forecast_df, png_path) on success
error string (starting '❌') otherwise
"""
# ── 1. Load file ──────────────────────────────────────────
ext = os.path.splitext(file_path)[1].lower()
try:
df = (
pd.read_excel(file_path)
if ext in (".xls", ".xlsx")
else pd.read_csv(file_path)
)
except Exception as exc:
return f"❌ Failed to load file: {exc}"
# ── 2. Column validation ─────────────────────────────────
missing = [c for c in (date_col, value_col) if c not in df.columns]
if missing:
return f"❌ Missing column(s): {', '.join(missing)}"
# ── 3. Parse & clean ─────────────────────────────────────
df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
df[value_col] = pd.to_numeric(df[value_col], errors="coerce")
df = df.dropna(subset=[date_col, value_col])
if df.empty:
return f"❌ No valid data after cleaning '{date_col}' / '{value_col}'."
# Aggregate duplicate timestamps β†’ mean
df = (
df[[date_col, value_col]]
.groupby(date_col, as_index=True)
.mean()
.sort_index()
)
# Infer or default frequency
freq = pd.infer_freq(df.index) or "D"
try:
df = df.asfreq(freq)
except Exception:
# fallback if duplicates still exist
df = (
df[~df.index.duplicated(keep="first")]
.asfreq(freq)
)
# ── 4. Fit ARIMA(1,1,1) ──────────────────────────────────
try:
model = ARIMA(df[value_col], order=(1, 1, 1))
fit = model.fit()
except Exception as exc:
return f"❌ ARIMA fitting failed: {exc}"
# ── 5. Forecast ──────────────────────────────────────────
try:
pred = fit.get_forecast(steps=periods)
forecast = pred.predicted_mean
except Exception as exc:
return f"❌ Forecast generation failed: {exc}"
forecast_df = forecast.to_frame(name="Forecast")
# ── 6. Plot history + forecast ───────────────────────────
fig: Plot = go.Figure()
fig.add_scatter(x=df.index, y=df[value_col], mode="lines", name="History")
fig.add_scatter(
x=forecast.index, y=forecast, mode="lines+markers", name="Forecast"
)
fig.update_layout(
title=f"{value_col} Forecast",
xaxis_title=date_col,
yaxis_title=value_col,
template="plotly_dark",
)
# ── 7. Save PNG ──────────────────────────────────────────
os.makedirs(output_dir, exist_ok=True)
tmp_png = tempfile.NamedTemporaryFile(
prefix="forecast_", suffix=".png", dir=output_dir, delete=False
)
png_path = tmp_png.name
tmp_png.close()
try:
fig.write_image(png_path, scale=2)
except Exception as exc:
return f"❌ Plot saving failed: {exc}"
return forecast_df, png_path