trader / app.py
tonne's picture
add model
e2fb478
raw
history blame
2.77 kB
import pandas as pd
import streamlit as st
import vnquant.data as dt
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import statsmodels.api as sm
from statsmodels.tsa.arima.model import ARIMA
from prophet import Prophet
from datetime import datetime, timedelta
import pytz
start_date = str((datetime.now(pytz.timezone('Asia/Ho_Chi_Minh')) - timedelta(days=365)).strftime("%Y-%m-%d"))
end_date = str((datetime.now(pytz.timezone('Asia/Ho_Chi_Minh')) - timedelta(days=0)).strftime("%Y-%m-%d"))
def prophet_ts(symbol, periods = 10):
loader = dt.DataLoader(symbol, start_date, end_date)
data = loader.download()
data.columns = [col[0] for col in data.columns]
m = Prophet()
pdf = pd.DataFrame()
pdf['ds'] = data.index
pdf['y'] = data.close.values
m.fit(pdf)
future = m.make_future_dataframe(periods=periods)
forecast = m.predict(future)
fig = go.Figure()
fig.add_trace(go.Scatter(x= pdf.ds,
y=pdf.y,
name = f"{symbol}_true"
))
fig.add_trace(go.Scatter(x= forecast.ds,
y=forecast.yhat,
name = f"{symbol}_pred"
))
return fig
class TS:
def __init__(self, symbol):
self.symbol = symbol
def get_data(self):
loader = dt.DataLoader(self.symbol, start_date, end_date)
data = loader.download()
data.columns = [col[0] for col in data.columns]
pdf = pd.DataFrame()
pdf['ds'] = data.index
pdf['y'] = data.close.values
return pdf
def prophet(self, period = 28):
df = self.get_data()
model = Prophet()
model.fit(df)
future = model.make_future_dataframe(periods=period)
forecast = model.predict(future)
return self.viz(df, forecast)
def viz(self, data, future):
fig = go.Figure()
fig.add_trace(go.Scatter(x= data.ds,
y=data.y,
name = f"{self.symbol}_true"
))
fig.add_trace(go.Scatter(x= future.ds,
y=future.yhat,
name = f"{self.symbol}_pred"
))
return fig
st.title("Vietnam Trading by Time Series")
model = st.selectbox("model", ["ARIMA", "Propphet"])
sb = st.text_input('Symbol', 'FPT')
period = st.slider('Period', 1, 365, 28)
# fig = prophet_ts(symbol=sb, periods = periods)
ts = TS()
fig = None
if model == "Prophet":
fig = ts.prophet(period = period)
st.plotly_chart(fig, use_container_width=True)