Najm_NSR / app.py
XPMaster's picture
Update app.py
8e0f382
raw
history blame
No virus
7.64 kB
import streamlit as st
import pandas as pd
from io import BytesIO
from itertools import product
from statsmodels.tsa.statespace.sarimax import SARIMAX
import plotly.express as px
st.set_page_config(layout="wide")
# Function to run the SARIMAX Model
def run_sarimax(city_data, order, seasonal_order):
try:
# Check if the data is non-empty and in the correct format
if city_data.empty:
st.error(f"No data available for modeling.")
return None, None
model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
model_fit = model.fit(disp=False)
forecast = model_fit.forecast(steps=6)
# Check if the forecast is valid
if forecast is None or forecast.empty:
st.error(f"Forecast failed, the model returned an empty forecast.")
return None, None
return forecast, model_fit.aic
except Exception as e:
st.error(f"An error occurred during model fitting: {e}")
return None, None
def create_data():
# Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
data.index = pd.to_datetime(data.index, format='%Y%m')
data = data.groupby('City').resample('M').sum().reset_index()
data.index = data['Accident Month Bracket']
data = data.drop(['Accident Month Bracket'], axis=1)
data.index = data.index.strftime('%Y-%m')
return data
def to_excel(df):
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, sheet_name='Sheet1')
writer.save()
processed_data = output.getvalue()
return processed_data
# Initialize session state for best parameters
if 'best_params' not in st.session_state:
st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}
st.title("SARIMAX Forecasting")
# Data preparation
data = create_data()
unique_cities = data['City'].unique()
# Creating tabs for each city
tabs = st.tabs([city for city in unique_cities])
for tab, city in zip(tabs, unique_cities):
with tab:
# SARIMAX specific sliders
p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')
city_data = data[data['City'] == city]['Accident Count']
forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))
if forecast is not None:
st.write(f"Best Parameters with AIC: {aic}")
st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
forecast_index = forecast_index.to_period('M') # Convert to period index with monthly frequency
forecast_df = pd.DataFrame(forecast, columns=['predicted_mean'])
forecast_df = forecast_df.round(0)
st.table(forecast_df)
fig = px.line(forecast_df, x=forecast_df.index, y="predicted_mean")
st.plotly_chart(fig)
# Grid search button
if st.button(f'Run Grid Search for {city}'):
best_aic = float('inf')
best_params = None
# Define the range for each parameter
p_range = d_range = q_range = range(3)
P_range = D_range = Q_range = range(3)
S = 12 # Assuming a fixed seasonal period, adjust as needed
# Perform the grid search
for params in product(p_range, d_range, q_range, P_range, D_range, Q_range):
order = params[:3]
seasonal_order = params[3:] + (S,)
try:
_, temp_aic = run_sarimax(city_data, order, seasonal_order)
if temp_aic < best_aic:
best_aic = temp_aic
best_params = (order, seasonal_order)
except Exception as e:
st.error(f"An error occurred for parameters {params}: {e}")
# Update the session state with the best parameters
if best_params is not None:
st.session_state.best_params = {
'order': best_params[0],
'seasonal_order': best_params[1]
}
st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")
# Export to Excel button
if st.button(f'Export {city} to Excel'):
df_to_export = forecast_df
excel_data = to_excel(df_to_export)
st.download_button(label='πŸ“₯ Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')
# import streamlit as st
# import numpy as np
# import matplotlib.pyplot as plt
# # Sample data for multiple plots
# data = [(np.linspace(0, 10, 10), np.sin(np.linspace(0, 10, 10))) for _ in range(3)]
# # Initialize session state
# if 'current_plot_index' not in st.session_state:
# st.session_state['current_plot_index'] = 0
# def update_plot(index, x, y):
# plt.figure()
# plt.plot(x, y, marker='o')
# plt.title(f"Plot {index+1}")
# st.pyplot(plt)
# def next_plot():
# st.session_state['current_plot_index'] = (st.session_state['current_plot_index'] + 1) % len(data)
# # Display the plot
# index = st.session_state['current_plot_index']
# x, y = data[index]
# update_plot(index, x, y)
# # Select and update point
# point_index = st.number_input('Point Index', min_value=0, max_value=len(x)-1, step=1)
# new_value = st.number_input('New Y Value', value=y[point_index])
# if st.button('Update Point'):
# y[point_index] = new_value
# update_plot(index, x, y)
# # Next plot button
# if st.button('Next Plot'):
# next_plot()
# import plotly.express as px
# import streamlit as st
# from streamlit_plotly_events import plotly_events
# # Sample data
# df = px.data.gapminder().query("country=='Canada'")
# fig = px.line(df, x="year", y="lifeExp", title='Life Expectancy in Canada Over Years')
# # Capture the selected points
# selected_points = plotly_events(fig, click_event=True)
# # Handle the click event
# if selected_points:
# st.write("You clicked on:", selected_points)
# point_index = selected_points[0]['pointIndex']
# new_value = st.number_input('Enter new value for life expectancy', value=df.iloc[point_index]['lifeExp'])
# if st.button('Update Data'):
# df.at[point_index, 'lifeExp'] = new_value
# fig = px.line(df, x="year", y="lifeExp", title='Life Expectancy in Canada Over Years')
# st.plotly_chart(fig)
# else:
# st.plotly_chart(fig)