|
import streamlit as st |
|
import pandas as pd |
|
from io import BytesIO |
|
from itertools import product |
|
from statsmodels.tsa.statespace.sarimax import SARIMAX |
|
import plotly.express as px |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
def run_sarimax(city_data, order, seasonal_order): |
|
try: |
|
|
|
if city_data.empty: |
|
st.error(f"No data available for modeling.") |
|
return None, None |
|
|
|
model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False) |
|
model_fit = model.fit(disp=False) |
|
forecast = model_fit.forecast(steps=6) |
|
|
|
|
|
if forecast is None or forecast.empty: |
|
st.error(f"Forecast failed, the model returned an empty forecast.") |
|
return None, None |
|
|
|
return forecast, model_fit.aic |
|
except Exception as e: |
|
st.error(f"An error occurred during model fitting: {e}") |
|
return None, None |
|
|
|
def create_data(): |
|
|
|
data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0) |
|
data.index = pd.to_datetime(data.index, format='%Y%m') |
|
data = data.groupby('City').resample('M').sum().reset_index() |
|
data.index = data['Accident Month Bracket'] |
|
data = data.drop(['Accident Month Bracket'], axis=1) |
|
data.index = data.index.strftime('%Y-%m') |
|
return data |
|
|
|
def to_excel(df): |
|
output = BytesIO() |
|
writer = pd.ExcelWriter(output, engine='xlsxwriter') |
|
df.to_excel(writer, sheet_name='Sheet1') |
|
writer.save() |
|
processed_data = output.getvalue() |
|
return processed_data |
|
|
|
|
|
if 'best_params' not in st.session_state: |
|
st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)} |
|
|
|
st.title("SARIMAX Forecasting") |
|
|
|
|
|
data = create_data() |
|
unique_cities = data['City'].unique() |
|
|
|
|
|
tabs = st.tabs([city for city in unique_cities]) |
|
|
|
for tab, city in zip(tabs, unique_cities): |
|
with tab: |
|
|
|
p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p') |
|
d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d') |
|
q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q') |
|
P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P') |
|
D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D') |
|
Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q') |
|
S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S') |
|
|
|
city_data = data[data['City'] == city]['Accident Count'] |
|
forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S)) |
|
|
|
if forecast is not None: |
|
st.write(f"Best Parameters with AIC: {aic}") |
|
st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}") |
|
forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:] |
|
forecast_index = forecast_index.to_period('M') |
|
forecast_df = pd.DataFrame(forecast, columns=['predicted_mean']) |
|
forecast_df = forecast_df.round(0) |
|
st.table(forecast_df) |
|
fig = px.line(forecast_df, x=forecast_df.index, y="predicted_mean") |
|
st.plotly_chart(fig) |
|
|
|
|
|
if st.button(f'Run Grid Search for {city}'): |
|
best_aic = float('inf') |
|
best_params = None |
|
|
|
p_range = d_range = q_range = range(3) |
|
P_range = D_range = Q_range = range(3) |
|
S = 12 |
|
|
|
|
|
for params in product(p_range, d_range, q_range, P_range, D_range, Q_range): |
|
order = params[:3] |
|
seasonal_order = params[3:] + (S,) |
|
try: |
|
_, temp_aic = run_sarimax(city_data, order, seasonal_order) |
|
if temp_aic < best_aic: |
|
best_aic = temp_aic |
|
best_params = (order, seasonal_order) |
|
except Exception as e: |
|
st.error(f"An error occurred for parameters {params}: {e}") |
|
|
|
|
|
if best_params is not None: |
|
st.session_state.best_params = { |
|
'order': best_params[0], |
|
'seasonal_order': best_params[1] |
|
} |
|
st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}") |
|
|
|
|
|
|
|
if st.button(f'Export {city} to Excel'): |
|
df_to_export = forecast_df |
|
excel_data = to_excel(df_to_export) |
|
st.download_button(label='π₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|