Najm_NSR / app.py
XPMaster's picture
Update app.py
939b379
raw
history blame
No virus
4.73 kB
import streamlit as st
import pandas as pd
from io import BytesIO
from itertools import product
from statsmodels.tsa.statespace.sarimax import SARIMAX
import plotly.express as px
st.set_page_config(layout="wide")
# Function to run the SARIMAX Model
def run_sarimax(city_data, order, seasonal_order):
try:
model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
model_fit = model.fit(disp=False)
return model_fit.forecast(steps=6), model_fit.aic
except Exception as e:
st.error(f"An error occurred during model fitting: {e}")
return None, None
def create_data():
# Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
data.index = pd.to_datetime(data.index, format='%Y%m')
data = data.groupby('City').resample('M').sum().reset_index()
data.index = data['Accident Month Bracket']
data = data.drop(['Accident Month Bracket'], axis=1)
data.index = data.index.strftime('%Y-%m')
return data
def to_excel(df):
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, sheet_name='Sheet1')
writer.save()
processed_data = output.getvalue()
return processed_data
# Initialize session state for best parameters
if 'best_params' not in st.session_state:
st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
st.title("SARIMAX Forecasting")
# Data preparation
data = create_data()
unique_cities = data['City'].unique()
# Creating tabs for each city
tabs = st.tabs([city for city in unique_cities])
for tab, city in zip(tabs, unique_cities):
with tab:
# SARIMAX specific sliders
p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
city_data = data[data['City'] == city]['Accident Count']
forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
if forecast is not None:
st.write(f"Best Parameters with AIC: {aic}")
st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
forecast_df = forecast_df.round(0)
st.table(forecast_df)
fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
st.plotly_chart(fig)
# Grid search button
if st.button(f'Run Grid Search for {city}'):
best_aic = float('inf')
best_params = None
for param_set in product(range(3), repeat=3): # Adjust the range and repeat parameters as needed
for seasonal_param_set in product(range(3), repeat=4): # Adjust for seasonal parameters
_, temp_aic = run_sarimax(city_data, param_set, seasonal_param_set+(12,))
if temp_aic and temp_aic < best_aic:
best_aic = temp_aic
best_params = (param_set, seasonal_param_set+(12,))
# Updating session state with the best parameters
st.session_state.best_params = {
'order': best_params[0],
'seasonal_order': best_params[1]
}
st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
# Export to Excel button
if st.button(f'Export {city} to Excel'):
df_to_export = forecast_df
excel_data = to_excel(df_to_export)
st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
# Rest of your code