File size: 2,763 Bytes
b21a05e ec29b9c af6ca4b b21a05e af6ca4b b6df53a 4c984de b21a05e b6df53a b21a05e a3399ca 7e8e2d0 4c984de a3399ca b21a05e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import pandas as pd
import streamlit as st
import vnquant.data as dt
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import statsmodels.api as sm
from statsmodels.tsa.arima.model import ARIMA
from prophet import Prophet
from datetime import datetime, timedelta
import pytz
start_date = str((datetime.now(pytz.timezone('Asia/Ho_Chi_Minh')) - timedelta(days=365)).strftime("%Y-%m-%d"))
end_date = str((datetime.now(pytz.timezone('Asia/Ho_Chi_Minh')) - timedelta(days=0)).strftime("%Y-%m-%d"))
def prophet_ts(symbol, periods = 10):
loader = dt.DataLoader(symbol, start_date, end_date)
data = loader.download()
data.columns = [col[0] for col in data.columns]
m = Prophet()
pdf = pd.DataFrame()
pdf['ds'] = data.index
pdf['y'] = data.close.values
m.fit(pdf)
future = m.make_future_dataframe(periods=periods)
forecast = m.predict(future)
fig = go.Figure()
fig.add_trace(go.Scatter(x= pdf.ds,
y=pdf.y,
name = f"{symbol}_true"
))
fig.add_trace(go.Scatter(x= forecast.ds,
y=forecast.yhat,
name = f"{symbol}_pred"
))
return fig
class TS:
def __init__(self, symbol):
self.symbol = symbol
def get_data(self):
loader = dt.DataLoader(self.symbol, start_date, end_date)
data = loader.download()
data.columns = [col[0] for col in data.columns]
pdf = pd.DataFrame()
pdf['ds'] = data.index
pdf['y'] = data.close.values
return pdf
def prophet(self, period = 28):
df = self.get_data()
model = Prophet()
model.fit(df)
future = model.make_future_dataframe(periods=period)
forecast = model.predict(future)
return self.viz(df, forecast)
def viz(self, data, future):
fig = go.Figure()
fig.add_trace(go.Scatter(x= data.ds,
y=data.y,
name = f"{symbol}_true"
))
fig.add_trace(go.Scatter(x= future.ds,
y=future.yhat,
name = f"{symbol}_pred"
))
return fig
st.title("Vietnam Trading by Time Series")
model = st.selectbox("model", ["ARIMA", "Propphet"])
sb = st.text_input('Symbol', 'FPT')
period = st.slider('Period', 1, 365, 28)
# fig = prophet_ts(symbol=sb, periods = periods)
ts = TS()
fig = None
if model == "Prophet":
fig = ts.prophet(period = period)
st.plotly_chart(fig, use_container_width=True)
|