Spaces:
Sleeping
Sleeping
File size: 5,930 Bytes
674f526 44ebc9f 674f526 b0174c1 674f526 b0174c1 674f526 b0174c1 674f526 b0174c1 674f526 2a32d74 674f526 44ebc9f 674f526 44ebc9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
from forecast.page_config import APP_PAGE_HEADER
import streamlit as st
import pandas as pd
APP_PAGE_HEADER()
class InputData:
@classmethod
def get_data(cls) -> pd.DataFrame:
"""
Datasets sources:
avg_daily_air_temp_celsius_helsinki: http://shorturl.at/gBR06
Returns:
"""
st1, st2 = st.columns(2)
sample = st1.selectbox(
"Sample datasets",
options=["", "sample1", "sample2", "avg_daily_air_temp_celsius_helsinki"],
)
if sample:
file_ = f"data/{sample}.csv"
return pd.read_csv(file_)
uploaded_data = cls.read_file(st2)
if uploaded_data is not None:
return uploaded_data
@classmethod
def read_file(cls, st_):
file_ = st_.file_uploader("Upload your dataset (csv file)")
if not file_:
st.stop()
if file_:
sep = st.selectbox("column sep", options=[",", ";", "|"])
df = pd.read_csv(file_, sep=sep)
cols = df.columns.tolist()
# -- choose date/target columns
st1, st3, st2 = st.columns(3)
date_col = st1.selectbox(
"Date column (x-axis / index)", options=[""] + cols
)
st.session_state.date_formatter = st2.text_input(
"Optional: Date format e.g. %Y-%m-%d"
)
target = st3.selectbox(
"Target column (y-axis / target variable)", options=[""] + cols
)
# -- display
if date_col and target:
df = df[[date_col, target]]
return df
st.write(df)
st.stop()
return file_
@classmethod
def preprocess_data(cls, df: pd.DataFrame) -> pd.DataFrame:
df.columns = ["ds", "y"]
# -- date column
try:
date_formatter = (
st.session_state.date_formatter
if "date_formatter" in st.session_state
else None
)
if date_formatter:
st.write("hello")
st.stop()
df["ds"] = pd.to_datetime(
df["ds"], format=st.session_state.date_formatter
)
else:
df["ds"] = pd.to_datetime(df["ds"]).dt.date
except:
st.error("Date column is not in correct format")
st.write(df["ds"])
st.stop()
# -- target column
df["y"] = df["y"].apply(lambda x: float(str(x).replace(",", "")))
df["y"] = df["y"].astype(float)
return df
class PredictionApp:
@staticmethod
def run_prediction(df: pd.DataFrame) -> pd.DataFrame:
# -- prepare data and user input
st1, st2 = st.columns(2)
st1.write(df)
st2.line_chart(df.set_index("ds"))
segmented_df = PredictionApp.split_df_by_date(df)
# -- future date picker
n = segmented_df.shape[0]
future = st.slider("Number of days to predict", 7, n * 2, value=int(n / 2))
params = PredictionApp.user_input_model_params()
run = st.button("Run")
PredictionApp.display_prophet_docs()
if not run:
return
# -- run prediction
with st.spinner("running prediction engine .."):
from forecast.fbprophet.model import ProphetModel
model = ProphetModel()
pred = model.predict(segmented_df, period=future, **params)
return pred
@staticmethod
def split_df_by_date(df: pd.DataFrame) -> pd.DataFrame:
# -- split dataframe by date
st.caption("Choose the target fitting period")
st1, st2 = st.columns(2)
from_ = st1.date_input(
"from", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.min()
)
to_ = st2.date_input(
"to", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.max()
)
ix1 = df.index[df.ds == from_][0]
ix2 = df.index[df.ds == to_][0]
new_df = PredictionApp._displayed_segmented_dataframe(df, from_=ix1, to_=ix2)
new_df.reset_index(inplace=True)
st1.write(f"{new_df['ds'].min()}")
st2.write(f"{new_df['ds'].max()}")
return new_df
@classmethod
def _displayed_segmented_dataframe(
cls, df: pd.DataFrame, from_: int, to_: int
) -> pd.DataFrame:
df = df.set_index("ds")
df_ = df[from_ : to_ + 1]
st.line_chart(df_)
return df_
@classmethod
def user_input_model_params(cls):
raw_params = st.text_input(
"Model params. Type param name and its value e.g. growth=logistic"
)
if raw_params:
in_params = [x.strip() for x in raw_params.split(",")]
params = {}
for param in in_params:
k, v = param.split("=")
params[k] = float(v) if v.isdigit() else v
if "growth" in params and params["growth"] == "logistic":
cap = st.text_input("cap")
if cap:
params["cap"] = float(cap)
else:
st.warning("Cap is required for logistic growth")
st.stop()
st.write("Your input params:")
st.write(params)
return params
return {}
@classmethod
def display_prophet_docs(cls):
from prophet import Prophet
with st.expander("View model params"):
st.write(Prophet.__doc__)
st.markdown(
"> more details: [visit](https://facebook.github.io/prophet/)",
unsafe_allow_html=True,
)
def app():
data = InputData.get_data()
data = InputData.preprocess_data(data)
PredictionApp.run_prediction(data)
if __name__ == "__main__":
app()
|