XPMaster commited on
Commit
939b379
β€’
1 Parent(s): 24d2e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -27
app.py CHANGED
@@ -2,29 +2,31 @@ import streamlit as st
2
  import pandas as pd
3
  from io import BytesIO
4
  from itertools import product
5
- from statsmodels.tsa.holtwinters import ExponentialSmoothing
6
  import plotly.express as px
 
7
  st.set_page_config(layout="wide")
8
- # Function to run the Exponential Smoothing Model
9
- def run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period):
 
10
  try:
11
- model = ExponentialSmoothing(city_data, trend=trend, damped_trend=damped_trend, seasonal=seasonal, seasonal_periods=seasonal_period)
12
- model_fit = model.fit(optimized=True)
13
  return model_fit.forecast(steps=6), model_fit.aic
14
  except Exception as e:
15
  st.error(f"An error occurred during model fitting: {e}")
16
  return None, None
17
 
18
  def create_data():
 
19
  data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
20
  data.index = pd.to_datetime(data.index, format='%Y%m')
21
  data = data.groupby('City').resample('M').sum().reset_index()
22
  data.index = data['Accident Month Bracket']
23
- data = data.drop(['Accident Month Bracket'],axis=1)
24
  data.index = data.index.strftime('%Y-%m')
25
  return data
26
 
27
- # Function to convert DataFrame to Excel
28
  def to_excel(df):
29
  output = BytesIO()
30
  writer = pd.ExcelWriter(output, engine='xlsxwriter')
@@ -35,9 +37,9 @@ def to_excel(df):
35
 
36
  # Initialize session state for best parameters
37
  if 'best_params' not in st.session_state:
38
- st.session_state.best_params = {'trend': None, 'damped_trend': False, 'seasonal': None, 'seasonal_period': 12}
39
 
40
- st.title("Exponential Smoothing Forecasting")
41
 
42
  # Data preparation
43
  data = create_data()
@@ -48,18 +50,21 @@ tabs = st.tabs([city for city in unique_cities])
48
 
49
  for tab, city in zip(tabs, unique_cities):
50
  with tab:
51
- # Sliders for parameter adjustment, using session state values as defaults
52
- trend = st.select_slider('Select Trend', options=['add', 'mul', None], value=st.session_state.best_params['trend'],key=city+'1')
53
- damped_trend = False
54
- seasonal = st.select_slider('Select Seasonal', options=['add', 'mul', None], value=st.session_state.best_params['seasonal'],key=city+'2')
55
- seasonal_period = st.slider('Seasonal Period', 1, 24, value=st.session_state.best_params['seasonal_period'],key=city+'3')
 
 
 
56
 
57
  city_data = data[data['City'] == city]['Accident Count']
58
- forecast, aic = run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
59
 
60
  if forecast is not None:
61
  st.write(f"Best Parameters with AIC: {aic}")
62
- st.write(f"Trend: {trend}, Damped Trend: {damped_trend}, Seasonal: {seasonal}, Seasonal Period: {seasonal_period}")
63
  forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
64
  forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
65
  forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
@@ -72,24 +77,24 @@ for tab, city in zip(tabs, unique_cities):
72
  if st.button(f'Run Grid Search for {city}'):
73
  best_aic = float('inf')
74
  best_params = None
75
- for param_set in product(['add', 'mul', None], [False], ['add', 'mul', None], [12]):
76
- _, temp_aic = run_exp_smoothing(city_data, *param_set)
77
- if temp_aic and temp_aic < best_aic:
78
- best_aic = temp_aic
79
- best_params = param_set
 
80
 
81
  # Updating session state with the best parameters
82
  st.session_state.best_params = {
83
- 'trend': best_params[0],
84
- 'damped_trend': best_params[1],
85
- 'seasonal': best_params[2],
86
- 'seasonal_period': best_params[3]
87
  }
88
  st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
89
 
90
-
91
  # Export to Excel button
92
  if st.button(f'Export {city} to Excel'):
93
  df_to_export = forecast_df
94
  excel_data = to_excel(df_to_export)
95
- st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
 
 
 
2
  import pandas as pd
3
  from io import BytesIO
4
  from itertools import product
5
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
6
  import plotly.express as px
7
+
8
  st.set_page_config(layout="wide")
9
+
10
+ # Function to run the SARIMAX Model
11
+ def run_sarimax(city_data, order, seasonal_order):
12
  try:
13
+ model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
14
+ model_fit = model.fit(disp=False)
15
  return model_fit.forecast(steps=6), model_fit.aic
16
  except Exception as e:
17
  st.error(f"An error occurred during model fitting: {e}")
18
  return None, None
19
 
20
  def create_data():
21
+ # Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
22
  data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
23
  data.index = pd.to_datetime(data.index, format='%Y%m')
24
  data = data.groupby('City').resample('M').sum().reset_index()
25
  data.index = data['Accident Month Bracket']
26
+ data = data.drop(['Accident Month Bracket'], axis=1)
27
  data.index = data.index.strftime('%Y-%m')
28
  return data
29
 
 
30
  def to_excel(df):
31
  output = BytesIO()
32
  writer = pd.ExcelWriter(output, engine='xlsxwriter')
 
37
 
38
  # Initialize session state for best parameters
39
  if 'best_params' not in st.session_state:
40
+ st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
41
 
42
+ st.title("SARIMAX Forecasting")
43
 
44
  # Data preparation
45
  data = create_data()
 
50
 
51
  for tab, city in zip(tabs, unique_cities):
52
  with tab:
53
+ # SARIMAX specific sliders
54
+ p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
55
+ d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
56
+ q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
57
+ P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
58
+ D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
59
+ Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
60
+ S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
61
 
62
  city_data = data[data['City'] == city]['Accident Count']
63
+ forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
64
 
65
  if forecast is not None:
66
  st.write(f"Best Parameters with AIC: {aic}")
67
+ st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
68
  forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
69
  forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
70
  forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
 
77
  if st.button(f'Run Grid Search for {city}'):
78
  best_aic = float('inf')
79
  best_params = None
80
+ for param_set in product(range(3), repeat=3): # Adjust the range and repeat parameters as needed
81
+ for seasonal_param_set in product(range(3), repeat=4): # Adjust for seasonal parameters
82
+ _, temp_aic = run_sarimax(city_data, param_set, seasonal_param_set+(12,))
83
+ if temp_aic and temp_aic < best_aic:
84
+ best_aic = temp_aic
85
+ best_params = (param_set, seasonal_param_set+(12,))
86
 
87
  # Updating session state with the best parameters
88
  st.session_state.best_params = {
89
+ 'order': best_params[0],
90
+ 'seasonal_order': best_params[1]
 
 
91
  }
92
  st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
93
 
 
94
  # Export to Excel button
95
  if st.button(f'Export {city} to Excel'):
96
  df_to_export = forecast_df
97
  excel_data = to_excel(df_to_export)
98
+ st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
99
+
100
+ # Rest of your code