XPMaster commited on
Commit
8e0f382
β€’
1 Parent(s): 640cbf4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -135
app.py CHANGED
@@ -1,125 +1,125 @@
1
- # import streamlit as st
2
- # import pandas as pd
3
- # from io import BytesIO
4
- # from itertools import product
5
- # from statsmodels.tsa.statespace.sarimax import SARIMAX
6
- # import plotly.express as px
7
 
8
- # st.set_page_config(layout="wide")
9
 
10
- # # Function to run the SARIMAX Model
11
- # def run_sarimax(city_data, order, seasonal_order):
12
- # try:
13
- # # Check if the data is non-empty and in the correct format
14
- # if city_data.empty:
15
- # st.error(f"No data available for modeling.")
16
- # return None, None
17
 
18
- # model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
19
- # model_fit = model.fit(disp=False)
20
- # forecast = model_fit.forecast(steps=6)
21
 
22
- # # Check if the forecast is valid
23
- # if forecast is None or forecast.empty:
24
- # st.error(f"Forecast failed, the model returned an empty forecast.")
25
- # return None, None
26
 
27
- # return forecast, model_fit.aic
28
- # except Exception as e:
29
- # st.error(f"An error occurred during model fitting: {e}")
30
- # return None, None
31
-
32
- # def create_data():
33
- # # Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
34
- # data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
35
- # data.index = pd.to_datetime(data.index, format='%Y%m')
36
- # data = data.groupby('City').resample('M').sum().reset_index()
37
- # data.index = data['Accident Month Bracket']
38
- # data = data.drop(['Accident Month Bracket'], axis=1)
39
- # data.index = data.index.strftime('%Y-%m')
40
- # return data
41
-
42
- # def to_excel(df):
43
- # output = BytesIO()
44
- # writer = pd.ExcelWriter(output, engine='xlsxwriter')
45
- # df.to_excel(writer, sheet_name='Sheet1')
46
- # writer.save()
47
- # processed_data = output.getvalue()
48
- # return processed_data
49
-
50
- # # Initialize session state for best parameters
51
- # if 'best_params' not in st.session_state:
52
- # st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
53
-
54
- # st.title("SARIMAX Forecasting")
55
-
56
- # # Data preparation
57
- # data = create_data()
58
- # unique_cities = data['City'].unique()
59
-
60
- # # Creating tabs for each city
61
- # tabs = st.tabs([city for city in unique_cities])
62
-
63
- # for tab, city in zip(tabs, unique_cities):
64
- # with tab:
65
- # # SARIMAX specific sliders
66
- # p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
67
- # d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
68
- # q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
69
- # P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
70
- # D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
71
- # Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
72
- # S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
73
-
74
- # city_data = data[data['City'] == city]['Accident Count']
75
- # forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
76
-
77
- # if forecast is not None:
78
- # st.write(f"Best Parameters with AIC: {aic}")
79
- # st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
80
- # forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
81
- # forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
82
- # forecast_df = pd.DataFrame(forecast, columns=['predicted_mean'])
83
- # forecast_df = forecast_df.round(0)
84
- # st.table(forecast_df)
85
- # fig = px.line(forecast_df, x=forecast_df.index, y="predicted_mean")
86
- # st.plotly_chart(fig)
87
-
88
- # # Grid search button
89
- # if st.button(f'Run Grid Search for {city}'):
90
- # best_aic = float('inf')
91
- # best_params = None
92
- # # Define the range for each parameter
93
- # p_range = d_range = q_range = range(3)
94
- # P_range = D_range = Q_range = range(3)
95
- # S = 12 # Assuming a fixed seasonal period, adjust as needed
96
 
97
- # # Perform the grid search
98
- # for params in product(p_range, d_range, q_range, P_range, D_range, Q_range):
99
- # order = params[:3]
100
- # seasonal_order = params[3:] + (S,)
101
- # try:
102
- # _, temp_aic = run_sarimax(city_data, order, seasonal_order)
103
- # if temp_aic < best_aic:
104
- # best_aic = temp_aic
105
- # best_params = (order, seasonal_order)
106
- # except Exception as e:
107
- # st.error(f"An error occurred for parameters {params}: {e}")
108
 
109
- # # Update the session state with the best parameters
110
- # if best_params is not None:
111
- # st.session_state.best_params = {
112
- # 'order': best_params[0],
113
- # 'seasonal_order': best_params[1]
114
- # }
115
- # st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
116
 
117
 
118
- # # Export to Excel button
119
- # if st.button(f'Export {city} to Excel'):
120
- # df_to_export = forecast_df
121
- # excel_data = to_excel(df_to_export)
122
- # st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
123
 
124
 
125
 
@@ -190,26 +190,26 @@
190
 
191
 
192
 
193
- import plotly.express as px
194
- import streamlit as st
195
- from streamlit_plotly_events import plotly_events
196
-
197
- # Sample data
198
- df = px.data.gapminder().query("country=='Canada'")
199
- fig = px.line(df, x="year", y="lifeExp", title='Life Expectancy in Canada Over Years')
200
-
201
- # Capture the selected points
202
- selected_points = plotly_events(fig, click_event=True)
203
-
204
- # Handle the click event
205
- if selected_points:
206
- st.write("You clicked on:", selected_points)
207
- point_index = selected_points[0]['pointIndex']
208
- new_value = st.number_input('Enter new value for life expectancy', value=df.iloc[point_index]['lifeExp'])
209
- if st.button('Update Data'):
210
- df.at[point_index, 'lifeExp'] = new_value
211
- fig = px.line(df, x="year", y="lifeExp", title='Life Expectancy in Canada Over Years')
212
- st.plotly_chart(fig)
213
- else:
214
- st.plotly_chart(fig)
215
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from io import BytesIO
4
+ from itertools import product
5
+ from statsmodels.tsa.statespace.sarimax import SARIMAX
6
+ import plotly.express as px
7
 
8
+ st.set_page_config(layout="wide")
9
 
10
+ # Function to run the SARIMAX Model
11
+ def run_sarimax(city_data, order, seasonal_order):
12
+ try:
13
+ # Check if the data is non-empty and in the correct format
14
+ if city_data.empty:
15
+ st.error(f"No data available for modeling.")
16
+ return None, None
17
 
18
+ model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
19
+ model_fit = model.fit(disp=False)
20
+ forecast = model_fit.forecast(steps=6)
21
 
22
+ # Check if the forecast is valid
23
+ if forecast is None or forecast.empty:
24
+ st.error(f"Forecast failed, the model returned an empty forecast.")
25
+ return None, None
26
 
27
+ return forecast, model_fit.aic
28
+ except Exception as e:
29
+ st.error(f"An error occurred during model fitting: {e}")
30
+ return None, None
31
+
32
+ def create_data():
33
+ # Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
34
+ data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
35
+ data.index = pd.to_datetime(data.index, format='%Y%m')
36
+ data = data.groupby('City').resample('M').sum().reset_index()
37
+ data.index = data['Accident Month Bracket']
38
+ data = data.drop(['Accident Month Bracket'], axis=1)
39
+ data.index = data.index.strftime('%Y-%m')
40
+ return data
41
+
42
+ def to_excel(df):
43
+ output = BytesIO()
44
+ writer = pd.ExcelWriter(output, engine='xlsxwriter')
45
+ df.to_excel(writer, sheet_name='Sheet1')
46
+ writer.save()
47
+ processed_data = output.getvalue()
48
+ return processed_data
49
+
50
+ # Initialize session state for best parameters
51
+ if 'best_params' not in st.session_state:
52
+ st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
53
+
54
+ st.title("SARIMAX Forecasting")
55
+
56
+ # Data preparation
57
+ data = create_data()
58
+ unique_cities = data['City'].unique()
59
+
60
+ # Creating tabs for each city
61
+ tabs = st.tabs([city for city in unique_cities])
62
+
63
+ for tab, city in zip(tabs, unique_cities):
64
+ with tab:
65
+ # SARIMAX specific sliders
66
+ p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
67
+ d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
68
+ q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
69
+ P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
70
+ D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
71
+ Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
72
+ S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
73
+
74
+ city_data = data[data['City'] == city]['Accident Count']
75
+ forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
76
+
77
+ if forecast is not None:
78
+ st.write(f"Best Parameters with AIC: {aic}")
79
+ st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
80
+ forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
81
+ forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
82
+ forecast_df = pd.DataFrame(forecast, columns=['predicted_mean'])
83
+ forecast_df = forecast_df.round(0)
84
+ st.table(forecast_df)
85
+ fig = px.line(forecast_df, x=forecast_df.index, y="predicted_mean")
86
+ st.plotly_chart(fig)
87
+
88
+ # Grid search button
89
+ if st.button(f'Run Grid Search for {city}'):
90
+ best_aic = float('inf')
91
+ best_params = None
92
+ # Define the range for each parameter
93
+ p_range = d_range = q_range = range(3)
94
+ P_range = D_range = Q_range = range(3)
95
+ S = 12 # Assuming a fixed seasonal period, adjust as needed
96
 
97
+ # Perform the grid search
98
+ for params in product(p_range, d_range, q_range, P_range, D_range, Q_range):
99
+ order = params[:3]
100
+ seasonal_order = params[3:] + (S,)
101
+ try:
102
+ _, temp_aic = run_sarimax(city_data, order, seasonal_order)
103
+ if temp_aic < best_aic:
104
+ best_aic = temp_aic
105
+ best_params = (order, seasonal_order)
106
+ except Exception as e:
107
+ st.error(f"An error occurred for parameters {params}: {e}")
108
 
109
+ # Update the session state with the best parameters
110
+ if best_params is not None:
111
+ st.session_state.best_params = {
112
+ 'order': best_params[0],
113
+ 'seasonal_order': best_params[1]
114
+ }
115
+ st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
116
 
117
 
118
+ # Export to Excel button
119
+ if st.button(f'Export {city} to Excel'):
120
+ df_to_export = forecast_df
121
+ excel_data = to_excel(df_to_export)
122
+ st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
123
 
124
 
125
 
 
190
 
191
 
192
 
193
+ # import plotly.express as px
194
+ # import streamlit as st
195
+ # from streamlit_plotly_events import plotly_events
196
+
197
+ # # Sample data
198
+ # df = px.data.gapminder().query("country=='Canada'")
199
+ # fig = px.line(df, x="year", y="lifeExp", title='Life Expectancy in Canada Over Years')
200
+
201
+ # # Capture the selected points
202
+ # selected_points = plotly_events(fig, click_event=True)
203
+
204
+ # # Handle the click event
205
+ # if selected_points:
206
+ # st.write("You clicked on:", selected_points)
207
+ # point_index = selected_points[0]['pointIndex']
208
+ # new_value = st.number_input('Enter new value for life expectancy', value=df.iloc[point_index]['lifeExp'])
209
+ # if st.button('Update Data'):
210
+ # df.at[point_index, 'lifeExp'] = new_value
211
+ # fig = px.line(df, x="year", y="lifeExp", title='Life Expectancy in Canada Over Years')
212
+ # st.plotly_chart(fig)
213
+ # else:
214
+ # st.plotly_chart(fig)
215