File size: 5,247 Bytes
d6165ca fde69ff 6d819a2 939b379 f27b577 939b379 24d2e79 939b379 fde69ff 939b379 fde69ff 8c7c72e fde69ff 939b379 f33264f 939b379 f33264f fde69ff d6165ca f33264f 939b379 fde69ff 939b379 fde69ff d6165ca fde69ff d6165ca 0b7e8df 5f049bf d6165ca 939b379 d6165ca 939b379 d6165ca 939b379 d6165ca 89e8e44 d6165ca 939b379 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import streamlit as st
import pandas as pd
from io import BytesIO
from itertools import product
from statsmodels.tsa.statespace.sarimax import SARIMAX
import plotly.express as px
st.set_page_config(layout="wide")
# Function to run the SARIMAX Model
def run_sarimax(city_data, order, seasonal_order):
try:
model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
model_fit = model.fit(disp=False)
return model_fit.forecast(steps=6), model_fit.aic
except Exception as e:
st.error(f"An error occurred during model fitting: {e}")
return None, None
def create_data():
# Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
data.index = pd.to_datetime(data.index, format='%Y%m')
data = data.groupby('City').resample('M').sum().reset_index()
data.index = data['Accident Month Bracket']
data = data.drop(['Accident Month Bracket'], axis=1)
data.index = data.index.strftime('%Y-%m')
return data
def to_excel(df):
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, sheet_name='Sheet1')
writer.save()
processed_data = output.getvalue()
return processed_data
# Initialize session state for best parameters
if 'best_params' not in st.session_state:
st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
st.title("SARIMAX Forecasting")
# Data preparation
data = create_data()
unique_cities = data['City'].unique()
# Creating tabs for each city
tabs = st.tabs([city for city in unique_cities])
for tab, city in zip(tabs, unique_cities):
with tab:
# SARIMAX specific sliders
p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
city_data = data[data['City'] == city]['Accident Count']
forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
if forecast is not None:
st.write(f"Best Parameters with AIC: {aic}")
st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
forecast_df = forecast_df.round(0)
st.table(forecast_df)
fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
st.plotly_chart(fig)
# Grid search button
if st.button(f'Run Grid Search for {city}'):
best_aic = float('inf')
best_params = None
# Define the range for each parameter
p_range = d_range = q_range = range(3)
P_range = D_range = Q_range = range(3)
S = 12 # Assuming a fixed seasonal period, adjust as needed
# Perform the grid search
for params in product(p_range, d_range, q_range, P_range, D_range, Q_range):
order = params[:3]
seasonal_order = params[3:] + (S,)
try:
_, temp_aic = run_sarimax(city_data, order, seasonal_order)
if temp_aic < best_aic:
best_aic = temp_aic
best_params = (order, seasonal_order)
except Exception as e:
st.error(f"An error occurred for parameters {params}: {e}")
# Update the session state with the best parameters
if best_params is not None:
st.session_state.best_params = {
'order': best_params[0],
'seasonal_order': best_params[1]
}
st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
# Export to Excel button
if st.button(f'Export {city} to Excel'):
df_to_export = forecast_df
excel_data = to_excel(df_to_export)
st.download_button(label='📥 Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
# Rest of your code
|