predictaoil / 🏠_Home.py
janrswong's picture
added wti
3d367a6
import streamlit as st
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
# import numpy as np
import plotly.express as px
from st_aggrid import GridOptionsBuilder, AgGrid
import plotly.graph_objects as go
from style import add_logo
from Brent import displayBrent
from WTI import displayWTI
from st_btn_select import st_btn_select
from streamlit_option_menu import option_menu
hide_menu_style = """
<style>
#MainMenu{visibility: hidden;}
footer{visibility:hidden;}
</style>
"""
# page expands to full width
st.set_page_config(page_title="Predicta.oil | Home",
layout='wide', page_icon="β›½")
st.markdown(hide_menu_style, unsafe_allow_html=True)
add_logo()
# PAGE LAYOUT
# heading
st.title("Crude Oil Benchmark Stock Price Prediction LSTM and ARIMA Models")
selection = option_menu(None, ["Brent", "WTI"],
icons=['droplet', 'droplet'],
menu_icon="cast", default_index=0, orientation="horizontal")
if selection == 'Brent':
displayBrent()
if selection == 'WTI':
displayWTI()
# selection = st_btn_select(('Brent', 'WTI'), index=0)
# st.write(f'Selected option: {selection}')
# if selection == 'Brent':
# displayBrent()
# if selection == 'WTI':
# st.header("Raw Data")
# with st.container():
# col1, col2 = st.columns(2)
# with col1:
# brent = st.button("Brent Crude Oil")
# st.write(brent)
# with col2:
# wti = st.button("WTI Crude Oil")
# st.write(wti)
# brent, wti = st.columns([.5, 1])
# brent, wti = st.columns([2])
# with brent:
# st.button('1')
# with wti:
# st.button('2')
# RawData = st.selectbox(
# "Select Data Set:", ('Brent', 'WTI'), key='dataTypeKey')
# st.write(RawData)
# if RawData == 'Brent':
# displayBrent()
# elif RawData == 'WTI':
# # select time interval
# interv = st.select_slider('Select Time Series Data Interval for Prediction', options=[
# 'Daily', 'Weekly', 'Monthly', 'Quarterly'], value='Weekly')
# # st.write(interv[0])
# # Function to convert time series to interval
# @st.cache(persist=True, allow_output_mutation=True)
# def getInterval(argument):
# switcher = {
# "W": "1wk",
# "M": "1mo",
# "Q": "3mo",
# "D": "1d"
# }
# return switcher.get(argument, "1wk")
# # show raw data
# # st.header("Raw Data")
# # using button
# # if st.button('Press to see Brent Crude Oil Raw Data'):
# df = yf.download('BZ=F', interval=getInterval(interv[0]), end="2022-06-30")
# # st.dataframe(df.head())
# df = df.reset_index()
# def pagination(df):
# gb = GridOptionsBuilder.from_dataframe(df)
# gb.configure_pagination(paginationAutoPageSize=True)
# return gb.build()
# # enable enterprise modules for trial only
# # raw data
# page = pagination(df)
# # AgGrid(df, enable_enterprise_modules=True,
# # theme='streamlit', gridOptions=page, fit_columns_on_grid_load=True, key='data')
# # st.dataframe(df, width=2000, height=600)
# # st.write(df)
# st.table(df.head())
# # download full data
# @st.cache
# def convert_df(df):
# # IMPORTANT: Cache the conversion to prevent computation on every rerun
# return df.to_csv().encode('utf-8')
# csv = convert_df(df)
# st.download_button(
# label="Download data as CSV",
# data=csv,
# file_name='Brent Oil Prices.csv',
# mime='text/csv',
# )
# st.header("Standard Deviation of Brent Crude Oil")
# sd = pd.read_csv('StandardDeviation.csv')
# sd.drop("Unnamed: 0", axis=1, inplace=True)
# # sd = sd.reset_index()
# AgGrid(sd, key='SD1', enable_enterprise_modules=True,
# fit_columns_on_grid_load=True, theme='streamlit')
# st.write("Note: All entries end on 2022-6-30.")
# sd = sd.pivot(index='Start Date', columns='Interval',
# values='Standard Deviation')
# sd = sd.reset_index()
# # table
# # AgGrid(sd, key='SD', enable_enterprise_modules=True,
# # fit_columns_on_grid_load=True, domLayout='autoHeight', theme='streamlit')
# # visualization
# fig = px.line(sd, x=sd.index, y=['1d', '1wk', '1mo', '3mo'],
# title="STANDARD DEVIATION OF BRENT CRUDE OIL PRICES", width=1000)
# st.plotly_chart(fig, use_container_width=True)
# # accuracy metrics
# st.header("Accuracy Metric Comparison")
# intervals = st.selectbox(
# "Select Interval:", ('Weekly', 'Monthly', 'Quarterly', 'Daily'), key='metricKey')
# with st.container():
# col1, col2 = st.columns(2)
# # LSTM METRICS
# # st.write("LSTM Metrics")
# readfile = pd.read_csv('LSTM.csv')
# # readfile = readfile[readfile['Interval'] == intervals.upper()]
# readfile = readfile[readfile['Interval'] == st.session_state.metricKey.upper()]
# # readfile[readfile['Interval'] == intervals.upper()]
# # readfile = updatefile(readfile)
# readfile.drop("Unnamed: 0", axis=1, inplace=True)
# with col1:
# st.write("LSTM Metrics")
# AgGrid(readfile, key=st.session_state.metricKey, fit_columns_on_grid_load=True,
# enable_enterprise_modules=True, theme='streamlit')
# # st.write(st.session_state.metricKey)
# # ARIMA METRICS
# # st.write("ARIMA Metrics")
# # intervals = st.selectbox(
# # "Select Interval:", ('Weekly', 'Monthly', 'Quarterly', 'Daily'))
# if intervals == 'Weekly':
# file = pd.read_csv('ARIMAMetrics/ARIMA-WEEKLY.csv')
# file.drop("Unnamed: 0", axis=1, inplace=True)
# page = pagination(file)
# with col2:
# st.write("ARIMA Metrics")
# AgGrid(file, width='100%', theme='streamlit', enable_enterprise_modules=True,
# fit_columns_on_grid_load=True, key='weeklyMetric', gridOptions=page)
# elif intervals == 'Monthly':
# file = pd.read_csv('ARIMAMetrics/ARIMA-MONTHLY.csv')
# file.drop("Unnamed: 0", axis=1, inplace=True)
# page = pagination(file)
# with col2:
# st.write("ARIMA Metrics")
# AgGrid(file, key='monthlyMetric', fit_columns_on_grid_load=True,
# enable_enterprise_modules=True, theme='streamlit', gridOptions=page)
# elif intervals == 'Quarterly':
# file = pd.read_csv('ARIMAMetrics/ARIMA-QUARTERLY.csv')
# file.drop("Unnamed: 0", axis=1, inplace=True)
# page = pagination(file)
# with col2:
# st.write("ARIMA Metrics")
# AgGrid(file, key='quarterlyMetric', fit_columns_on_grid_load=True,
# enable_enterprise_modules=True, theme='streamlit', gridOptions=page)
# elif intervals == 'Daily':
# file = pd.read_csv('ARIMAMetrics/ARIMA-DAILY.csv')
# file.drop("Unnamed: 0", axis=1, inplace=True)
# page = pagination(file)
# with col2:
# st.write("ARIMA Metrics")
# AgGrid(file, key='dailyMetric', width='100%', fit_columns_on_grid_load=True,
# enable_enterprise_modules=True, theme='streamlit', gridOptions=page)
# # MODEL OUTPUT TABLE
# st.header("Model Output (Close Prices vs. Predicted Prices)")
# interval = st.selectbox("Select Interval:", ('Weekly',
# 'Monthly', 'Quarterly', 'Daily'), key='bestmodels')
# if interval == 'Weekly':
# file = pd.read_csv('bestWeekly.csv')
# page = pagination(file)
# AgGrid(file, key='weeklycombined', fit_columns_on_grid_load=True,
# enable_enterprise_modules=True, theme='streamlit', gridOptions=page)
# # Visualization
# st.header("Visualization")
# fig = px.line(file, x=file["Date"], y=["Close Prices", "ARIMA_50.0_(1, 0, 0)_Predictions",
# "LSTM_80.0_Predictions"], title="BOTH PREDICTED BRENT CRUDE OIL PRICES", width=1000)
# st.plotly_chart(fig, use_container_width=True)
# elif interval == 'Monthly':
# file = pd.read_csv('bestMonthly.csv')
# page = pagination(file)
# AgGrid(file, key='monthlyCombined', fit_columns_on_grid_load=True,
# enable_enterprise_modules=True, theme='streamlit', gridOptions=page)
# # Visualization
# st.header("Visualization")
# fig = px.line(file, x=file["Date"], y=["Close Prices", "ARIMA_60.0_(0, 1, 1)_Predictions", # find file
# "LSTM_80.0_Predictions"], title="BOTH PREDICTED BRENT CRUDE OIL PRICES", width=1000)
# st.plotly_chart(fig, use_container_width=True)
# elif interval == 'Quarterly':
# file = pd.read_csv('bestQuarterly.csv')
# page = pagination(file)
# AgGrid(file, key='quarterlyCombined', fit_columns_on_grid_load=True,
# enable_enterprise_modules=True, theme='streamlit', gridOptions=page)
# # Visualization
# st.header("Visualization")
# fig = px.line(file, x=file["Date"], y=["Close Prices", "ARIMA_50.0_(0, 1, 0)_Predictions", # find file
# "LSTM_80.0_Predictions"], title="BOTH PREDICTED BRENT CRUDE OIL PRICES", width=1000)
# st.plotly_chart(fig, use_container_width=True)
# elif interval == 'Daily':
# file = pd.read_csv('bestDaily.csv')
# page = pagination(file)
# AgGrid(file, key='dailyCombined', fit_columns_on_grid_load=True,
# enable_enterprise_modules=True, theme='streamlit', gridOptions=page)
# # Visualization
# st.header("Visualization")
# fig = px.line(file, x=file["Date"], y=["Close Prices", "ARIMA_50.0_(0, 1, 0)_Predictions", # find file
# "LSTM_60.0_Predictions"], title="BOTH PREDICTED BRENT CRUDE OIL PRICES", width=1000)
# st.plotly_chart(fig, use_container_width=True)