File size: 3,679 Bytes
9af85d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import datetime
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import yfinance as yf
import plotly.graph_objs as go
import plotly.express as px
from prophet import Prophet
from workcell.integrations.types import PlotlyPlot


class Input(BaseModel):
    ticker: str = Field(default="AAPL", description="A ticker value, like `AAPL`, etc...")


def load_data(ticker):
    """Download ticker price data from ticker.
    e.g. ticker = 'AAPL'|'AMZN'|'GOOG'
    """
    start = datetime.datetime(2022, 1, 1)
    end = datetime.datetime.now() # latest
    data = yf.download(ticker, start=start, end=end, interval='1d')
    # adjust close
    close = data['Adj Close']
    return close


def preprocess_data(df):
    """
    Preprocess dataframe for prediction.
    - Filter out predict value.
    """
    # post process
    df_processed = df.reset_index()
    df_processed.rename(columns={'Adj Close': 'y', 'Date': 'ds'}, inplace=True)
    return df_processed


def predict_data(df, periods=30):
    """Predict future prices by prophet.
    e.g. df = preprocess_df(df)
    """
    # init prophet model
    model = Prophet()
    # fit
    model.fit(df)
    # predict data
    future_prices = model.make_future_dataframe(periods=periods)
    forecast = model.predict(future_prices)
    # forecast data
    df_forecast = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']]
    return df_forecast


def visualization(df_processed, df_forecast, ticker):
    """Visualization price plot by df_forecast dataframe.
    """
    trace_open = go.Scatter(
        x = df_forecast["ds"],
        y = df_forecast["yhat"],
        mode = 'lines',
        name="Forecast"
    )

    trace_high = go.Scatter(
        x = df_forecast["ds"],
        y = df_forecast["yhat_upper"],
        mode = 'lines',
        fill = "tonexty", 
        line = {"color": "#57b8ff"}, 
        name="Higher uncertainty interval"
    )

    trace_low = go.Scatter(
        x = df_forecast["ds"],
        y = df_forecast["yhat_lower"],
        mode = 'lines',
        fill = "tonexty", 
        line = {"color": "#57b8ff"}, 
        name="Lower uncertainty interval"
    )

    trace_close = go.Scatter(
        x = df_processed["ds"],
        y = df_processed["y"],
        name="Data values"
    )

    data = [trace_open,trace_high,trace_low,trace_close]
    layout = go.Layout(title="Repsol Stock Price Forecast for: {}".format(ticker), xaxis_rangeslider_visible=True)
    fig = go.Figure(data=data,layout=layout)
    fig.update_xaxes(
        rangeslider_visible=True,
        rangeselector=dict(
            buttons=list([
                dict(count=1, label="1m", step="month", stepmode="backward"),
                dict(count=6, label="6m", step="month", stepmode="backward"),
                dict(count=1, label="YTD", step="year", stepmode="todate"),
                dict(count=1, label="1y", step="year", stepmode="backward"),
                dict(step="all")
            ])
        )    
    )
    fig.update_layout(
        hovermode="x", 
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=0.01
        )
    )
    return fig


def stock_predictor(input: Input) -> PlotlyPlot:
    """Input ticker, predict stocks price in 30 days by prophet. Data from yahoo finance."""
    # Step1. load data & preprocess
    df = load_data(input.ticker)
    df_processed = preprocess_data(df)
    # Step2. predict
    df_forecast = predict_data(df_processed)
    # Step3. visualization
    fig = visualization(df_processed, df_forecast, input.ticker)
    # Step3. wrapped by output
    output = PlotlyPlot(data=fig)
    return output