cse6242-dataminers / arima.py
Lirsen Myrtaj
Update arima.py
dc95b73
import pandas as pd
from datetime import datetime
from datetime import timedelta
import numpy as np
import statsmodels.api as sm
import plotly.express as px
import plotly.graph_objects as go
import warnings
warnings.filterwarnings("ignore")
# from sklearn.metrics import mean_squared_error
df = pd.read_csv('us-shareprices-daily.csv', sep=';')
def get_model_accuracy(data, ticker_symbol):
stock_data = data[data['Ticker'] == ticker_symbol]
# get MSE for testing data using 85/15 split for chosen stock symbol
train_data, test_data = stock_data[0:int(len(stock_data)*0.85)], stock_data[int(len(stock_data)*0.85):]
training_data = train_data['Close'].values
test_data = test_data['Close'].values
history = [x for x in training_data]
model_predictions = []
N_test_observations = len(test_data)
for time_point in range(N_test_observations):
model = sm.tsa.statespace.SARIMAX(history, order=(1,1,1))
model_fit = model.fit(disp=0)
output = model_fit.forecast()
yhat = output[0]
model_predictions.append(yhat)
true_test_value = test_data[time_point]
history.append(true_test_value)
MSE_error = mean_squared_error(test_data, model_predictions)
return 'Testing Mean Squared Error is {}'.format(MSE_error)
def arima_chart(tickers):
df = pd.read_csv('data_and_sp500.csv')
df = df[['Date']+tickers]
fig = px.line(df, x='Date', y=df.columns)
for ticker in tickers:
x = np.array(df['Date'])
y = np.array(df[ticker])
ticker_df = pd.concat([df['Date'], df[ticker]], axis=1)
model = sm.tsa.statespace.SARIMAX(ticker_df[ticker], order=(21,1,7))
model_fit = model.fit(disp=-1)
# print(model_fit.summary())
forecast = model_fit.forecast(7, alpha=0.05)#.predict(start=1259, end=1289)
begin_date = datetime.strptime('2021-10-22', '%Y-%m-%d')
forecast_dates = [begin_date+timedelta(days=i-1258) for i in forecast.index]
fig.add_trace(go.Scatter(x=forecast_dates, y=forecast.to_list(),
mode='lines',
name='{} forecast'.format(ticker)))
fig.update_xaxes(range=[begin_date-timedelta(days=120), begin_date+timedelta(days=10)])
st.plotly_chart(fig, use_container_width=True)