mgbam commited on
Commit
1fadf44
·
verified ·
1 Parent(s): d121885

Update tools/forecaster.py

Browse files
Files changed (1) hide show
  1. tools/forecaster.py +37 -8
tools/forecaster.py CHANGED
@@ -1,16 +1,45 @@
1
-
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  from statsmodels.tsa.arima.model import ARIMA
5
 
6
- def forecast_tool(file_path: str) -> str:
 
 
 
 
 
 
 
7
  df = pd.read_csv(file_path)
8
- df['Month'] = pd.to_datetime(df['Month'])
9
- df.set_index('Month', inplace=True)
10
- model = ARIMA(df['Sales'], order=(1, 1, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model_fit = model.fit()
12
  forecast = model_fit.forecast(steps=3)
13
- df_forecast = pd.DataFrame(forecast, columns=['Forecast'])
14
- df_forecast.plot(title="Sales Forecast", figsize=(10, 6))
 
15
  plt.savefig("forecast_plot.png")
16
- return "Generated forecast_plot.png"
 
 
 
1
  import pandas as pd
2
  import matplotlib.pyplot as plt
3
  from statsmodels.tsa.arima.model import ARIMA
4
 
5
+ def forecast_tool(file_path: str, date_col: str | None = None) -> str:
6
+ """
7
+ Forecast the next 3 periods of the 'Sales' column.
8
+ • If date_col is provided, use it.
9
+ • Otherwise auto‑detect the first column that can be parsed as dates.
10
+
11
+ Returns human‑readable summary and saves 'forecast_plot.png'.
12
+ """
13
  df = pd.read_csv(file_path)
14
+
15
+ # Auto‑detect date column if not specified
16
+ if date_col is None:
17
+ for col in df.columns:
18
+ try:
19
+ pd.to_datetime(df[col])
20
+ date_col = col
21
+ break
22
+ except Exception:
23
+ continue
24
+ if date_col is None:
25
+ return "❌ No parseable date column found."
26
+
27
+ # Parse the date column
28
+ try:
29
+ df[date_col] = pd.to_datetime(df[date_col])
30
+ except Exception:
31
+ return f"❌ Column '{date_col}' cannot be parsed as dates."
32
+
33
+ if "Sales" not in df.columns:
34
+ return "❌ CSV must contain a 'Sales' column."
35
+
36
+ df.set_index(date_col, inplace=True)
37
+ model = ARIMA(df["Sales"], order=(1, 1, 1))
38
  model_fit = model.fit()
39
  forecast = model_fit.forecast(steps=3)
40
+
41
+ forecast_df = pd.DataFrame(forecast, columns=["Forecast"])
42
+ forecast_df.plot(title="Sales Forecast", figsize=(10, 6))
43
  plt.savefig("forecast_plot.png")
44
+
45
+ return forecast_df.to_string()