Spaces:
Running
Running
# tools/forecaster.py | |
# ------------------------------------------------------------ | |
# Fits an ARIMA(1,1,1) model to any (date, value) series, | |
# forecasts the next `periods` steps, plots history + forecast, | |
# and saves a hiβres PNG copy to /tmp (or custom output_dir). | |
import os | |
import tempfile | |
from typing import Tuple, Union | |
import pandas as pd | |
import plotly.graph_objects as go | |
from statsmodels.tsa.arima.model import ARIMA | |
# Typing alias | |
Plot = go.Figure | |
def forecast_metric_tool( | |
file_path: str, | |
date_col: str, | |
value_col: str, | |
periods: int = 3, | |
output_dir: str = "/tmp", | |
) -> Union[Tuple[pd.DataFrame, str], str]: | |
""" | |
Parameters | |
---------- | |
file_path : str | |
CSV or Excel path. | |
date_col : str | |
Column to treat as the date index. | |
value_col : str | |
Numeric column to forecast. | |
periods : int | |
Steps ahead to forecast. | |
output_dir: str | |
Directory to save PNG. | |
Returns | |
------- | |
(forecast_df, png_path) on success | |
error string (starting 'β') otherwise | |
""" | |
# ββ 1. Load file ββββββββββββββββββββββββββββββββββββββββββ | |
ext = os.path.splitext(file_path)[1].lower() | |
try: | |
df = ( | |
pd.read_excel(file_path) | |
if ext in (".xls", ".xlsx") | |
else pd.read_csv(file_path) | |
) | |
except Exception as exc: | |
return f"β Failed to load file: {exc}" | |
# ββ 2. Column validation βββββββββββββββββββββββββββββββββ | |
missing = [c for c in (date_col, value_col) if c not in df.columns] | |
if missing: | |
return f"β Missing column(s): {', '.join(missing)}" | |
# ββ 3. Parse & clean βββββββββββββββββββββββββββββββββββββ | |
df[date_col] = pd.to_datetime(df[date_col], errors="coerce") | |
df[value_col] = pd.to_numeric(df[value_col], errors="coerce") | |
df = df.dropna(subset=[date_col, value_col]) | |
if df.empty: | |
return f"β No valid data after cleaning '{date_col}' / '{value_col}'." | |
# Aggregate duplicate timestamps β mean | |
df = ( | |
df[[date_col, value_col]] | |
.groupby(date_col, as_index=True) | |
.mean() | |
.sort_index() | |
) | |
# Infer or default frequency | |
freq = pd.infer_freq(df.index) or "D" | |
try: | |
df = df.asfreq(freq) | |
except Exception: | |
# fallback if duplicates still exist | |
df = ( | |
df[~df.index.duplicated(keep="first")] | |
.asfreq(freq) | |
) | |
# ββ 4. Fit ARIMA(1,1,1) ββββββββββββββββββββββββββββββββββ | |
try: | |
model = ARIMA(df[value_col], order=(1, 1, 1)) | |
fit = model.fit() | |
except Exception as exc: | |
return f"β ARIMA fitting failed: {exc}" | |
# ββ 5. Forecast ββββββββββββββββββββββββββββββββββββββββββ | |
try: | |
pred = fit.get_forecast(steps=periods) | |
forecast = pred.predicted_mean | |
except Exception as exc: | |
return f"β Forecast generation failed: {exc}" | |
forecast_df = forecast.to_frame(name="Forecast") | |
# ββ 6. Plot history + forecast βββββββββββββββββββββββββββ | |
fig: Plot = go.Figure() | |
fig.add_scatter(x=df.index, y=df[value_col], mode="lines", name="History") | |
fig.add_scatter( | |
x=forecast.index, y=forecast, mode="lines+markers", name="Forecast" | |
) | |
fig.update_layout( | |
title=f"{value_col} Forecast", | |
xaxis_title=date_col, | |
yaxis_title=value_col, | |
template="plotly_dark", | |
) | |
# ββ 7. Save PNG ββββββββββββββββββββββββββββββββββββββββββ | |
os.makedirs(output_dir, exist_ok=True) | |
tmp_png = tempfile.NamedTemporaryFile( | |
prefix="forecast_", suffix=".png", dir=output_dir, delete=False | |
) | |
png_path = tmp_png.name | |
tmp_png.close() | |
try: | |
fig.write_image(png_path, scale=2) | |
except Exception as exc: | |
return f"β Plot saving failed: {exc}" | |
return forecast_df, png_path | |