mgbam commited on
Commit
eec9db3
·
verified ·
1 Parent(s): e1d8bc9

Update tools/forecaster.py

Browse files
Files changed (1) hide show
  1. tools/forecaster.py +76 -40
tools/forecaster.py CHANGED
@@ -1,75 +1,111 @@
1
- # tools/forecaster.py
2
-
3
  import pandas as pd
4
  from statsmodels.tsa.arima.model import ARIMA
5
  import plotly.graph_objects as go
6
 
7
- def forecast_metric_tool(file_path: str, date_col: str, value_col: str):
8
- """
9
- Forecast the next 3 periods for any numeric metric.
10
- - Saves a date‐indexed Plotly PNG under /tmp via the safe write monkey‐patch.
11
- - Returns a text table of the forecast.
 
 
 
12
  """
 
 
13
 
14
- # 0) Read full CSV
15
- df = pd.read_csv(file_path)
 
16
 
17
- # 1) Check that both columns actually exist
18
- if date_col not in df.columns:
19
- return f"❌ Date column '{date_col}' not found in your data."
20
- if value_col not in df.columns:
21
- return f"❌ Metric column '{value_col}' not found in your data."
 
 
 
 
 
 
22
 
23
- # 2) Parse dates
 
 
 
 
 
24
  try:
25
  df[date_col] = pd.to_datetime(df[date_col])
26
  except Exception:
27
  return f"❌ Could not parse '{date_col}' as dates."
28
 
29
- # 3) Coerce metric to numeric & drop invalid rows
30
- df[value_col] = pd.to_numeric(df[value_col], errors="coerce")
31
  df = df.dropna(subset=[date_col, value_col])
32
  if df.empty:
33
- return f"❌ After coercion, no valid data remains for '{value_col}'."
34
 
35
- # 4) Sort & index by date, collapse duplicates
36
- df = df.sort_values(date_col).set_index(date_col)
37
- df = df[[value_col]].groupby(level=0).mean()
 
 
 
 
38
 
39
- # 5) Infer a frequency and re‐index
40
  freq = pd.infer_freq(df.index)
41
  if freq is None:
42
- freq = "D" # fallback to daily
43
- df = df.asfreq(freq)
 
 
 
 
44
 
45
- # 6) Fit ARIMA (1,1,1)
46
  try:
47
- model = ARIMA(df[value_col], order=(1, 1, 1))
48
- model_fit = model.fit()
49
  except Exception as e:
50
  return f"❌ ARIMA fitting failed: {e}"
51
 
52
- # 7) Produce a proper date‐indexed forecast
53
- fc_res = model_fit.get_forecast(steps=3)
54
  forecast = fc_res.predicted_mean
 
55
 
56
- # 8) Plot history + forecast
57
  fig = go.Figure()
58
- fig.add_scatter(
59
- x=df.index, y=df[value_col],
60
- mode="lines", name=value_col
 
 
61
  )
62
- fig.add_scatter(
63
- x=forecast.index, y=forecast,
64
- mode="lines+markers", name="Forecast"
 
 
65
  )
66
  fig.update_layout(
67
  title=f"{value_col} Forecast",
68
  xaxis_title=date_col,
69
  yaxis_title=value_col,
70
- template="plotly_dark",
 
 
 
 
 
 
71
  )
72
- fig.write_image("forecast_plot.png") # lands in /tmp via our monkey‐patch
 
 
73
 
74
- # 9) Return the forecast as a text table
75
- return forecast.to_frame(name="Forecast").to_string()
 
1
+ import os
2
+ import tempfile
3
  import pandas as pd
4
  from statsmodels.tsa.arima.model import ARIMA
5
  import plotly.graph_objects as go
6
 
7
+
8
+ def forecast_metric_tool(
9
+ file_path: str,
10
+ date_col: str,
11
+ value_col: str,
12
+ periods: int = 3,
13
+ output_dir: str = "/tmp"
14
+ ):
15
  """
16
+ Load a CSV or Excel file, parse a time series metric, fit an ARIMA(1,1,1) model,
17
+ forecast the next `periods` steps, and save a combined history+forecast plot.
18
 
19
+ Returns:
20
+ forecast_df (pd.DataFrame): next-period predicted values, indexed by date.
21
+ plot_path (str): full path to the saved PNG plot.
22
 
23
+ Errors return a string starting with '❌' describing the problem.
24
+ """
25
+ # 0) Load data (CSV or Excel)
26
+ ext = os.path.splitext(file_path)[1].lower()
27
+ try:
28
+ if ext in ('.xls', '.xlsx'):
29
+ df = pd.read_excel(file_path)
30
+ else:
31
+ df = pd.read_csv(file_path)
32
+ except Exception as e:
33
+ return f"❌ Failed to load file: {e}"
34
 
35
+ # 1) Validate columns
36
+ for col in (date_col, value_col):
37
+ if col not in df.columns:
38
+ return f"❌ Column '{col}' not found."
39
+
40
+ # 2) Parse dates and numeric
41
  try:
42
  df[date_col] = pd.to_datetime(df[date_col])
43
  except Exception:
44
  return f"❌ Could not parse '{date_col}' as dates."
45
 
46
+ df[value_col] = pd.to_numeric(df[value_col], errors='coerce')
 
47
  df = df.dropna(subset=[date_col, value_col])
48
  if df.empty:
49
+ return f"❌ No valid rows after dropping NaNs in '{date_col}'/'{value_col}'."
50
 
51
+ # 3) Aggregate duplicates & index
52
+ df = (
53
+ df[[date_col, value_col]]
54
+ .groupby(date_col, as_index=True)
55
+ .mean()
56
+ .sort_index()
57
+ )
58
 
59
+ # 4) Infer frequency
60
  freq = pd.infer_freq(df.index)
61
  if freq is None:
62
+ freq = 'D' # fallback
63
+ try:
64
+ df = df.asfreq(freq)
65
+ except ValueError as e:
66
+ # if duplicates remain
67
+ df = df[~df.index.duplicated(keep='first')].asfreq(freq)
68
 
69
+ # 5) Fit ARIMA
70
  try:
71
+ model = ARIMA(df[value_col], order=(1, 1, 1))
72
+ fit = model.fit()
73
  except Exception as e:
74
  return f"❌ ARIMA fitting failed: {e}"
75
 
76
+ # 6) Forecast future
77
+ fc_res = fit.get_forecast(steps=periods)
78
  forecast = fc_res.predicted_mean
79
+ forecast_df = forecast.to_frame(name='Forecast')
80
 
81
+ # 7) Plot history + forecast
82
  fig = go.Figure()
83
+ fig.add_trace(
84
+ go.Scatter(
85
+ x=df.index, y=df[value_col],
86
+ mode='lines+markers', name=value_col
87
+ )
88
  )
89
+ fig.add_trace(
90
+ go.Scatter(
91
+ x=forecast.index, y=forecast,
92
+ mode='lines+markers', name='Forecast'
93
+ )
94
  )
95
  fig.update_layout(
96
  title=f"{value_col} Forecast",
97
  xaxis_title=date_col,
98
  yaxis_title=value_col,
99
+ template='plotly_dark',
100
+ )
101
+
102
+ # 8) Save to temporary file
103
+ os.makedirs(output_dir, exist_ok=True)
104
+ tmp = tempfile.NamedTemporaryFile(
105
+ suffix='.png', prefix='forecast_', dir=output_dir, delete=False
106
  )
107
+ plot_path = tmp.name
108
+ tmp.close()
109
+ fig.write_image(plot_path)
110
 
111
+ return forecast_df, plot_path