Najm_NSR / app.py
XPMaster's picture
Update app.py
89e8e44
raw
history blame
No virus
5.25 kB
import streamlit as st
import pandas as pd
from io import BytesIO
from itertools import product
from statsmodels.tsa.statespace.sarimax import SARIMAX
import plotly.express as px
st.set_page_config(layout="wide")
# Function to run the SARIMAX Model
def run_sarimax(city_data, order, seasonal_order):
try:
model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
model_fit = model.fit(disp=False)
return model_fit.forecast(steps=6), model_fit.aic
except Exception as e:
st.error(f"An error occurred during model fitting: {e}")
return None, None
def create_data():
# Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
data.index = pd.to_datetime(data.index, format='%Y%m')
data = data.groupby('City').resample('M').sum().reset_index()
data.index = data['Accident Month Bracket']
data = data.drop(['Accident Month Bracket'], axis=1)
data.index = data.index.strftime('%Y-%m')
return data
def to_excel(df):
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, sheet_name='Sheet1')
writer.save()
processed_data = output.getvalue()
return processed_data
# Initialize session state for best parameters
if 'best_params' not in st.session_state:
st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
st.title("SARIMAX Forecasting")
# Data preparation
data = create_data()
unique_cities = data['City'].unique()
# Creating tabs for each city
tabs = st.tabs([city for city in unique_cities])
for tab, city in zip(tabs, unique_cities):
with tab:
# SARIMAX specific sliders
p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
city_data = data[data['City'] == city]['Accident Count']
forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
if forecast is not None:
st.write(f"Best Parameters with AIC: {aic}")
st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
forecast_df = forecast_df.round(0)
st.table(forecast_df)
fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
st.plotly_chart(fig)
# Grid search button
if st.button(f'Run Grid Search for {city}'):
best_aic = float('inf')
best_params = None
# Define the range for each parameter
p_range = d_range = q_range = range(3)
P_range = D_range = Q_range = range(3)
S = 12 # Assuming a fixed seasonal period, adjust as needed
# Perform the grid search
for params in product(p_range, d_range, q_range, P_range, D_range, Q_range):
order = params[:3]
seasonal_order = params[3:] + (S,)
try:
_, temp_aic = run_sarimax(city_data, order, seasonal_order)
if temp_aic < best_aic:
best_aic = temp_aic
best_params = (order, seasonal_order)
except Exception as e:
st.error(f"An error occurred for parameters {params}: {e}")
# Update the session state with the best parameters
if best_params is not None:
st.session_state.best_params = {
'order': best_params[0],
'seasonal_order': best_params[1]
}
st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
# Export to Excel button
if st.button(f'Export {city} to Excel'):
df_to_export = forecast_df
excel_data = to_excel(df_to_export)
st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
# Rest of your code