Spaces:
Sleeping
Sleeping
init forecaster
Browse files- app.py +186 -1
- assets/logo.png +0 -0
- forecast/fbprophet/model.py +55 -0
- forecast/page_config.py +34 -0
- forecast/utils.py +38 -0
- requirements.txt +11 -1
app.py
CHANGED
@@ -1,8 +1,193 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def app():
|
5 |
-
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
if __name__ == "__main__":
|
|
|
1 |
+
from forecast.page_config import APP_PAGE_HEADER
|
2 |
+
|
3 |
import streamlit as st
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
APP_PAGE_HEADER()
|
7 |
+
|
8 |
+
|
9 |
+
class InputData:
|
10 |
+
@classmethod
|
11 |
+
def get_data(cls) -> pd.DataFrame:
|
12 |
+
"""
|
13 |
+
Datasets sources:
|
14 |
+
avg_daily_air_temp_celsius_helsinki: http://shorturl.at/gBR06
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
|
18 |
+
"""
|
19 |
+
sample = st.selectbox(
|
20 |
+
"Sample datasets",
|
21 |
+
options=["", "sample1", "sample2", "avg_daily_air_temp_celsius_helsinki"],
|
22 |
+
)
|
23 |
+
if sample:
|
24 |
+
file_ = f"data/{sample}.csv"
|
25 |
+
return pd.read_csv(file_)
|
26 |
+
|
27 |
+
uploaded_data = cls.read_file()
|
28 |
+
if uploaded_data is not None:
|
29 |
+
return uploaded_data
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def read_file(cls):
|
33 |
+
global DATE_FORMATTER
|
34 |
+
|
35 |
+
file_ = st.file_uploader("Upload your dataset (csv file)")
|
36 |
+
if not file_:
|
37 |
+
st.stop()
|
38 |
+
|
39 |
+
if file_:
|
40 |
+
sep = st.selectbox("column sep", options=[",", ";", "|"])
|
41 |
+
df = pd.read_csv(file_, sep=sep)
|
42 |
+
|
43 |
+
cols = df.columns.tolist()
|
44 |
+
|
45 |
+
# -- choose date/target columns
|
46 |
+
st1, st3, st2 = st.columns(3)
|
47 |
+
date_col = st1.selectbox(
|
48 |
+
"Date column (x-axis / index)", options=[""] + cols
|
49 |
+
)
|
50 |
+
st.session_state.date_formatter = st2.text_input(
|
51 |
+
"Optional: Date format e.g. %Y-%m-%d"
|
52 |
+
)
|
53 |
+
target = st3.selectbox(
|
54 |
+
"Target column (y-axis / target variable)", options=[""] + cols
|
55 |
+
)
|
56 |
+
|
57 |
+
# -- display
|
58 |
+
if date_col and target:
|
59 |
+
df = df[[date_col, target]]
|
60 |
+
return df
|
61 |
+
st.write(df)
|
62 |
+
st.stop()
|
63 |
+
|
64 |
+
return file_
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def preprocess_data(cls, df: pd.DataFrame) -> pd.DataFrame:
|
68 |
+
df.columns = ["ds", "y"]
|
69 |
+
# -- date column
|
70 |
+
try:
|
71 |
+
if st.session_state.date_formatter:
|
72 |
+
df["ds"] = pd.to_datetime(
|
73 |
+
df["ds"], format=st.session_state.date_formatter
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
df["ds"] = pd.to_datetime(df["ds"]).dt.date
|
77 |
+
except:
|
78 |
+
st.error("Date column is not in correct format")
|
79 |
+
st.write(df["ds"])
|
80 |
+
st.stop()
|
81 |
+
# -- target column
|
82 |
+
df["y"] = df["y"].apply(lambda x: float(str(x).replace(",", "")))
|
83 |
+
df["y"] = df["y"].astype(float)
|
84 |
+
|
85 |
+
return df
|
86 |
+
|
87 |
+
|
88 |
+
class PredictionApp:
|
89 |
+
@staticmethod
|
90 |
+
def run_prediction(df: pd.DataFrame) -> pd.DataFrame:
|
91 |
+
# -- prepare data and user input
|
92 |
+
|
93 |
+
st1, st2 = st.columns(2)
|
94 |
+
st1.write(df)
|
95 |
+
st2.line_chart(df.set_index("ds"))
|
96 |
+
|
97 |
+
segmented_df = PredictionApp.split_df_by_date(df)
|
98 |
+
|
99 |
+
# -- future date picker
|
100 |
+
n = segmented_df.shape[0]
|
101 |
+
future = st.slider("Number of days to predict", 7, n * 2, value=int(n / 2))
|
102 |
+
|
103 |
+
params = PredictionApp.user_input_model_params()
|
104 |
+
run = st.button("Run")
|
105 |
+
PredictionApp.display_prophet_docs()
|
106 |
+
|
107 |
+
if not run:
|
108 |
+
return
|
109 |
+
|
110 |
+
# -- run prediction
|
111 |
+
with st.spinner("running prediction engine .."):
|
112 |
+
from forecast.fbprophet.model import ProphetModel
|
113 |
+
|
114 |
+
model = ProphetModel()
|
115 |
+
pred = model.predict(segmented_df, period=future, **params)
|
116 |
+
return pred
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def split_df_by_date(df: pd.DataFrame) -> pd.DataFrame:
|
120 |
+
# -- split dataframe by date
|
121 |
+
st.caption("Choose the target fitting period")
|
122 |
+
st1, st2 = st.columns(2)
|
123 |
+
|
124 |
+
from_ = st1.date_input(
|
125 |
+
"from", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.min()
|
126 |
+
)
|
127 |
+
to_ = st2.date_input(
|
128 |
+
"to", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.max()
|
129 |
+
)
|
130 |
+
ix1 = df.index[df.ds == from_][0]
|
131 |
+
ix2 = df.index[df.ds == to_][0]
|
132 |
+
new_df = PredictionApp._displayed_segmented_dataframe(df, from_=ix1, to_=ix2)
|
133 |
+
|
134 |
+
new_df.reset_index(inplace=True)
|
135 |
+
st1.write(f"{new_df['ds'].min()}")
|
136 |
+
st2.write(f"{new_df['ds'].max()}")
|
137 |
+
return new_df
|
138 |
+
|
139 |
+
@classmethod
|
140 |
+
def _displayed_segmented_dataframe(
|
141 |
+
cls, df: pd.DataFrame, from_: int, to_: int
|
142 |
+
) -> pd.DataFrame:
|
143 |
+
df = df.set_index("ds")
|
144 |
+
df_ = df[from_ : to_ + 1]
|
145 |
+
st.line_chart(df_)
|
146 |
+
return df_
|
147 |
+
|
148 |
+
@classmethod
|
149 |
+
def user_input_model_params(cls):
|
150 |
+
raw_params = st.text_input(
|
151 |
+
"Model params. Type param name and its value e.g. growth=logistic"
|
152 |
+
)
|
153 |
+
|
154 |
+
if raw_params:
|
155 |
+
in_params = [x.strip() for x in raw_params.split(",")]
|
156 |
+
params = {}
|
157 |
+
for param in in_params:
|
158 |
+
k, v = param.split("=")
|
159 |
+
params[k] = float(v) if v.isdigit() else v
|
160 |
+
|
161 |
+
if "growth" in params and params["growth"] == "logistic":
|
162 |
+
cap = st.text_input("cap")
|
163 |
+
if cap:
|
164 |
+
params["cap"] = float(cap)
|
165 |
+
else:
|
166 |
+
st.warning("Cap is required for logistic growth")
|
167 |
+
st.stop()
|
168 |
+
st.write("Your input params:")
|
169 |
+
st.write(params)
|
170 |
+
return params
|
171 |
+
return {}
|
172 |
+
|
173 |
+
@classmethod
|
174 |
+
def display_prophet_docs(cls):
|
175 |
+
from prophet import Prophet
|
176 |
+
|
177 |
+
with st.expander("View model params"):
|
178 |
+
|
179 |
+
st.write(Prophet.__doc__)
|
180 |
+
st.markdown(
|
181 |
+
"> more details: [visit](https://facebook.github.io/prophet/)",
|
182 |
+
unsafe_allow_html=True,
|
183 |
+
)
|
184 |
|
185 |
|
186 |
def app():
|
187 |
+
data = InputData.get_data()
|
188 |
+
data = InputData.preprocess_data(data)
|
189 |
+
|
190 |
+
PredictionApp.run_prediction(data)
|
191 |
|
192 |
|
193 |
if __name__ == "__main__":
|
assets/logo.png
ADDED
forecast/fbprophet/model.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
from prophet import Prophet
|
4 |
+
|
5 |
+
|
6 |
+
@st.experimental_singleton
|
7 |
+
class ProphetModel:
|
8 |
+
@staticmethod
|
9 |
+
def predict(df: pd.DataFrame, **kwargs) -> pd.DataFrame:
|
10 |
+
st1, st2 = st.columns(2)
|
11 |
+
params = {
|
12 |
+
"growth": kwargs.get("growth", "linear"),
|
13 |
+
"interval_width": kwargs.get("interval_width", 0.95),
|
14 |
+
}
|
15 |
+
# st.write(params)
|
16 |
+
st.write(kwargs)
|
17 |
+
if "cap" in kwargs:
|
18 |
+
df["cap"] = float(kwargs.get("cap"))
|
19 |
+
period = kwargs.get("period", 7)
|
20 |
+
|
21 |
+
# -- train model
|
22 |
+
m = Prophet(**params)
|
23 |
+
m.fit(df)
|
24 |
+
|
25 |
+
future = m.make_future_dataframe(periods=period)
|
26 |
+
if "cap" in kwargs:
|
27 |
+
future["cap"] = float(kwargs.get("cap"))
|
28 |
+
forecast = m.predict(future)
|
29 |
+
|
30 |
+
# -- display output
|
31 |
+
cols = ["ds", "yhat", "yhat_lower", "yhat_upper"]
|
32 |
+
|
33 |
+
temp_ = forecast.copy()
|
34 |
+
temp_["ds"] = temp_["ds"].apply(lambda x: x.strftime("%Y-%m-%d"))
|
35 |
+
st.write(f"future={period}days")
|
36 |
+
st.write(temp_[cols])
|
37 |
+
|
38 |
+
fig1 = m.plot(forecast)
|
39 |
+
fig2 = m.plot_components(forecast)
|
40 |
+
|
41 |
+
from prophet.plot import plot_plotly, plot_components_plotly
|
42 |
+
|
43 |
+
st1.markdown("> forecasts")
|
44 |
+
st1.plotly_chart(plot_plotly(m, forecast, trend=True), use_container_width=True)
|
45 |
+
st2.markdown("> forecast components")
|
46 |
+
st2.plotly_chart(plot_components_plotly(m, forecast), use_container_width=True)
|
47 |
+
|
48 |
+
# -- download results
|
49 |
+
from forecast.utils import get_table_download_link
|
50 |
+
|
51 |
+
st.markdown(get_table_download_link(forecast), unsafe_allow_html=True)
|
52 |
+
|
53 |
+
st.success("Forecast completed ✨")
|
54 |
+
|
55 |
+
return df
|
forecast/page_config.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
from datetime import date
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
|
7 |
+
def APP_PAGE_HEADER():
|
8 |
+
st.set_page_config(
|
9 |
+
page_title="Simple Forecaster",
|
10 |
+
page_icon=":camel:",
|
11 |
+
layout="wide",
|
12 |
+
initial_sidebar_state="collapsed",
|
13 |
+
)
|
14 |
+
|
15 |
+
hide_style = """
|
16 |
+
<style>
|
17 |
+
#MainMenu {visibility: hidden;}
|
18 |
+
footer {visibility: hidden;}
|
19 |
+
</style>
|
20 |
+
"""
|
21 |
+
st.markdown(hide_style, unsafe_allow_html=True)
|
22 |
+
HEADER()
|
23 |
+
|
24 |
+
|
25 |
+
def HEADER():
|
26 |
+
st_ = st.columns(3)
|
27 |
+
st_[0].markdown("> ## Simple Time-Series Forecast")
|
28 |
+
today = date.today()
|
29 |
+
|
30 |
+
st_[1].image(
|
31 |
+
"./assets/logo.png",
|
32 |
+
caption=f"{today.strftime('%B %d, %Y')}",
|
33 |
+
use_column_width=True,
|
34 |
+
)
|
forecast/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
|
6 |
+
def get_table_download_link(df: pd.DataFrame) -> str:
|
7 |
+
"""Generates a link for download the `df` locally as a csv file.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
df: the dataframe to download
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
href link
|
14 |
+
"""
|
15 |
+
import base64
|
16 |
+
|
17 |
+
csv = df.to_csv(index=False)
|
18 |
+
b64 = base64.b64encode(csv.encode()).decode()
|
19 |
+
return f'<a href="data:file/csv;base64,{b64}" download="data_{datetime.now()}.csv">download</a>'
|
20 |
+
|
21 |
+
|
22 |
+
def filter_df(df: pd.DataFrame, filter_: str) -> pd.DataFrame:
|
23 |
+
"""Takes a dataframe and a `filter_` keyword, returns all the rows that contain the value `filter_` in any column
|
24 |
+
|
25 |
+
Args:
|
26 |
+
df: pandas dataframe
|
27 |
+
filter_: the string to search in the dataframe
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
filtered dataframe
|
31 |
+
"""
|
32 |
+
import numpy as np
|
33 |
+
|
34 |
+
mask = np.column_stack(
|
35 |
+
[df[col].astype(str).str.contains(filter_, na=False) for col in df]
|
36 |
+
)
|
37 |
+
filtered_df = df.loc[mask.any(axis=1)]
|
38 |
+
return filtered_df
|
requirements.txt
CHANGED
@@ -1 +1,11 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python3.8
|
2 |
+
|
3 |
+
# -- data and processing
|
4 |
+
pandas
|
5 |
+
|
6 |
+
# -- UI
|
7 |
+
streamlit==1.2.0
|
8 |
+
|
9 |
+
# -- model
|
10 |
+
pystan==2.19.1.1
|
11 |
+
prophet
|