azizalto's picture
init forecaster
674f526
import streamlit as st
import pandas as pd
from prophet import Prophet
@st.experimental_singleton
class ProphetModel:
@staticmethod
def predict(df: pd.DataFrame, **kwargs) -> pd.DataFrame:
st1, st2 = st.columns(2)
params = {
"growth": kwargs.get("growth", "linear"),
"interval_width": kwargs.get("interval_width", 0.95),
}
# st.write(params)
st.write(kwargs)
if "cap" in kwargs:
df["cap"] = float(kwargs.get("cap"))
period = kwargs.get("period", 7)
# -- train model
m = Prophet(**params)
m.fit(df)
future = m.make_future_dataframe(periods=period)
if "cap" in kwargs:
future["cap"] = float(kwargs.get("cap"))
forecast = m.predict(future)
# -- display output
cols = ["ds", "yhat", "yhat_lower", "yhat_upper"]
temp_ = forecast.copy()
temp_["ds"] = temp_["ds"].apply(lambda x: x.strftime("%Y-%m-%d"))
st.write(f"future={period}days")
st.write(temp_[cols])
fig1 = m.plot(forecast)
fig2 = m.plot_components(forecast)
from prophet.plot import plot_plotly, plot_components_plotly
st1.markdown("> forecasts")
st1.plotly_chart(plot_plotly(m, forecast, trend=True), use_container_width=True)
st2.markdown("> forecast components")
st2.plotly_chart(plot_components_plotly(m, forecast), use_container_width=True)
# -- download results
from forecast.utils import get_table_download_link
st.markdown(get_table_download_link(forecast), unsafe_allow_html=True)
st.success("Forecast completed ✨")
return df