AlexStav commited on
Commit
9e62a5b
1 Parent(s): cd0ecd7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from prophet import Prophet
4
+ import yfinance as yf
5
+ from sklearn.metrics import mean_absolute_error, mean_squared_error
6
+ from prophet.plot import plot_plotly, plot_components_plotly
7
+
8
+ # Function to fetch stock data from Yahoo Finance
9
+ def fetch_stock_data(ticker_symbol, start_date, end_date):
10
+ stock_data = yf.download(ticker_symbol, start=start_date, end=end_date)
11
+ df = stock_data[['Adj Close']].reset_index()
12
+ df = df.rename(columns={'Date': 'ds', 'Adj Close': 'y'})
13
+ return df
14
+
15
+ # Function to train the Prophet model
16
+ def train_prophet_model(df):
17
+ model = Prophet()
18
+ model.fit(df)
19
+ return model
20
+
21
+ # Function to make the forecast
22
+ def make_forecast(model, periods):
23
+ future = model.make_future_dataframe(periods=periods)
24
+ forecast = model.predict(future)
25
+ return forecast
26
+
27
+ # Function to calculate performance metrics
28
+ def calculate_performance_metrics(actual, predicted):
29
+ mae = mean_absolute_error(actual, predicted)
30
+ mse = mean_squared_error(actual, predicted)
31
+ rmse = np.sqrt(mse)
32
+ return {'MAE': mae, 'MSE': mse, 'RMSE': rmse}
33
+
34
+ # Streamlit app
35
+ def main():
36
+ st.title('Stock Forecasting with Prophet')
37
+
38
+ # Set up the layout
39
+ st.sidebar.header('User Input Parameters')
40
+ ticker_symbol = st.sidebar.text_input('Enter Ticker Symbol', 'RACE')
41
+ start_date = st.sidebar.date_input('Start Date', value=pd.to_datetime('2015-01-01'))
42
+ end_date = st.sidebar.date_input('End Date', value=pd.to_datetime('today'))
43
+
44
+ # Dropdown for forecast horizon selection
45
+ forecast_horizon = st.sidebar.selectbox('Forecast Horizon',
46
+ options=['1 year', '2 years', '3 years', '5 years'],
47
+ format_func=lambda x: x.capitalize())
48
+
49
+ # Convert the selected horizon to days
50
+ horizon_mapping = {'1 year': 365, '2 years': 730, '3 years': 1095, '5 years': 1825}
51
+ forecast_days = horizon_mapping[forecast_horizon]
52
+
53
+ if st.sidebar.button('Forecast Stock Prices'):
54
+ with st.spinner('Fetching data...'):
55
+ df = fetch_stock_data(ticker_symbol, start_date, end_date)
56
+
57
+ with st.spinner('Training model...'):
58
+ model = train_prophet_model(df)
59
+ forecast = make_forecast(model, forecast_days)
60
+
61
+ st.subheader('Forecast Data')
62
+ st.write('The table below shows the forecasted stock prices along with the lower and upper bounds of the predictions.')
63
+ st.write(forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail())
64
+
65
+ st.subheader('Forecast Plot')
66
+ st.write('The plot below visualizes the predicted stock prices with their confidence intervals.')
67
+ fig1 = plot_plotly(model, forecast)
68
+ fig1.update_traces(marker=dict(color='red'), line=dict(color='black'))
69
+ st.plotly_chart(fig1)
70
+
71
+ st.subheader('Forecast Components')
72
+ st.write('This plot breaks down the forecast into trend, weekly, and yearly components.')
73
+ fig2 = plot_components_plotly(model, forecast)
74
+ fig2.update_traces(line=dict(color='black'))
75
+ st.plotly_chart(fig2)
76
+
77
+ st.subheader('Performance Metrics')
78
+ st.write('The metrics below provide a quantitative measure of the model’s accuracy. The Mean Absolute Error (MAE) is the average absolute difference between predicted and actual values, Mean Squared Error (MSE) is the average squared difference, and Root Mean Squared Error (RMSE) is the square root of MSE, which is more interpretable in the same units as the target variable.')
79
+ actual = df['y']
80
+ predicted = forecast['yhat'][:len(df)]
81
+ metrics = calculate_performance_metrics(actual, predicted)
82
+ st.metric(label="Mean Absolute Error (MAE)", value="{:.2f}".format(metrics['MAE']), delta="Lower is better")
83
+ st.metric(label="Mean Squared Error (MSE)", value="{:.2f}".format(metrics['MSE']), delta="Lower is better")
84
+ st.metric(label="Root Mean Squared Error (RMSE)", value="{:.2f}".format(metrics['RMSE']), delta="Lower is better")
85
+
86
+ # Run the main function
87
+ if __name__ == "__main__":
88
+ main()