tonne commited on
Commit
a3399ca
1 Parent(s): 7e8e2d0
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -37,10 +37,45 @@ def prophet_ts(symbol, periods = 10):
37
  name = f"{symbol}_pred"
38
  ))
39
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  st.title("Vietnam Trading by Time Series")
42
  model = st.selectbox("model", ["ARIMA", "Propphet"])
43
  sb = st.text_input('Symbol', 'FPT')
44
- periods = st.slider('Period', 1, 365, 28)
45
- fig = prophet_ts(symbol=sb, periods = periods)
 
 
 
 
46
  st.plotly_chart(fig, use_container_width=True)
 
37
  name = f"{symbol}_pred"
38
  ))
39
  return fig
40
+ class TS:
41
+ def __init__(self, symbol):
42
+ self.symbol = symbol
43
+ def get_data(self):
44
+ loader = dt.DataLoader(self.symbol, start_date, end_date)
45
+ data = loader.download()
46
+ data.columns = [col[0] for col in data.columns]
47
+ pdf = pd.DataFrame()
48
+ pdf['ds'] = data.index
49
+ pdf['y'] = data.close.values
50
+ return pdf
51
+ def prophet(self, period = 28):
52
+ df = self.get_data()
53
+ model = Prophet()
54
+ model.fit(df)
55
+ future = model.make_future_dataframe(periods=period)
56
+ forecast = model.predict(future)
57
+ return self.viz(df, forecast)
58
+ def viz(self, data, future):
59
+ fig = go.Figure()
60
+ fig.add_trace(go.Scatter(x= data.ds,
61
+ y=data.y,
62
+ name = f"{symbol}_true"
63
+ ))
64
+ fig.add_trace(go.Scatter(x= future.ds,
65
+ y=future.yhat,
66
+ name = f"{symbol}_pred"
67
+ ))
68
+ return fig
69
+
70
+
71
 
72
  st.title("Vietnam Trading by Time Series")
73
  model = st.selectbox("model", ["ARIMA", "Propphet"])
74
  sb = st.text_input('Symbol', 'FPT')
75
+ period = st.slider('Period', 1, 365, 28)
76
+ # fig = prophet_ts(symbol=sb, periods = periods)
77
+ ts = TS()
78
+ fig = None
79
+ if model == "Prophet":
80
+ fig = ts.prophet(period = period)
81
  st.plotly_chart(fig, use_container_width=True)