|
import pandas as pd |
|
from datetime import datetime |
|
from datetime import timedelta |
|
import numpy as np |
|
import statsmodels.api as sm |
|
|
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
df = pd.read_csv('us-shareprices-daily.csv', sep=';') |
|
|
|
def get_model_accuracy(data, ticker_symbol): |
|
|
|
stock_data = data[data['Ticker'] == ticker_symbol] |
|
|
|
|
|
|
|
|
|
train_data, test_data = stock_data[0:int(len(stock_data)*0.85)], stock_data[int(len(stock_data)*0.85):] |
|
training_data = train_data['Close'].values |
|
test_data = test_data['Close'].values |
|
history = [x for x in training_data] |
|
model_predictions = [] |
|
N_test_observations = len(test_data) |
|
for time_point in range(N_test_observations): |
|
model = sm.tsa.statespace.SARIMAX(history, order=(1,1,1)) |
|
model_fit = model.fit(disp=0) |
|
output = model_fit.forecast() |
|
yhat = output[0] |
|
model_predictions.append(yhat) |
|
true_test_value = test_data[time_point] |
|
history.append(true_test_value) |
|
|
|
MSE_error = mean_squared_error(test_data, model_predictions) |
|
return 'Testing Mean Squared Error is {}'.format(MSE_error) |
|
|
|
|
|
def arima_chart(tickers): |
|
df = pd.read_csv('data_and_sp500.csv') |
|
df = df[['Date']+tickers] |
|
fig = px.line(df, x='Date', y=df.columns) |
|
|
|
for ticker in tickers: |
|
x = np.array(df['Date']) |
|
y = np.array(df[ticker]) |
|
ticker_df = pd.concat([df['Date'], df[ticker]], axis=1) |
|
|
|
model = sm.tsa.statespace.SARIMAX(ticker_df[ticker], order=(21,1,7)) |
|
model_fit = model.fit(disp=-1) |
|
|
|
forecast = model_fit.forecast(7, alpha=0.05) |
|
begin_date = datetime.strptime('2021-10-22', '%Y-%m-%d') |
|
forecast_dates = [begin_date+timedelta(days=i-1258) for i in forecast.index] |
|
fig.add_trace(go.Scatter(x=forecast_dates, y=forecast.to_list(), |
|
mode='lines', |
|
name='{} forecast'.format(ticker))) |
|
|
|
fig.update_xaxes(range=[begin_date-timedelta(days=120), begin_date+timedelta(days=10)]) |
|
st.plotly_chart(fig, use_container_width=True) |
|
|