Najm_NSR / app.py
XPMaster's picture
Update app.py
6d819a2
raw
history blame
No virus
5.02 kB
import pandas as pd
from io import BytesIO
import streamlit as st
from statsmodels.tsa.holtwinters import ExponentialSmoothing
from statsmodels.tsa.statespace.sarimax import SARIMAX
from itertools import product
import plotly.express as px
# Function to run the Exponential Smoothing Model
def run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period):
try:
model = ExponentialSmoothing(city_data, trend=trend, damped_trend=damped_trend, seasonal=seasonal, seasonal_periods=seasonal_period)
model_fit = model.fit(optimized=True)
return model_fit.forecast(steps=6), model_fit.aic
except Exception as e:
st.error(f"An error occurred during model fitting: {e}")
return None, None
# Function to run SARIMAX Model
def run_sarimax(city_data, order, seasonal_order):
try:
model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order)
model_fit = model.fit(disp=False)
return model_fit.forecast(steps=6), model_fit.aic
except Exception as e:
st.error(f"An error occurred during SARIMAX model fitting: {e}")
return None, None
def create_data():
data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
data.index = pd.to_datetime(data.index, format='%Y%m')
data = data.groupby('City').resample('M').sum().reset_index()
data.index = data['Accident Month Bracket']
data = data.drop(['Accident Month Bracket'],axis=1)
data.index = data.index.strftime('%Y-%m')
return data
def to_excel(df):
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, sheet_name='Sheet1')
writer.save()
processed_data = output.getvalue()
return processed_data
if 'best_params' not in st.session_state:
st.session_state.best_params = {'trend': None, 'damped_trend': False, 'seasonal': None, 'seasonal_period': 12, 'model_type': 'ExpSmoothing'}
st.title("Exponential Smoothing and SARIMAX Forecasting")
data = create_data()
unique_cities = data['City'].unique()
selected_city = st.selectbox('Select a City', unique_cities)
model_type = st.selectbox('Select Model Type', ['ExpSmoothing', 'SARIMAX'])
if model_type == 'ExpSmoothing':
trend = st.select_slider('Select Trend', options=['add', 'mul', None], value=st.session_state.best_params['trend'])
damped_trend = st.checkbox('Damped Trend', value=st.session_state.best_params['damped_trend'])
seasonal = st.select_slider('Select Seasonal', options=['add', 'mul', None], value=st.session_state.best_params['seasonal'])
seasonal_period = st.slider('Seasonal Period', 1, 24, value=st.session_state.best_params['seasonal_period'])
elif model_type == 'SARIMAX':
p = st.slider('AR Order (p)', 0, 5, 0)
d = st.slider('Differencing (d)', 0, 2, 1)
q = st.slider('MA Order (q)', 0, 5, 0)
P = st.slider('Seasonal AR Order (P)', 0, 2, 0)
D = st.slider('Seasonal Differencing (D)', 0, 2, 1)
Q = st.slider('Seasonal MA Order (Q)', 0, 2, 0)
S = st.slider('Seasonal Period (S)', 1, 24, 12)
city_data = data[data['City'] == selected_city]['Accident Count']
if model_type == 'ExpSmoothing':
forecast, aic = run_exp_smoothing(city_data, trend, damped_trend, seasonal, seasonal_period)
elif model_type == 'SARIMAX':
order = (p, d, q)
seasonal_order = (P, D, Q, S)
forecast, aic = run_sarimax(city_data, order, seasonal_order)
if forecast is not None:
st.write(f"Best Parameters with AIC: {aic}")
st.write(f"Forecast:")
forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
forecast_index = forecast_index.to_period('M')
forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
forecast_df = forecast_df.round(0)
st.table(forecast_df)
fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
st.plotly_chart(fig)
# Grid Search Logic for Both Models
if st.button('Run Grid Search'):
best_aic = float('inf')
best_params = None
if model_type == 'ExpSmoothing':
for param_set in product(['add', 'mul', None], [False], ['add', 'mul', None], [12]):
_, temp_aic = run_exp_smoothing(city_data, *param_set)
if temp_aic and temp_aic < best_aic:
best_aic = temp_aic
best_params = param_set
elif model_type == 'SARIMAX':
for param_set in product(range(3), range(2), range(3), range(2), range(2), range(2), [12]):
_, temp_aic = run_sarimax(city_data, param_set[:3], param_set[3:])
if temp_aic and temp_aic < best_aic:
best_aic = temp_aic
best_params = param_set
st.session_state.best_params = best_params
st.write(f"Best Parameters: {best_params} with AIC: {best_aic}")
if st.button('Export to Excel'):
df_to_export = forecast_df
excel_data = to_excel(df_to_export)
st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name='forecast.xlsx', mime='application/vnd.ms-excel')