XPMaster commited on
Commit
6d819a2
β€’
1 Parent(s): 36a8df4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -33
app.py CHANGED
@@ -1,8 +1,9 @@
1
- import streamlit as st
2
  import pandas as pd
3
  from io import BytesIO
4
- from itertools import product
5
  from statsmodels.tsa.holtwinters import ExponentialSmoothing
 
 
6
  import plotly.express as px
7
 
8
  # Function to run the Exponential Smoothing Model
@@ -15,6 +16,16 @@ def run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
15
  st.error(f"An error occurred during model fitting: {e}")
16
  return None, None
17
 
 
 
 
 
 
 
 
 
 
 
18
  def create_data():
19
  data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
20
  data.index = pd.to_datetime(data.index, format='%Y%m')
@@ -24,7 +35,6 @@ def create_data():
24
  data.index = data.index.strftime('%Y-%m')
25
  return data
26
 
27
- # Function to convert DataFrame to Excel
28
  def to_excel(df):
29
  output = BytesIO()
30
  writer = pd.ExcelWriter(output, engine='xlsxwriter')
@@ -33,61 +43,74 @@ def to_excel(df):
33
  processed_data = output.getvalue()
34
  return processed_data
35
 
36
- # Initialize session state for best parameters
37
  if 'best_params' not in st.session_state:
38
- st.session_state.best_params = {'trend': None, 'damped_trend': False, 'seasonal': None, 'seasonal_period': 12}
39
 
40
- st.title("Exponential Smoothing Forecasting")
41
 
42
- # Data preparation
43
  data = create_data()
44
  unique_cities = data['City'].unique()
45
 
46
- # Select a city
47
  selected_city = st.selectbox('Select a City', unique_cities)
48
 
49
- # Sliders for parameter adjustment, using session state values as defaults
50
- trend = st.select_slider('Select Trend', options=['add', 'mul', None], value=st.session_state.best_params['trend'])
51
- damped_trend = False#st.checkbox('Damped Trend', value=st.session_state.best_params['damped_trend'])
52
- seasonal = st.select_slider('Select Seasonal', options=['add', 'mul', None], value=st.session_state.best_params['seasonal'])
53
- seasonal_period = st.slider('Seasonal Period', 1, 24, value=st.session_state.best_params['seasonal_period'])
 
 
 
 
 
 
 
 
 
 
54
 
55
  city_data = data[data['City'] == selected_city]['Accident Count']
56
- forecast, aic = run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
 
 
 
 
 
 
57
 
58
  if forecast is not None:
59
  st.write(f"Best Parameters with AIC: {aic}")
60
- st.write(f"Trend: {trend}, Damped Trend: {damped_trend}, Seasonal: {seasonal}, Seasonal Period: {seasonal_period}")
61
  forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
62
- forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
63
  forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
64
  forecast_df = forecast_df.round(0)
65
  st.table(forecast_df)
66
  fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
67
  st.plotly_chart(fig)
68
 
69
- # Grid search button
70
  if st.button('Run Grid Search'):
71
  best_aic = float('inf')
72
  best_params = None
73
- for param_set in product(['add', 'mul', None], [False], ['add', 'mul', None], [12]):
74
- _, temp_aic = run_exp_smoothing(city_data, *param_set)
75
- if temp_aic and temp_aic < best_aic:
76
- best_aic = temp_aic
77
- best_params = param_set
78
-
79
- # Updating session state with the best parameters
80
- st.session_state.best_params = {
81
- 'trend': best_params[0],
82
- 'damped_trend': best_params[1],
83
- 'seasonal': best_params[2],
84
- 'seasonal_period': best_params[3]
85
- }
 
 
86
  st.write(f"Best Parameters: {best_params} with AIC: {best_aic}")
87
 
88
-
89
- # Export to Excel button
90
  if st.button('Export to Excel'):
91
  df_to_export = forecast_df
92
  excel_data = to_excel(df_to_export)
93
- st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name='forecast.xlsx', mime='application/vnd.ms-excel')
 
 
1
  import pandas as pd
2
  from io import BytesIO
3
+ import streamlit as st
4
  from statsmodels.tsa.holtwinters import ExponentialSmoothing
5
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
6
+ from itertools import product
7
  import plotly.express as px
8
 
9
  # Function to run the Exponential Smoothing Model
 
16
  st.error(f"An error occurred during model fitting: {e}")
17
  return None, None
18
 
19
+ # Function to run SARIMAX Model
20
+ def run_sarimax(city_data, order, seasonal_order):
21
+ try:
22
+ model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order)
23
+ model_fit = model.fit(disp=False)
24
+ return model_fit.forecast(steps=6), model_fit.aic
25
+ except Exception as e:
26
+ st.error(f"An error occurred during SARIMAX model fitting: {e}")
27
+ return None, None
28
+
29
  def create_data():
30
  data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
31
  data.index = pd.to_datetime(data.index, format='%Y%m')
 
35
  data.index = data.index.strftime('%Y-%m')
36
  return data
37
 
 
38
  def to_excel(df):
39
  output = BytesIO()
40
  writer = pd.ExcelWriter(output, engine='xlsxwriter')
 
43
  processed_data = output.getvalue()
44
  return processed_data
45
 
 
46
  if 'best_params' not in st.session_state:
47
+ st.session_state.best_params = {'trend': None, 'damped_trend': False, 'seasonal': None, 'seasonal_period': 12, 'model_type': 'ExpSmoothing'}
48
 
49
+ st.title("Exponential Smoothing and SARIMAX Forecasting")
50
 
 
51
  data = create_data()
52
  unique_cities = data['City'].unique()
53
 
 
54
  selected_city = st.selectbox('Select a City', unique_cities)
55
 
56
+ model_type = st.selectbox('Select Model Type', ['ExpSmoothing', 'SARIMAX'])
57
+
58
+ if model_type == 'ExpSmoothing':
59
+ trend = st.select_slider('Select Trend', options=['add', 'mul', None], value=st.session_state.best_params['trend'])
60
+ damped_trend = st.checkbox('Damped Trend', value=st.session_state.best_params['damped_trend'])
61
+ seasonal = st.select_slider('Select Seasonal', options=['add', 'mul', None], value=st.session_state.best_params['seasonal'])
62
+ seasonal_period = st.slider('Seasonal Period', 1, 24, value=st.session_state.best_params['seasonal_period'])
63
+ elif model_type == 'SARIMAX':
64
+ p = st.slider('AR Order (p)', 0, 5, 0)
65
+ d = st.slider('Differencing (d)', 0, 2, 1)
66
+ q = st.slider('MA Order (q)', 0, 5, 0)
67
+ P = st.slider('Seasonal AR Order (P)', 0, 2, 0)
68
+ D = st.slider('Seasonal Differencing (D)', 0, 2, 1)
69
+ Q = st.slider('Seasonal MA Order (Q)', 0, 2, 0)
70
+ S = st.slider('Seasonal Period (S)', 1, 24, 12)
71
 
72
  city_data = data[data['City'] == selected_city]['Accident Count']
73
+
74
+ if model_type == 'ExpSmoothing':
75
+ forecast, aic = run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
76
+ elif model_type == 'SARIMAX':
77
+ order = (p, d, q)
78
+ seasonal_order = (P, D, Q, S)
79
+ forecast, aic = run_sarimax(city_data, order, seasonal_order)
80
 
81
  if forecast is not None:
82
  st.write(f"Best Parameters with AIC: {aic}")
83
+ st.write(f"Forecast:")
84
  forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
85
+ forecast_index = forecast_index.to_period('M')
86
  forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
87
  forecast_df = forecast_df.round(0)
88
  st.table(forecast_df)
89
  fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
90
  st.plotly_chart(fig)
91
 
92
+ # Grid Search Logic for Both Models
93
  if st.button('Run Grid Search'):
94
  best_aic = float('inf')
95
  best_params = None
96
+
97
+ if model_type == 'ExpSmoothing':
98
+ for param_set in product(['add', 'mul', None], [False], ['add', 'mul', None], [12]):
99
+ _, temp_aic = run_exp_smoothing(city_data, *param_set)
100
+ if temp_aic and temp_aic < best_aic:
101
+ best_aic = temp_aic
102
+ best_params = param_set
103
+ elif model_type == 'SARIMAX':
104
+ for param_set in product(range(3), range(2), range(3), range(2), range(2), range(2), [12]):
105
+ _, temp_aic = run_sarimax(city_data, param_set[:3], param_set[3:])
106
+ if temp_aic and temp_aic < best_aic:
107
+ best_aic = temp_aic
108
+ best_params = param_set
109
+
110
+ st.session_state.best_params = best_params
111
  st.write(f"Best Parameters: {best_params} with AIC: {best_aic}")
112
 
 
 
113
  if st.button('Export to Excel'):
114
  df_to_export = forecast_df
115
  excel_data = to_excel(df_to_export)
116
+ st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name='forecast.xlsx', mime='application/vnd.ms-excel')