XPMaster commited on
Commit
c141f0a
1 Parent(s): 8e0f382

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -116
app.py CHANGED
@@ -1,125 +1,125 @@
1
- import streamlit as st
2
- import pandas as pd
3
- from io import BytesIO
4
- from itertools import product
5
- from statsmodels.tsa.statespace.sarimax import SARIMAX
6
- import plotly.express as px
7
-
8
- st.set_page_config(layout="wide")
9
-
10
- # Function to run the SARIMAX Model
11
- def run_sarimax(city_data, order, seasonal_order):
12
- try:
13
- # Check if the data is non-empty and in the correct format
14
- if city_data.empty:
15
- st.error(f"No data available for modeling.")
16
- return None, None
17
-
18
- model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
19
- model_fit = model.fit(disp=False)
20
- forecast = model_fit.forecast(steps=6)
21
 
22
- # Check if the forecast is valid
23
- if forecast is None or forecast.empty:
24
- st.error(f"Forecast failed, the model returned an empty forecast.")
25
- return None, None
26
 
27
- return forecast, model_fit.aic
28
- except Exception as e:
29
- st.error(f"An error occurred during model fitting: {e}")
30
- return None, None
31
-
32
- def create_data():
33
- # Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
34
- data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
35
- data.index = pd.to_datetime(data.index, format='%Y%m')
36
- data = data.groupby('City').resample('M').sum().reset_index()
37
- data.index = data['Accident Month Bracket']
38
- data = data.drop(['Accident Month Bracket'], axis=1)
39
- data.index = data.index.strftime('%Y-%m')
40
- return data
41
-
42
- def to_excel(df):
43
- output = BytesIO()
44
- writer = pd.ExcelWriter(output, engine='xlsxwriter')
45
- df.to_excel(writer, sheet_name='Sheet1')
46
- writer.save()
47
- processed_data = output.getvalue()
48
- return processed_data
49
-
50
- # Initialize session state for best parameters
51
- if 'best_params' not in st.session_state:
52
- st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
53
-
54
- st.title("SARIMAX Forecasting")
55
-
56
- # Data preparation
57
- data = create_data()
58
- unique_cities = data['City'].unique()
59
-
60
- # Creating tabs for each city
61
- tabs = st.tabs([city for city in unique_cities])
62
-
63
- for tab, city in zip(tabs, unique_cities):
64
- with tab:
65
- # SARIMAX specific sliders
66
- p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
67
- d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
68
- q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
69
- P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
70
- D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
71
- Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
72
- S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
73
-
74
- city_data = data[data['City'] == city]['Accident Count']
75
- forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
76
-
77
- if forecast is not None:
78
- st.write(f"Best Parameters with AIC: {aic}")
79
- st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
80
- forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
81
- forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
82
- forecast_df = pd.DataFrame(forecast, columns=['predicted_mean'])
83
- forecast_df = forecast_df.round(0)
84
- st.table(forecast_df)
85
- fig = px.line(forecast_df, x=forecast_df.index, y="predicted_mean")
86
- st.plotly_chart(fig)
87
-
88
- # Grid search button
89
- if st.button(f'Run Grid Search for {city}'):
90
- best_aic = float('inf')
91
- best_params = None
92
- # Define the range for each parameter
93
- p_range = d_range = q_range = range(3)
94
- P_range = D_range = Q_range = range(3)
95
- S = 12 # Assuming a fixed seasonal period, adjust as needed
96
 
97
- # Perform the grid search
98
- for params in product(p_range, d_range, q_range, P_range, D_range, Q_range):
99
- order = params[:3]
100
- seasonal_order = params[3:] + (S,)
101
- try:
102
- _, temp_aic = run_sarimax(city_data, order, seasonal_order)
103
- if temp_aic < best_aic:
104
- best_aic = temp_aic
105
- best_params = (order, seasonal_order)
106
- except Exception as e:
107
- st.error(f"An error occurred for parameters {params}: {e}")
108
 
109
- # Update the session state with the best parameters
110
- if best_params is not None:
111
- st.session_state.best_params = {
112
- 'order': best_params[0],
113
- 'seasonal_order': best_params[1]
114
- }
115
- st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
116
 
117
 
118
- # Export to Excel button
119
- if st.button(f'Export {city} to Excel'):
120
- df_to_export = forecast_df
121
- excel_data = to_excel(df_to_export)
122
- st.download_button(label='📥 Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
123
 
124
 
125
 
@@ -133,6 +133,33 @@ for tab, city in zip(tabs, unique_cities):
133
 
134
 
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
 
138
 
 
1
+ # import streamlit as st
2
+ # import pandas as pd
3
+ # from io import BytesIO
4
+ # from itertools import product
5
+ # from statsmodels.tsa.statespace.sarimax import SARIMAX
6
+ # import plotly.express as px
7
+
8
+ # st.set_page_config(layout="wide")
9
+
10
+ # # Function to run the SARIMAX Model
11
+ # def run_sarimax(city_data, order, seasonal_order):
12
+ # try:
13
+ # # Check if the data is non-empty and in the correct format
14
+ # if city_data.empty:
15
+ # st.error(f"No data available for modeling.")
16
+ # return None, None
17
+
18
+ # model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
19
+ # model_fit = model.fit(disp=False)
20
+ # forecast = model_fit.forecast(steps=6)
21
 
22
+ # # Check if the forecast is valid
23
+ # if forecast is None or forecast.empty:
24
+ # st.error(f"Forecast failed, the model returned an empty forecast.")
25
+ # return None, None
26
 
27
+ # return forecast, model_fit.aic
28
+ # except Exception as e:
29
+ # st.error(f"An error occurred during model fitting: {e}")
30
+ # return None, None
31
+
32
+ # def create_data():
33
+ # # Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
34
+ # data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
35
+ # data.index = pd.to_datetime(data.index, format='%Y%m')
36
+ # data = data.groupby('City').resample('M').sum().reset_index()
37
+ # data.index = data['Accident Month Bracket']
38
+ # data = data.drop(['Accident Month Bracket'], axis=1)
39
+ # data.index = data.index.strftime('%Y-%m')
40
+ # return data
41
+
42
+ # def to_excel(df):
43
+ # output = BytesIO()
44
+ # writer = pd.ExcelWriter(output, engine='xlsxwriter')
45
+ # df.to_excel(writer, sheet_name='Sheet1')
46
+ # writer.save()
47
+ # processed_data = output.getvalue()
48
+ # return processed_data
49
+
50
+ # # Initialize session state for best parameters
51
+ # if 'best_params' not in st.session_state:
52
+ # st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
53
+
54
+ # st.title("SARIMAX Forecasting")
55
+
56
+ # # Data preparation
57
+ # data = create_data()
58
+ # unique_cities = data['City'].unique()
59
+
60
+ # # Creating tabs for each city
61
+ # tabs = st.tabs([city for city in unique_cities])
62
+
63
+ # for tab, city in zip(tabs, unique_cities):
64
+ # with tab:
65
+ # # SARIMAX specific sliders
66
+ # p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
67
+ # d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
68
+ # q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
69
+ # P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
70
+ # D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
71
+ # Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
72
+ # S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
73
+
74
+ # city_data = data[data['City'] == city]['Accident Count']
75
+ # forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
76
+
77
+ # if forecast is not None:
78
+ # st.write(f"Best Parameters with AIC: {aic}")
79
+ # st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
80
+ # forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
81
+ # forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
82
+ # forecast_df = pd.DataFrame(forecast, columns=['predicted_mean'])
83
+ # forecast_df = forecast_df.round(0)
84
+ # st.table(forecast_df)
85
+ # fig = px.line(forecast_df, x=forecast_df.index, y="predicted_mean")
86
+ # st.plotly_chart(fig)
87
+
88
+ # # Grid search button
89
+ # if st.button(f'Run Grid Search for {city}'):
90
+ # best_aic = float('inf')
91
+ # best_params = None
92
+ # # Define the range for each parameter
93
+ # p_range = d_range = q_range = range(3)
94
+ # P_range = D_range = Q_range = range(3)
95
+ # S = 12 # Assuming a fixed seasonal period, adjust as needed
96
 
97
+ # # Perform the grid search
98
+ # for params in product(p_range, d_range, q_range, P_range, D_range, Q_range):
99
+ # order = params[:3]
100
+ # seasonal_order = params[3:] + (S,)
101
+ # try:
102
+ # _, temp_aic = run_sarimax(city_data, order, seasonal_order)
103
+ # if temp_aic < best_aic:
104
+ # best_aic = temp_aic
105
+ # best_params = (order, seasonal_order)
106
+ # except Exception as e:
107
+ # st.error(f"An error occurred for parameters {params}: {e}")
108
 
109
+ # # Update the session state with the best parameters
110
+ # if best_params is not None:
111
+ # st.session_state.best_params = {
112
+ # 'order': best_params[0],
113
+ # 'seasonal_order': best_params[1]
114
+ # }
115
+ # st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
116
 
117
 
118
+ # # Export to Excel button
119
+ # if st.button(f'Export {city} to Excel'):
120
+ # df_to_export = forecast_df
121
+ # excel_data = to_excel(df_to_export)
122
+ # st.download_button(label='📥 Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
123
 
124
 
125
 
 
133
 
134
 
135
 
136
+ import streamlit as st
137
+
138
+ def deletefile(filename, dfname):
139
+ # Your delete file logic here
140
+ pass
141
+
142
+ def spawnbutton(filename, dfname):
143
+ # Check if the button has already been clicked
144
+ if st.session_state.get(f"{filename}_clicked", False):
145
+ # Button logic after being clicked (if any)
146
+ pass
147
+ else:
148
+ # Show the button if it hasn't been clicked yet
149
+ if st.button(f"Delete file ({st.session_state[filename].name})", use_container_width=True, key=f'{filename}_deleter'):
150
+ deletefile(filename, dfname)
151
+ statecheck(dfname)
152
+ statecheck(filename)
153
+ # Set the state to indicate the button has been clicked
154
+ st.session_state[f"{filename}_clicked"] = True
155
+
156
+ # Example usage
157
+ if 'example_file' not in st.session_state:
158
+ st.session_state['example_file'] = "file.txt"
159
+ if 'example_df' not in st.session_state:
160
+ st.session_state['example_df'] = "dataframe"
161
+
162
+ spawnbutton('example_file', 'example_df')
163
 
164
 
165