File size: 5,247 Bytes
d6165ca
fde69ff
 
6d819a2
939b379
f27b577
939b379
24d2e79
939b379
 
 
fde69ff
939b379
 
fde69ff
 
8c7c72e
fde69ff
 
 
939b379
f33264f
 
 
 
939b379
f33264f
 
 
fde69ff
 
 
 
 
 
 
 
d6165ca
f33264f
939b379
fde69ff
939b379
fde69ff
d6165ca
fde69ff
 
 
d6165ca
0b7e8df
 
5f049bf
d6165ca
939b379
 
 
 
 
 
 
 
d6165ca
 
939b379
d6165ca
 
 
939b379
d6165ca
 
 
 
 
 
 
 
89e8e44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6165ca
 
 
 
 
 
939b379
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
import pandas as pd
from io import BytesIO
from itertools import product
from statsmodels.tsa.statespace.sarimax import SARIMAX
import plotly.express as px

st.set_page_config(layout="wide")

# Function to run the SARIMAX Model
def run_sarimax(city_data, order, seasonal_order):
    try:
        model = SARIMAX(city_data, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
        model_fit = model.fit(disp=False)
        return model_fit.forecast(steps=6), model_fit.aic
    except Exception as e:
        st.error(f"An error occurred during model fitting: {e}")
        return None, None

def create_data():
    # Assuming you have a CSV file named 'accident_count.csv' with 'City' and 'Accident Count' columns
    data = pd.read_csv('accident_count.csv', parse_dates=True, index_col=0)
    data.index = pd.to_datetime(data.index, format='%Y%m')
    data = data.groupby('City').resample('M').sum().reset_index()
    data.index = data['Accident Month Bracket']
    data = data.drop(['Accident Month Bracket'], axis=1)
    data.index = data.index.strftime('%Y-%m')
    return data

def to_excel(df):
    output = BytesIO()
    writer = pd.ExcelWriter(output, engine='xlsxwriter')
    df.to_excel(writer, sheet_name='Sheet1')
    writer.save()
    processed_data = output.getvalue()
    return processed_data

# Initialize session state for best parameters
if 'best_params' not in st.session_state:
    st.session_state.best_params = {'order': (1, 1, 1), 'seasonal_order': (1, 1, 1, 12)}

st.title("SARIMAX Forecasting")

# Data preparation
data = create_data()
unique_cities = data['City'].unique()

# Creating tabs for each city
tabs = st.tabs([city for city in unique_cities])

for tab, city in zip(tabs, unique_cities):
    with tab:
        # SARIMAX specific sliders
        p = st.slider('AR Order (p)', 0, 5, value=st.session_state.best_params['order'][0], key=city+'p')
        d = st.slider('Differencing Order (d)', 0, 2, value=st.session_state.best_params['order'][1], key=city+'d')
        q = st.slider('MA Order (q)', 0, 5, value=st.session_state.best_params['order'][2], key=city+'q')
        P = st.slider('Seasonal AR Order (P)', 0, 5, value=st.session_state.best_params['seasonal_order'][0], key=city+'P')
        D = st.slider('Seasonal Differencing Order (D)', 0, 2, value=st.session_state.best_params['seasonal_order'][1], key=city+'D')
        Q = st.slider('Seasonal MA Order (Q)', 0, 5, value=st.session_state.best_params['seasonal_order'][2], key=city+'Q')
        S = st.slider('Seasonal Period (S)', 1, 24, value=st.session_state.best_params['seasonal_order'][3], key=city+'S')

        city_data = data[data['City'] == city]['Accident Count']
        forecast, aic = run_sarimax(city_data, (p, d, q), (P, D, Q, S))

        if forecast is not None:
            st.write(f"Best Parameters with AIC: {aic}")
            st.write(f"Non-Seasonal Order: {(p, d, q)}, Seasonal Order: {(P, D, Q, S)}")
            forecast_index = pd.date_range(start=city_data.index[-1], periods=7, freq='M')[1:]
            forecast_index = forecast_index.to_period('M')  # Convert to period index with monthly frequency
            forecast_df = pd.DataFrame(forecast, columns=['Forecast'])
            forecast_df = forecast_df.round(0)
            st.table(forecast_df)
            fig = px.line(forecast_df, x=forecast_df.index, y="Forecast")
            st.plotly_chart(fig)

            # Grid search button
            if st.button(f'Run Grid Search for {city}'):
                best_aic = float('inf')
                best_params = None
                # Define the range for each parameter
                p_range = d_range = q_range = range(3)
                P_range = D_range = Q_range = range(3)
                S = 12  # Assuming a fixed seasonal period, adjust as needed
            
                # Perform the grid search
                for params in product(p_range, d_range, q_range, P_range, D_range, Q_range):
                    order = params[:3]
                    seasonal_order = params[3:] + (S,)
                    try:
                        _, temp_aic = run_sarimax(city_data, order, seasonal_order)
                        if temp_aic < best_aic:
                            best_aic = temp_aic
                            best_params = (order, seasonal_order)
                    except Exception as e:
                        st.error(f"An error occurred for parameters {params}: {e}")
            
                # Update the session state with the best parameters
                if best_params is not None:
                    st.session_state.best_params = {
                        'order': best_params[0],
                        'seasonal_order': best_params[1]
                    }
                    st.write(f"Best Parameters for {city}: {best_params} with AIC: {best_aic}")


        # Export to Excel button
        if st.button(f'Export {city} to Excel'):
            df_to_export = forecast_df
            excel_data = to_excel(df_to_export)
            st.download_button(label='📥 Download Excel', data=excel_data, file_name=f'{city}_forecast.xlsx', mime='application/vnd.ms-excel')

# Rest of your code