XPMaster commited on
Commit
d6165ca
β€’
1 Parent(s): 6d819a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -78
app.py CHANGED
@@ -1,9 +1,8 @@
 
1
  import pandas as pd
2
  from io import BytesIO
3
- import streamlit as st
4
- from statsmodels.tsa.holtwinters import ExponentialSmoothing
5
- from statsmodels.tsa.statespace.sarimax import SARIMAX
6
  from itertools import product
 
7
  import plotly.express as px
8
 
9
  # Function to run the Exponential Smoothing Model
@@ -16,16 +15,6 @@ def run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
16
  st.error(f"An error occurred during model fitting: {e}")
17
  return None, None
18
 
19
- # Function to run SARIMAX Model
20
- def run_sarimax(city_data, order, seasonal_order):
21
- try:
22
- model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order)
23
- model_fit = model.fit(disp=False)
24
- return model_fit.forecast(steps=6), model_fit.aic
25
- except Exception as e:
26
- st.error(f"An error occurred during SARIMAX model fitting: {e}")
27
- return None, None
28
-
29
  def create_data():
30
  data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
31
  data.index = pd.to_datetime(data.index, format='%Y%m')
@@ -35,6 +24,7 @@ def create_data():
35
  data.index = data.index.strftime('%Y-%m')
36
  return data
37
 
 
38
  def to_excel(df):
39
  output = BytesIO()
40
  writer = pd.ExcelWriter(output, engine='xlsxwriter')
@@ -43,74 +33,62 @@ def to_excel(df):
43
  processed_data = output.getvalue()
44
  return processed_data
45
 
 
46
  if 'best_params' not in st.session_state:
47
- st.session_state.best_params = {'trend': None, 'damped_trend': False, 'seasonal': None, 'seasonal_period': 12, 'model_type': 'ExpSmoothing'}
48
 
49
- st.title("Exponential Smoothing and SARIMAX Forecasting")
50
 
 
51
  data = create_data()
52
  unique_cities = data['City'].unique()
53
 
54
- selected_city = st.selectbox('Select a City', unique_cities)
55
-
56
- model_type = st.selectbox('Select Model Type', ['ExpSmoothing', 'SARIMAX'])
57
-
58
- if model_type == 'ExpSmoothing':
59
- trend = st.select_slider('Select Trend', options=['add', 'mul', None], value=st.session_state.best_params['trend'])
60
- damped_trend = st.checkbox('Damped Trend', value=st.session_state.best_params['damped_trend'])
61
- seasonal = st.select_slider('Select Seasonal', options=['add', 'mul', None], value=st.session_state.best_params['seasonal'])
62
- seasonal_period = st.slider('Seasonal Period', 1, 24, value=st.session_state.best_params['seasonal_period'])
63
- elif model_type == 'SARIMAX':
64
- p = st.slider('AR Order (p)', 0, 5, 0)
65
- d = st.slider('Differencing (d)', 0, 2, 1)
66
- q = st.slider('MA Order (q)', 0, 5, 0)
67
- P = st.slider('Seasonal AR Order (P)', 0, 2, 0)
68
- D = st.slider('Seasonal Differencing (D)', 0, 2, 1)
69
- Q = st.slider('Seasonal MA Order (Q)', 0, 2, 0)
70
- S = st.slider('Seasonal Period (S)', 1, 24, 12)
71
-
72
- city_data = data[data['City'] == selected_city]['Accident Count']
73
-
74
- if model_type == 'ExpSmoothing':
75
- forecast, aic = run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
76
- elif model_type == 'SARIMAX':
77
- order = (p, d, q)
78
- seasonal_order = (P, D, Q, S)
79
- forecast, aic = run_sarimax(city_data, order, seasonal_order)
80
-
81
- if forecast is not None:
82
- st.write(f"Best Parameters with AIC: {aic}")
83
- st.write(f"Forecast:")
84
- forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
85
- forecast_index = forecast_index.to_period('M')
86
- forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
87
- forecast_df = forecast_df.round(0)
88
- st.table(forecast_df)
89
- fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
90
- st.plotly_chart(fig)
91
-
92
- # Grid Search Logic for Both Models
93
- if st.button('Run Grid Search'):
94
- best_aic = float('inf')
95
- best_params = None
96
-
97
- if model_type == 'ExpSmoothing':
98
- for param_set in product(['add', 'mul', None], [False], ['add', 'mul', None], [12]):
99
- _, temp_aic = run_exp_smoothing(city_data, *param_set)
100
- if temp_aic and temp_aic < best_aic:
101
- best_aic = temp_aic
102
- best_params = param_set
103
- elif model_type == 'SARIMAX':
104
- for param_set in product(range(3), range(2), range(3), range(2), range(2), range(2), [12]):
105
- _, temp_aic = run_sarimax(city_data, param_set[:3], param_set[3:])
106
- if temp_aic and temp_aic < best_aic:
107
- best_aic = temp_aic
108
- best_params = param_set
109
-
110
- st.session_state.best_params = best_params
111
- st.write(f"Best Parameters: {best_params} with AIC: {best_aic}")
112
-
113
- if st.button('Export to Excel'):
114
- df_to_export = forecast_df
115
- excel_data = to_excel(df_to_export)
116
- st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name='forecast.xlsx', mime='application/vnd.ms-excel')
 
1
+ import streamlit as st
2
  import pandas as pd
3
  from io import BytesIO
 
 
 
4
  from itertools import product
5
+ from statsmodels.tsa.holtwinters import ExponentialSmoothing
6
  import plotly.express as px
7
 
8
  # Function to run the Exponential Smoothing Model
 
15
  st.error(f"An error occurred during model fitting: {e}")
16
  return None, None
17
 
 
 
 
 
 
 
 
 
 
 
18
  def create_data():
19
  data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
20
  data.index = pd.to_datetime(data.index, format='%Y%m')
 
24
  data.index = data.index.strftime('%Y-%m')
25
  return data
26
 
27
+ # Function to convert DataFrame to Excel
28
  def to_excel(df):
29
  output = BytesIO()
30
  writer = pd.ExcelWriter(output, engine='xlsxwriter')
 
33
  processed_data = output.getvalue()
34
  return processed_data
35
 
36
+ # Initialize session state for best parameters
37
  if 'best_params' not in st.session_state:
38
+ st.session_state.best_params = {'trend': None, 'damped_trend': False, 'seasonal': None, 'seasonal_period': 12}
39
 
40
+ st.title("Exponential Smoothing Forecasting")
41
 
42
+ # Data preparation
43
  data = create_data()
44
  unique_cities = data['City'].unique()
45
 
46
+ # Creating tabs for each city
47
+ tab_dict = {city: st.tab(city) for city in unique_cities}
48
+ for city, tab in tab_dict.items():
49
+ with tab:
50
+ # Sliders for parameter adjustment, using session state values as defaults
51
+ trend = st.select_slider('Select Trend', options=['add', 'mul', None], value=st.session_state.best_params['trend'])
52
+ damped_trend = False
53
+ seasonal = st.select_slider('Select Seasonal', options=['add', 'mul', None], value=st.session_state.best_params['seasonal'])
54
+ seasonal_period = st.slider('Seasonal Period', 1, 24, value=st.session_state.best_params['seasonal_period'])
55
+
56
+ city_data = data[data['City'] == city]['Accident Count']
57
+ forecast, aic = run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
58
+
59
+ if forecast is not None:
60
+ st.write(f"Best Parameters with AIC: {aic}")
61
+ st.write(f"Trend: {trend}, Damped Trend: {damped_trend}, Seasonal: {seasonal}, Seasonal Period: {seasonal_period}")
62
+ forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
63
+ forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
64
+ forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
65
+ forecast_df = forecast_df.round(0)
66
+ st.table(forecast_df)
67
+ fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
68
+ st.plotly_chart(fig)
69
+
70
+ # Grid search button
71
+ if st.button(f'Run Grid Search for {city}'):
72
+ best_aic = float('inf')
73
+ best_params = None
74
+ for param_set in product(['add', 'mul', None], [False], ['add', 'mul', None], [12]):
75
+ _, temp_aic = run_exp_smoothing(city_data, *param_set)
76
+ if temp_aic and temp_aic < best_aic:
77
+ best_aic = temp_aic
78
+ best_params = param_set
79
+
80
+ # Updating session state with the best parameters
81
+ st.session_state.best_params = {
82
+ 'trend': best_params[0],
83
+ 'damped_trend': best_params[1],
84
+ 'seasonal': best_params[2],
85
+ 'seasonal_period': best_params[3]
86
+ }
87
+ st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
88
+
89
+
90
+ # Export to Excel button
91
+ if st.button(f'Export {city} to Excel'):
92
+ df_to_export = forecast_df
93
+ excel_data = to_excel(df_to_export)
94
+ st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')