simple_forecast / app.py
azizalto's picture
Re-align input layout
b0174c1
from forecast.page_config import APP_PAGE_HEADER
import streamlit as st
import pandas as pd
APP_PAGE_HEADER()
class InputData:
@classmethod
def get_data(cls) -> pd.DataFrame:
"""
Datasets sources:
avg_daily_air_temp_celsius_helsinki: http://shorturl.at/gBR06
Returns:
"""
st1, st2 = st.columns(2)
sample = st1.selectbox(
"Sample datasets",
options=["", "sample1", "sample2", "avg_daily_air_temp_celsius_helsinki"],
)
if sample:
file_ = f"data/{sample}.csv"
return pd.read_csv(file_)
uploaded_data = cls.read_file(st2)
if uploaded_data is not None:
return uploaded_data
@classmethod
def read_file(cls, st_):
file_ = st_.file_uploader("Upload your dataset (csv file)")
if not file_:
st.stop()
if file_:
sep = st.selectbox("column sep", options=[",", ";", "|"])
df = pd.read_csv(file_, sep=sep)
cols = df.columns.tolist()
# -- choose date/target columns
st1, st3, st2 = st.columns(3)
date_col = st1.selectbox(
"Date column (x-axis / index)", options=[""] + cols
)
st.session_state.date_formatter = st2.text_input(
"Optional: Date format e.g. %Y-%m-%d"
)
target = st3.selectbox(
"Target column (y-axis / target variable)", options=[""] + cols
)
# -- display
if date_col and target:
df = df[[date_col, target]]
return df
st.write(df)
st.stop()
return file_
@classmethod
def preprocess_data(cls, df: pd.DataFrame) -> pd.DataFrame:
df.columns = ["ds", "y"]
# -- date column
try:
date_formatter = (
st.session_state.date_formatter
if "date_formatter" in st.session_state
else None
)
if date_formatter:
st.write("hello")
st.stop()
df["ds"] = pd.to_datetime(
df["ds"], format=st.session_state.date_formatter
)
else:
df["ds"] = pd.to_datetime(df["ds"]).dt.date
except:
st.error("Date column is not in correct format")
st.write(df["ds"])
st.stop()
# -- target column
df["y"] = df["y"].apply(lambda x: float(str(x).replace(",", "")))
df["y"] = df["y"].astype(float)
return df
class PredictionApp:
@staticmethod
def run_prediction(df: pd.DataFrame) -> pd.DataFrame:
# -- prepare data and user input
st1, st2 = st.columns(2)
st1.write(df)
st2.line_chart(df.set_index("ds"))
segmented_df = PredictionApp.split_df_by_date(df)
# -- future date picker
n = segmented_df.shape[0]
future = st.slider("Number of days to predict", 7, n * 2, value=int(n / 2))
params = PredictionApp.user_input_model_params()
run = st.button("Run")
PredictionApp.display_prophet_docs()
if not run:
return
# -- run prediction
with st.spinner("running prediction engine .."):
from forecast.fbprophet.model import ProphetModel
model = ProphetModel()
pred = model.predict(segmented_df, period=future, **params)
return pred
@staticmethod
def split_df_by_date(df: pd.DataFrame) -> pd.DataFrame:
# -- split dataframe by date
st.caption("Choose the target fitting period")
st1, st2 = st.columns(2)
from_ = st1.date_input(
"from", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.min()
)
to_ = st2.date_input(
"to", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.max()
)
ix1 = df.index[df.ds == from_][0]
ix2 = df.index[df.ds == to_][0]
new_df = PredictionApp._displayed_segmented_dataframe(df, from_=ix1, to_=ix2)
new_df.reset_index(inplace=True)
st1.write(f"{new_df['ds'].min()}")
st2.write(f"{new_df['ds'].max()}")
return new_df
@classmethod
def _displayed_segmented_dataframe(
cls, df: pd.DataFrame, from_: int, to_: int
) -> pd.DataFrame:
df = df.set_index("ds")
df_ = df[from_ : to_ + 1]
st.line_chart(df_)
return df_
@classmethod
def user_input_model_params(cls):
raw_params = st.text_input(
"Model params. Type param name and its value e.g. growth=logistic"
)
if raw_params:
in_params = [x.strip() for x in raw_params.split(",")]
params = {}
for param in in_params:
k, v = param.split("=")
params[k] = float(v) if v.isdigit() else v
if "growth" in params and params["growth"] == "logistic":
cap = st.text_input("cap")
if cap:
params["cap"] = float(cap)
else:
st.warning("Cap is required for logistic growth")
st.stop()
st.write("Your input params:")
st.write(params)
return params
return {}
@classmethod
def display_prophet_docs(cls):
from prophet import Prophet
with st.expander("View model params"):
st.write(Prophet.__doc__)
st.markdown(
"> more details: [visit](https://facebook.github.io/prophet/)",
unsafe_allow_html=True,
)
def app():
data = InputData.get_data()
data = InputData.preprocess_data(data)
PredictionApp.run_prediction(data)
if __name__ == "__main__":
app()