forecasting / app.py
filipeclduarte's picture
Update app.py
29b3cdb
import streamlit as st
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
from pmdarima import auto_arima
from utils import calc_seasonality
@st.cache
def convert_df(df):
return df.to_csv().encode('utf-8')
st.title('Forecasting time series')
uploaded_file = st.file_uploader("Choose a CSV file", accept_multiple_files=False)
if uploaded_file is not None:
dataframe = pd.read_csv(uploaded_file)
st.write(dataframe)
series = st.text_input("Write the name of the variable you want to forecast")
date = st.text_input('Write the first date')
# freq = st.text_input('Write the frequency')
freq = st.selectbox('Select the frequency',
('Y', 'Q', 'M', 'W', 'D', 'NULL'))
if series is not None and freq is not None:
horizons = int(st.text_input("Write the number of horizons", value="10"))
seasonality = calc_seasonality(freq)
st.write(f'Seasonality: {seasonality}')
series_train = pd.DataFrame(
{
'ds': pd.date_range(start=date, periods=dataframe.shape[0], freq=freq),
'y': dataframe[series].values
},
index=pd.Index([0] * dataframe.shape[0], name='unique_id')
)
fig = px.line(series_train, x='ds', y='y', title=f'{series}')
st.plotly_chart(fig)
with st.spinner('Wait for it...'):
if freq != 'NULL':
fcst = auto_arima(
series_train['y'],
seasonality=True,
freq=freq,
n_jobs=1,
max_p=12,
max_q=12
)
else:
fcst = auto_arima(
series_train['y'],
seasonality=False,
n_jobs=1,
max_p=12,
max_q=12
)
fcst.fit(series_train['y'])
forecasts, ci_05 = fcst.predict(horizons, return_conf_int=True, alpha=0.05)
_, ci_10 = fcst.predict(horizons, return_conf_int=True, alpha=0.1)
st.success('Done!')
st.write(forecasts)
forecasts_df = pd.DataFrame({'ds':pd.date_range(start=series_train['ds'].iloc[-1], periods=horizons, freq=freq),
'mean':forecasts, 'low_ci_05':ci_05[:,0], 'low_ci_10':ci_10[:,0],
'hi_ci_05':ci_05[:,1], 'hi_ci_10':ci_10[:,1]})
csv = convert_df(forecasts)
st.download_button(
label="Download data as CSV",
data=csv,
file_name='forecast.csv',
mime='text/csv',
)
fig_forecast = go.Figure([
go.Scatter(
name=series,
x=series_train['ds'],
y=series_train['y'],
mode='lines',
marker=dict(color='blue'),
line=dict(width=1),
showlegend=True
),
go.Scatter(
name='PI-95%',
x=np.concatenate([forecasts_df['ds'].values,forecasts_df['ds'].values[::-1]]), # x, then x reversed
y=pd.concat([forecasts_df['hi_ci_05'],forecasts_df['low_ci_05'][::-1]]), # upper, then lower reversed
fill='toself',
fillcolor='rgba(63, 232, 39, 0.52)',
line=dict(color='rgba(63, 232, 39, 0.52)'),
opacity=0.3,
hoverinfo="skip",
showlegend=True
),
go.Scatter(
name='PI-90%',
x=np.concatenate([forecasts_df['ds'].values,forecasts_df['ds'].values[::-1]]), # x, then x reversed
y=pd.concat([forecasts_df['hi_ci_10'],forecasts_df['low_ci_10'][::-1]]), # upper, then lower reversed
fill='toself',
fillcolor='rgba(63, 232, 39, 0.26)',
line=dict(color='rgba(63, 232, 39, 0.26)'),
opacity=0.5,
hoverinfo="skip",
showlegend=True
),
go.Scatter(
name=f'auto_arima_season_length-{seasonality}_mean',
x=forecasts_df['ds'],
y=forecasts_df['mean'],
mode='lines',
marker=dict(color='green'),
line=dict(width=1),
showlegend=True
),
])
fig_forecast.update_layout(
yaxis_title=series,
title=f'Forecasting {series}'
)
st.plotly_chart(fig_forecast)