File size: 4,384 Bytes
92cca14
afa778c
 
 
 
 
eec9db3
 
afa778c
 
010071f
3651f7b
afa778c
 
 
 
010071f
eec9db3
 
 
 
 
 
afa778c
92cca14
1fadf44
afa778c
 
 
 
 
 
 
 
 
 
 
 
5405a02
afa778c
 
 
 
eec9db3
afa778c
eec9db3
 
afa778c
 
 
 
 
92cca14
 
5405a02
afa778c
92cca14
 
 
eec9db3
afa778c
 
 
3dbe4ef
 
afa778c
e9cc996
afa778c
eec9db3
 
 
 
 
 
e9cc996
afa778c
 
eec9db3
 
92cca14
afa778c
 
 
 
 
e9cc996
afa778c
9db6dea
eec9db3
 
92cca14
 
e9cc996
afa778c
92cca14
 
 
 
 
afa778c
 
 
 
 
 
 
 
5405a02
3dbe4ef
 
e9cc996
 
afa778c
eec9db3
 
afa778c
eec9db3
afa778c
 
 
 
 
92cca14
afa778c
92cca14
 
e9cc996
afa778c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# tools/forecaster.py
# ------------------------------------------------------------
# Fits an ARIMA(1,1,1) model to any (date, value) series,
# forecasts the next `periods` steps, plots history + forecast,
# and saves a hi‑res PNG copy to /tmp (or custom output_dir).

import os
import tempfile
from typing import Tuple, Union

import pandas as pd
import plotly.graph_objects as go
from statsmodels.tsa.arima.model import ARIMA

# Typing alias
Plot = go.Figure


def forecast_metric_tool(
    file_path: str,
    date_col: str,
    value_col: str,
    periods: int = 3,
    output_dir: str = "/tmp",
) -> Union[Tuple[pd.DataFrame, str], str]:
    """
    Parameters
    ----------
    file_path : str
        CSV or Excel path.
    date_col  : str
        Column to treat as the date index.
    value_col : str
        Numeric column to forecast.
    periods   : int
        Steps ahead to forecast.
    output_dir: str
        Directory to save PNG.

    Returns
    -------
    (forecast_df, png_path)   on success
    error string (starting '❌') otherwise
    """
    # ── 1. Load file ──────────────────────────────────────────
    ext = os.path.splitext(file_path)[1].lower()
    try:
        df = (
            pd.read_excel(file_path)
            if ext in (".xls", ".xlsx")
            else pd.read_csv(file_path)
        )
    except Exception as exc:
        return f"❌ Failed to load file: {exc}"

    # ── 2. Column validation ─────────────────────────────────
    missing = [c for c in (date_col, value_col) if c not in df.columns]
    if missing:
        return f"❌ Missing column(s): {', '.join(missing)}"

    # ── 3. Parse & clean ─────────────────────────────────────
    df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
    df[value_col] = pd.to_numeric(df[value_col], errors="coerce")
    df = df.dropna(subset=[date_col, value_col])
    if df.empty:
        return f"❌ No valid data after cleaning '{date_col}' / '{value_col}'."

    # Aggregate duplicate timestamps β†’ mean
    df = (
        df[[date_col, value_col]]
        .groupby(date_col, as_index=True)
        .mean()
        .sort_index()
    )

    # Infer or default frequency
    freq = pd.infer_freq(df.index) or "D"
    try:
        df = df.asfreq(freq)
    except Exception:
        # fallback if duplicates still exist
        df = (
            df[~df.index.duplicated(keep="first")]
            .asfreq(freq)
        )

    # ── 4. Fit ARIMA(1,1,1) ──────────────────────────────────
    try:
        model = ARIMA(df[value_col], order=(1, 1, 1))
        fit = model.fit()
    except Exception as exc:
        return f"❌ ARIMA fitting failed: {exc}"

    # ── 5. Forecast ──────────────────────────────────────────
    try:
        pred = fit.get_forecast(steps=periods)
        forecast = pred.predicted_mean
    except Exception as exc:
        return f"❌ Forecast generation failed: {exc}"

    forecast_df = forecast.to_frame(name="Forecast")

    # ── 6. Plot history + forecast ───────────────────────────
    fig: Plot = go.Figure()
    fig.add_scatter(x=df.index, y=df[value_col], mode="lines", name="History")
    fig.add_scatter(
        x=forecast.index, y=forecast, mode="lines+markers", name="Forecast"
    )
    fig.update_layout(
        title=f"{value_col} Forecast",
        xaxis_title=date_col,
        yaxis_title=value_col,
        template="plotly_dark",
    )

    # ── 7. Save PNG ──────────────────────────────────────────
    os.makedirs(output_dir, exist_ok=True)
    tmp_png = tempfile.NamedTemporaryFile(
        prefix="forecast_", suffix=".png", dir=output_dir, delete=False
    )
    png_path = tmp_png.name
    tmp_png.close()
    try:
        fig.write_image(png_path, scale=2)
    except Exception as exc:
        return f"❌ Plot saving failed: {exc}"

    return forecast_df, png_path