File size: 5,015 Bytes
fde69ff 6d819a2 f33264f 6d819a2 f27b577 fde69ff 8c7c72e fde69ff 6d819a2 fde69ff f33264f fde69ff f33264f 6d819a2 fde69ff 6d819a2 fde69ff 6d819a2 fde69ff 6d819a2 fde69ff 6d819a2 fde69ff 6d819a2 f33264f 4b89b9b 339bffb f27b577 1d8445c 6d819a2 f33264f 6d819a2 f33264f 0338839 6d819a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import pandas as pd
from io import BytesIO
import streamlit as st
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.statespace.sarimax import SARIMAX
from itertools import product
import plotly.express as px
# Function to run the Exponential Smoothing Model
def run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period):
try:
model = ExponentialSmoothing(city_data, trend=trend, damped_trend=damped_trend, seasonal=seasonal, seasonal_periods=seasonal_period)
model_fit = model.fit(optimized=True)
return model_fit.forecast(steps=6), model_fit.aic
except Exception as e:
st.error(f"An error occurred during model fitting: {e}")
return None, None
# Function to run SARIMAX Model
def run_sarimax(city_data, order, seasonal_order):
try:
model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order)
model_fit = model.fit(disp=False)
return model_fit.forecast(steps=6), model_fit.aic
except Exception as e:
st.error(f"An error occurred during SARIMAX model fitting: {e}")
return None, None
def create_data():
data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
data.index = pd.to_datetime(data.index, format='%Y%m')
data = data.groupby('City').resample('M').sum().reset_index()
data.index = data['Accident Month Bracket']
data = data.drop(['Accident Month Bracket'],axis=1)
data.index = data.index.strftime('%Y-%m')
return data
def to_excel(df):
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, sheet_name='Sheet1')
writer.save()
processed_data = output.getvalue()
return processed_data
if 'best_params' not in st.session_state:
st.session_state.best_params = {'trend': None, 'damped_trend': False, 'seasonal': None, 'seasonal_period': 12, 'model_type': 'ExpSmoothing'}
st.title("Exponential Smoothing and SARIMAX Forecasting")
data = create_data()
unique_cities = data['City'].unique()
selected_city = st.selectbox('Select a City', unique_cities)
model_type = st.selectbox('Select Model Type', ['ExpSmoothing', 'SARIMAX'])
if model_type == 'ExpSmoothing':
trend = st.select_slider('Select Trend', options=['add', 'mul', None], value=st.session_state.best_params['trend'])
damped_trend = st.checkbox('Damped Trend', value=st.session_state.best_params['damped_trend'])
seasonal = st.select_slider('Select Seasonal', options=['add', 'mul', None], value=st.session_state.best_params['seasonal'])
seasonal_period = st.slider('Seasonal Period', 1, 24, value=st.session_state.best_params['seasonal_period'])
elif model_type == 'SARIMAX':
p = st.slider('AR Order (p)', 0, 5, 0)
d = st.slider('Differencing (d)', 0, 2, 1)
q = st.slider('MA Order (q)', 0, 5, 0)
P = st.slider('Seasonal AR Order (P)', 0, 2, 0)
D = st.slider('Seasonal Differencing (D)', 0, 2, 1)
Q = st.slider('Seasonal MA Order (Q)', 0, 2, 0)
S = st.slider('Seasonal Period (S)', 1, 24, 12)
city_data = data[data['City'] == selected_city]['Accident Count']
if model_type == 'ExpSmoothing':
forecast, aic = run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
elif model_type == 'SARIMAX':
order = (p, d, q)
seasonal_order = (P, D, Q, S)
forecast, aic = run_sarimax(city_data, order, seasonal_order)
if forecast is not None:
st.write(f"Best Parameters with AIC: {aic}")
st.write(f"Forecast:")
forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
forecast_index = forecast_index.to_period('M')
forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
forecast_df = forecast_df.round(0)
st.table(forecast_df)
fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
st.plotly_chart(fig)
# Grid Search Logic for Both Models
if st.button('Run Grid Search'):
best_aic = float('inf')
best_params = None
if model_type == 'ExpSmoothing':
for param_set in product(['add', 'mul', None], [False], ['add', 'mul', None], [12]):
_, temp_aic = run_exp_smoothing(city_data, *param_set)
if temp_aic and temp_aic < best_aic:
best_aic = temp_aic
best_params = param_set
elif model_type == 'SARIMAX':
for param_set in product(range(3), range(2), range(3), range(2), range(2), range(2), [12]):
_, temp_aic = run_sarimax(city_data, param_set[:3], param_set[3:])
if temp_aic and temp_aic < best_aic:
best_aic = temp_aic
best_params = param_set
st.session_state.best_params = best_params
st.write(f"Best Parameters: {best_params} with AIC: {best_aic}")
if st.button('Export to Excel'):
df_to_export = forecast_df
excel_data = to_excel(df_to_export)
st.download_button(label='📥 Download Excel', data=excel_data, file_name='forecast.xlsx', mime='application/vnd.ms-excel')
|