azizalto commited on
Commit
674f526
1 Parent(s): 44ebc9f

init forecaster

Browse files
app.py CHANGED
@@ -1,8 +1,193 @@
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  def app():
5
- st.write("Hello, World!")
 
 
 
6
 
7
 
8
  if __name__ == "__main__":
 
1
+ from forecast.page_config import APP_PAGE_HEADER
2
+
3
  import streamlit as st
4
+ import pandas as pd
5
+
6
+ APP_PAGE_HEADER()
7
+
8
+
9
+ class InputData:
10
+ @classmethod
11
+ def get_data(cls) -> pd.DataFrame:
12
+ """
13
+ Datasets sources:
14
+ avg_daily_air_temp_celsius_helsinki: http://shorturl.at/gBR06
15
+
16
+ Returns:
17
+
18
+ """
19
+ sample = st.selectbox(
20
+ "Sample datasets",
21
+ options=["", "sample1", "sample2", "avg_daily_air_temp_celsius_helsinki"],
22
+ )
23
+ if sample:
24
+ file_ = f"data/{sample}.csv"
25
+ return pd.read_csv(file_)
26
+
27
+ uploaded_data = cls.read_file()
28
+ if uploaded_data is not None:
29
+ return uploaded_data
30
+
31
+ @classmethod
32
+ def read_file(cls):
33
+ global DATE_FORMATTER
34
+
35
+ file_ = st.file_uploader("Upload your dataset (csv file)")
36
+ if not file_:
37
+ st.stop()
38
+
39
+ if file_:
40
+ sep = st.selectbox("column sep", options=[",", ";", "|"])
41
+ df = pd.read_csv(file_, sep=sep)
42
+
43
+ cols = df.columns.tolist()
44
+
45
+ # -- choose date/target columns
46
+ st1, st3, st2 = st.columns(3)
47
+ date_col = st1.selectbox(
48
+ "Date column (x-axis / index)", options=[""] + cols
49
+ )
50
+ st.session_state.date_formatter = st2.text_input(
51
+ "Optional: Date format e.g. %Y-%m-%d"
52
+ )
53
+ target = st3.selectbox(
54
+ "Target column (y-axis / target variable)", options=[""] + cols
55
+ )
56
+
57
+ # -- display
58
+ if date_col and target:
59
+ df = df[[date_col, target]]
60
+ return df
61
+ st.write(df)
62
+ st.stop()
63
+
64
+ return file_
65
+
66
+ @classmethod
67
+ def preprocess_data(cls, df: pd.DataFrame) -> pd.DataFrame:
68
+ df.columns = ["ds", "y"]
69
+ # -- date column
70
+ try:
71
+ if st.session_state.date_formatter:
72
+ df["ds"] = pd.to_datetime(
73
+ df["ds"], format=st.session_state.date_formatter
74
+ )
75
+ else:
76
+ df["ds"] = pd.to_datetime(df["ds"]).dt.date
77
+ except:
78
+ st.error("Date column is not in correct format")
79
+ st.write(df["ds"])
80
+ st.stop()
81
+ # -- target column
82
+ df["y"] = df["y"].apply(lambda x: float(str(x).replace(",", "")))
83
+ df["y"] = df["y"].astype(float)
84
+
85
+ return df
86
+
87
+
88
+ class PredictionApp:
89
+ @staticmethod
90
+ def run_prediction(df: pd.DataFrame) -> pd.DataFrame:
91
+ # -- prepare data and user input
92
+
93
+ st1, st2 = st.columns(2)
94
+ st1.write(df)
95
+ st2.line_chart(df.set_index("ds"))
96
+
97
+ segmented_df = PredictionApp.split_df_by_date(df)
98
+
99
+ # -- future date picker
100
+ n = segmented_df.shape[0]
101
+ future = st.slider("Number of days to predict", 7, n * 2, value=int(n / 2))
102
+
103
+ params = PredictionApp.user_input_model_params()
104
+ run = st.button("Run")
105
+ PredictionApp.display_prophet_docs()
106
+
107
+ if not run:
108
+ return
109
+
110
+ # -- run prediction
111
+ with st.spinner("running prediction engine .."):
112
+ from forecast.fbprophet.model import ProphetModel
113
+
114
+ model = ProphetModel()
115
+ pred = model.predict(segmented_df, period=future, **params)
116
+ return pred
117
+
118
+ @staticmethod
119
+ def split_df_by_date(df: pd.DataFrame) -> pd.DataFrame:
120
+ # -- split dataframe by date
121
+ st.caption("Choose the target fitting period")
122
+ st1, st2 = st.columns(2)
123
+
124
+ from_ = st1.date_input(
125
+ "from", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.min()
126
+ )
127
+ to_ = st2.date_input(
128
+ "to", min_value=df.ds.min(), max_value=df.ds.max(), value=df.ds.max()
129
+ )
130
+ ix1 = df.index[df.ds == from_][0]
131
+ ix2 = df.index[df.ds == to_][0]
132
+ new_df = PredictionApp._displayed_segmented_dataframe(df, from_=ix1, to_=ix2)
133
+
134
+ new_df.reset_index(inplace=True)
135
+ st1.write(f"{new_df['ds'].min()}")
136
+ st2.write(f"{new_df['ds'].max()}")
137
+ return new_df
138
+
139
+ @classmethod
140
+ def _displayed_segmented_dataframe(
141
+ cls, df: pd.DataFrame, from_: int, to_: int
142
+ ) -> pd.DataFrame:
143
+ df = df.set_index("ds")
144
+ df_ = df[from_ : to_ + 1]
145
+ st.line_chart(df_)
146
+ return df_
147
+
148
+ @classmethod
149
+ def user_input_model_params(cls):
150
+ raw_params = st.text_input(
151
+ "Model params. Type param name and its value e.g. growth=logistic"
152
+ )
153
+
154
+ if raw_params:
155
+ in_params = [x.strip() for x in raw_params.split(",")]
156
+ params = {}
157
+ for param in in_params:
158
+ k, v = param.split("=")
159
+ params[k] = float(v) if v.isdigit() else v
160
+
161
+ if "growth" in params and params["growth"] == "logistic":
162
+ cap = st.text_input("cap")
163
+ if cap:
164
+ params["cap"] = float(cap)
165
+ else:
166
+ st.warning("Cap is required for logistic growth")
167
+ st.stop()
168
+ st.write("Your input params:")
169
+ st.write(params)
170
+ return params
171
+ return {}
172
+
173
+ @classmethod
174
+ def display_prophet_docs(cls):
175
+ from prophet import Prophet
176
+
177
+ with st.expander("View model params"):
178
+
179
+ st.write(Prophet.__doc__)
180
+ st.markdown(
181
+ "> more details: [visit](https://facebook.github.io/prophet/)",
182
+ unsafe_allow_html=True,
183
+ )
184
 
185
 
186
  def app():
187
+ data = InputData.get_data()
188
+ data = InputData.preprocess_data(data)
189
+
190
+ PredictionApp.run_prediction(data)
191
 
192
 
193
  if __name__ == "__main__":
assets/logo.png ADDED
forecast/fbprophet/model.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from prophet import Prophet
4
+
5
+
6
+ @st.experimental_singleton
7
+ class ProphetModel:
8
+ @staticmethod
9
+ def predict(df: pd.DataFrame, **kwargs) -> pd.DataFrame:
10
+ st1, st2 = st.columns(2)
11
+ params = {
12
+ "growth": kwargs.get("growth", "linear"),
13
+ "interval_width": kwargs.get("interval_width", 0.95),
14
+ }
15
+ # st.write(params)
16
+ st.write(kwargs)
17
+ if "cap" in kwargs:
18
+ df["cap"] = float(kwargs.get("cap"))
19
+ period = kwargs.get("period", 7)
20
+
21
+ # -- train model
22
+ m = Prophet(**params)
23
+ m.fit(df)
24
+
25
+ future = m.make_future_dataframe(periods=period)
26
+ if "cap" in kwargs:
27
+ future["cap"] = float(kwargs.get("cap"))
28
+ forecast = m.predict(future)
29
+
30
+ # -- display output
31
+ cols = ["ds", "yhat", "yhat_lower", "yhat_upper"]
32
+
33
+ temp_ = forecast.copy()
34
+ temp_["ds"] = temp_["ds"].apply(lambda x: x.strftime("%Y-%m-%d"))
35
+ st.write(f"future={period}days")
36
+ st.write(temp_[cols])
37
+
38
+ fig1 = m.plot(forecast)
39
+ fig2 = m.plot_components(forecast)
40
+
41
+ from prophet.plot import plot_plotly, plot_components_plotly
42
+
43
+ st1.markdown("> forecasts")
44
+ st1.plotly_chart(plot_plotly(m, forecast, trend=True), use_container_width=True)
45
+ st2.markdown("> forecast components")
46
+ st2.plotly_chart(plot_components_plotly(m, forecast), use_container_width=True)
47
+
48
+ # -- download results
49
+ from forecast.utils import get_table_download_link
50
+
51
+ st.markdown(get_table_download_link(forecast), unsafe_allow_html=True)
52
+
53
+ st.success("Forecast completed ✨")
54
+
55
+ return df
forecast/page_config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from datetime import date
3
+
4
+ import streamlit as st
5
+
6
+
7
+ def APP_PAGE_HEADER():
8
+ st.set_page_config(
9
+ page_title="Simple Forecaster",
10
+ page_icon=":camel:",
11
+ layout="wide",
12
+ initial_sidebar_state="collapsed",
13
+ )
14
+
15
+ hide_style = """
16
+ <style>
17
+ #MainMenu {visibility: hidden;}
18
+ footer {visibility: hidden;}
19
+ </style>
20
+ """
21
+ st.markdown(hide_style, unsafe_allow_html=True)
22
+ HEADER()
23
+
24
+
25
+ def HEADER():
26
+ st_ = st.columns(3)
27
+ st_[0].markdown("> ## Simple Time-Series Forecast")
28
+ today = date.today()
29
+
30
+ st_[1].image(
31
+ "./assets/logo.png",
32
+ caption=f"{today.strftime('%B %d, %Y')}",
33
+ use_column_width=True,
34
+ )
forecast/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+
3
+ import pandas as pd
4
+
5
+
6
+ def get_table_download_link(df: pd.DataFrame) -> str:
7
+ """Generates a link for download the `df` locally as a csv file.
8
+
9
+ Args:
10
+ df: the dataframe to download
11
+
12
+ Returns:
13
+ href link
14
+ """
15
+ import base64
16
+
17
+ csv = df.to_csv(index=False)
18
+ b64 = base64.b64encode(csv.encode()).decode()
19
+ return f'<a href="data:file/csv;base64,{b64}" download="data_{datetime.now()}.csv">download</a>'
20
+
21
+
22
+ def filter_df(df: pd.DataFrame, filter_: str) -> pd.DataFrame:
23
+ """Takes a dataframe and a `filter_` keyword, returns all the rows that contain the value `filter_` in any column
24
+
25
+ Args:
26
+ df: pandas dataframe
27
+ filter_: the string to search in the dataframe
28
+
29
+ Returns:
30
+ filtered dataframe
31
+ """
32
+ import numpy as np
33
+
34
+ mask = np.column_stack(
35
+ [df[col].astype(str).str.contains(filter_, na=False) for col in df]
36
+ )
37
+ filtered_df = df.loc[mask.any(axis=1)]
38
+ return filtered_df
requirements.txt CHANGED
@@ -1 +1,11 @@
1
- streamlit
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python3.8
2
+
3
+ # -- data and processing
4
+ pandas
5
+
6
+ # -- UI
7
+ streamlit==1.2.0
8
+
9
+ # -- model
10
+ pystan==2.19.1.1
11
+ prophet