Spaces:
Runtime error
Runtime error
File size: 3,679 Bytes
9af85d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import datetime
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import yfinance as yf
import plotly.graph_objs as go
import plotly.express as px
from prophet import Prophet
from workcell.integrations.types import PlotlyPlot
class Input(BaseModel):
ticker: str = Field(default="AAPL", description="A ticker value, like `AAPL`, etc...")
def load_data(ticker):
"""Download ticker price data from ticker.
e.g. ticker = 'AAPL'|'AMZN'|'GOOG'
"""
start = datetime.datetime(2022, 1, 1)
end = datetime.datetime.now() # latest
data = yf.download(ticker, start=start, end=end, interval='1d')
# adjust close
close = data['Adj Close']
return close
def preprocess_data(df):
"""
Preprocess dataframe for prediction.
- Filter out predict value.
"""
# post process
df_processed = df.reset_index()
df_processed.rename(columns={'Adj Close': 'y', 'Date': 'ds'}, inplace=True)
return df_processed
def predict_data(df, periods=30):
"""Predict future prices by prophet.
e.g. df = preprocess_df(df)
"""
# init prophet model
model = Prophet()
# fit
model.fit(df)
# predict data
future_prices = model.make_future_dataframe(periods=periods)
forecast = model.predict(future_prices)
# forecast data
df_forecast = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
return df_forecast
def visualization(df_processed, df_forecast, ticker):
"""Visualization price plot by df_forecast dataframe.
"""
trace_open = go.Scatter(
x = df_forecast["ds"],
y = df_forecast["yhat"],
mode = 'lines',
name="Forecast"
)
trace_high = go.Scatter(
x = df_forecast["ds"],
y = df_forecast["yhat_upper"],
mode = 'lines',
fill = "tonexty",
line = {"color": "#57b8ff"},
name="Higher uncertainty interval"
)
trace_low = go.Scatter(
x = df_forecast["ds"],
y = df_forecast["yhat_lower"],
mode = 'lines',
fill = "tonexty",
line = {"color": "#57b8ff"},
name="Lower uncertainty interval"
)
trace_close = go.Scatter(
x = df_processed["ds"],
y = df_processed["y"],
name="Data values"
)
data = [trace_open,trace_high,trace_low,trace_close]
layout = go.Layout(title="Repsol Stock Price Forecast for: {}".format(ticker), xaxis_rangeslider_visible=True)
fig = go.Figure(data=data,layout=layout)
fig.update_xaxes(
rangeslider_visible=True,
rangeselector=dict(
buttons=list([
dict(count=1, label="1m", step="month", stepmode="backward"),
dict(count=6, label="6m", step="month", stepmode="backward"),
dict(count=1, label="YTD", step="year", stepmode="todate"),
dict(count=1, label="1y", step="year", stepmode="backward"),
dict(step="all")
])
)
)
fig.update_layout(
hovermode="x",
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01
)
)
return fig
def stock_predictor(input: Input) -> PlotlyPlot:
"""Input ticker, predict stocks price in 30 days by prophet. Data from yahoo finance."""
# Step1. load data & preprocess
df = load_data(input.ticker)
df_processed = preprocess_data(df)
# Step2. predict
df_forecast = predict_data(df_processed)
# Step3. visualization
fig = visualization(df_processed, df_forecast, input.ticker)
# Step3. wrapped by output
output = PlotlyPlot(data=fig)
return output |