mgbam commited on
Commit
afa778c
Β·
verified Β·
1 Parent(s): ccdbd61

Update tools/forecaster.py

Browse files
Files changed (1) hide show
  1. tools/forecaster.py +66 -39
tools/forecaster.py CHANGED
@@ -1,10 +1,19 @@
1
  # tools/forecaster.py
 
 
 
 
 
2
  import os
3
  import tempfile
 
 
4
  import pandas as pd
5
- from statsmodels.tsa.arima.model import ARIMA
6
  import plotly.graph_objects as go
7
- from typing import Tuple, Union
 
 
 
8
 
9
 
10
  def forecast_metric_tool(
@@ -12,39 +21,51 @@ def forecast_metric_tool(
12
  date_col: str,
13
  value_col: str,
14
  periods: int = 3,
15
- output_dir: str = "/tmp"
16
  ) -> Union[Tuple[pd.DataFrame, str], str]:
17
  """
18
- Load CSV or Excel, parse a time series metric, fit ARIMA(1,1,1),
19
- forecast next `periods` steps, return DataFrame and PNG path.
 
 
 
 
 
 
 
 
 
 
20
 
21
- Returns:
22
- - (forecast_df, plot_path) on success
23
- - error string starting with '❌' on failure
 
24
  """
25
- # Load data
26
  ext = os.path.splitext(file_path)[1].lower()
27
  try:
28
- df = pd.read_excel(file_path) if ext in ('.xls', '.xlsx') else pd.read_csv(file_path)
 
 
 
 
29
  except Exception as exc:
30
  return f"❌ Failed to load file: {exc}"
31
 
32
- # Validate columns
33
  missing = [c for c in (date_col, value_col) if c not in df.columns]
34
  if missing:
35
  return f"❌ Missing column(s): {', '.join(missing)}"
36
 
37
- # Parse and clean
38
- try:
39
- df[date_col] = pd.to_datetime(df[date_col], errors='coerce')
40
- except Exception:
41
- return f"❌ Could not parse '{date_col}' as dates."
42
- df[value_col] = pd.to_numeric(df[value_col], errors='coerce')
43
  df = df.dropna(subset=[date_col, value_col])
44
  if df.empty:
45
- return f"❌ No valid data after cleaning '{date_col}'/'{value_col}'"
46
 
47
- # Aggregate duplicates and sort
48
  df = (
49
  df[[date_col, value_col]]
50
  .groupby(date_col, as_index=True)
@@ -52,50 +73,56 @@ def forecast_metric_tool(
52
  .sort_index()
53
  )
54
 
55
- # Infer frequency
56
- freq = pd.infer_freq(df.index) or 'D'
57
  try:
58
  df = df.asfreq(freq)
59
  except Exception:
60
- df = df[~df.index.duplicated(keep='first')].asfreq(freq)
 
 
 
 
61
 
62
- # Fit ARIMA
63
  try:
64
  model = ARIMA(df[value_col], order=(1, 1, 1))
65
  fit = model.fit()
66
  except Exception as exc:
67
  return f"❌ ARIMA fitting failed: {exc}"
68
 
69
- # Forecast
70
  try:
71
  pred = fit.get_forecast(steps=periods)
72
  forecast = pred.predicted_mean
73
  except Exception as exc:
74
  return f"❌ Forecast generation failed: {exc}"
75
- forecast_df = forecast.to_frame(name='Forecast')
76
-
77
- # Plot history + forecast
78
- fig = go.Figure(
79
- data=[
80
- go.Scatter(x=df.index, y=df[value_col], mode='lines', name='History'),
81
- go.Scatter(x=forecast.index, y=forecast, mode='lines+markers', name='Forecast')
82
- ]
83
  )
84
  fig.update_layout(
85
  title=f"{value_col} Forecast",
86
  xaxis_title=date_col,
87
  yaxis_title=value_col,
88
- template='plotly_dark'
89
  )
90
 
91
- # Save PNG
92
  os.makedirs(output_dir, exist_ok=True)
93
- tmp = tempfile.NamedTemporaryFile(suffix='.png', prefix='forecast_', dir=output_dir, delete=False)
94
- plot_path = tmp.name
95
- tmp.close()
 
 
96
  try:
97
- fig.write_image(plot_path, scale=2)
98
  except Exception as exc:
99
  return f"❌ Plot saving failed: {exc}"
100
 
101
- return forecast_df, plot_path
 
1
  # tools/forecaster.py
2
+ # ------------------------------------------------------------
3
+ # Fits an ARIMA(1,1,1) model to any (date, value) series,
4
+ # forecasts the next `periods` steps, plots history + forecast,
5
+ # and saves a hi‑res PNG copy to /tmp (or custom output_dir).
6
+
7
  import os
8
  import tempfile
9
+ from typing import Tuple, Union
10
+
11
  import pandas as pd
 
12
  import plotly.graph_objects as go
13
+ from statsmodels.tsa.arima.model import ARIMA
14
+
15
+ # Typing alias
16
+ Plot = go.Figure
17
 
18
 
19
  def forecast_metric_tool(
 
21
  date_col: str,
22
  value_col: str,
23
  periods: int = 3,
24
+ output_dir: str = "/tmp",
25
  ) -> Union[Tuple[pd.DataFrame, str], str]:
26
  """
27
+ Parameters
28
+ ----------
29
+ file_path : str
30
+ CSV or Excel path.
31
+ date_col : str
32
+ Column to treat as the date index.
33
+ value_col : str
34
+ Numeric column to forecast.
35
+ periods : int
36
+ Steps ahead to forecast.
37
+ output_dir: str
38
+ Directory to save PNG.
39
 
40
+ Returns
41
+ -------
42
+ (forecast_df, png_path) on success
43
+ error string (starting '❌') otherwise
44
  """
45
+ # ── 1. Load file ──────────────────────────────────────────
46
  ext = os.path.splitext(file_path)[1].lower()
47
  try:
48
+ df = (
49
+ pd.read_excel(file_path)
50
+ if ext in (".xls", ".xlsx")
51
+ else pd.read_csv(file_path)
52
+ )
53
  except Exception as exc:
54
  return f"❌ Failed to load file: {exc}"
55
 
56
+ # ── 2. Column validation ─────────────────────────────────
57
  missing = [c for c in (date_col, value_col) if c not in df.columns]
58
  if missing:
59
  return f"❌ Missing column(s): {', '.join(missing)}"
60
 
61
+ # ── 3. Parse & clean ─────────────────────────────────────
62
+ df[date_col] = pd.to_datetime(df[date_col], errors="coerce")
63
+ df[value_col] = pd.to_numeric(df[value_col], errors="coerce")
 
 
 
64
  df = df.dropna(subset=[date_col, value_col])
65
  if df.empty:
66
+ return f"❌ No valid data after cleaning '{date_col}' / '{value_col}'."
67
 
68
+ # Aggregate duplicate timestamps β†’ mean
69
  df = (
70
  df[[date_col, value_col]]
71
  .groupby(date_col, as_index=True)
 
73
  .sort_index()
74
  )
75
 
76
+ # Infer or default frequency
77
+ freq = pd.infer_freq(df.index) or "D"
78
  try:
79
  df = df.asfreq(freq)
80
  except Exception:
81
+ # fallback if duplicates still exist
82
+ df = (
83
+ df[~df.index.duplicated(keep="first")]
84
+ .asfreq(freq)
85
+ )
86
 
87
+ # ── 4. Fit ARIMA(1,1,1) ──────────────────────────────────
88
  try:
89
  model = ARIMA(df[value_col], order=(1, 1, 1))
90
  fit = model.fit()
91
  except Exception as exc:
92
  return f"❌ ARIMA fitting failed: {exc}"
93
 
94
+ # ── 5. Forecast ──────────────────────────────────────────
95
  try:
96
  pred = fit.get_forecast(steps=periods)
97
  forecast = pred.predicted_mean
98
  except Exception as exc:
99
  return f"❌ Forecast generation failed: {exc}"
100
+
101
+ forecast_df = forecast.to_frame(name="Forecast")
102
+
103
+ # ── 6. Plot history + forecast ───────────────────────────
104
+ fig: Plot = go.Figure()
105
+ fig.add_scatter(x=df.index, y=df[value_col], mode="lines", name="History")
106
+ fig.add_scatter(
107
+ x=forecast.index, y=forecast, mode="lines+markers", name="Forecast"
108
  )
109
  fig.update_layout(
110
  title=f"{value_col} Forecast",
111
  xaxis_title=date_col,
112
  yaxis_title=value_col,
113
+ template="plotly_dark",
114
  )
115
 
116
+ # ── 7. Save PNG ──────────────────────────────────────────
117
  os.makedirs(output_dir, exist_ok=True)
118
+ tmp_png = tempfile.NamedTemporaryFile(
119
+ prefix="forecast_", suffix=".png", dir=output_dir, delete=False
120
+ )
121
+ png_path = tmp_png.name
122
+ tmp_png.close()
123
  try:
124
+ fig.write_image(png_path, scale=2)
125
  except Exception as exc:
126
  return f"❌ Plot saving failed: {exc}"
127
 
128
+ return forecast_df, png_path