File size: 2,285 Bytes
a2d3623
1e716d6
 
ed74d2e
 
1e716d6
 
a2d3623
 
 
 
ed74d2e
 
 
 
a2d3623
 
ed74d2e
a2d3623
 
 
 
 
 
ed74d2e
 
a2d3623
ed74d2e
a2d3623
 
ed74d2e
 
 
a2d3623
ed74d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2d3623
ed74d2e
 
 
 
 
 
 
 
a2d3623
 
 
ed74d2e
a2d3623
ed74d2e
a2d3623
 
ddd2d15
 
 
ed74d2e
a2d3623
ed74d2e
a2d3623
 
 
ed74d2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import yfinance as yf
from sklearn.linear_model import LinearRegression
import plotly.graph_objects as go


def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
    """
    Downloads stock data, trains a linear regression model, and predicts future prices.

    Args:
        ticker: The ticker symbol of the stock.
        start_date: The start date for the data (YYYY-MM-DD format).
        end_date: The end date for the data (YYYY-MM-DD format).
        prediction_days: The number of days to predict.

    Returns:
        A plotly figure with the historical and predicted prices.
    """
    # Download stock data
    data = yf.download(ticker, start=start_date, end=end_date)
    data = data["Close"]

    # Train linear regression model
    X = data.index.values[:-prediction_days].reshape(-1, 1)
    y = data.values[:-prediction_days]
    model = LinearRegression()
    model.fit(X, y)

    # Predict future prices
    future_dates = data.index.values[-prediction_days:]
    X_future = future_dates.reshape(-1, 1)
    predicted_prices = model.predict(X_future)

    # Prepare data for plotting
    historical_prices = go.Scatter(
        x=data.index,
        y=data.values,
        mode="lines",
        line_color=lambda p: "green" if p > 0 else "red",
        name="Historical Prices",
    )
    predicted_prices_trace = go.Scatter(
        x=future_dates,
        y=predicted_prices,
        mode="lines",
        line_color="gold",
        line_width=3,
        marker_line_width=3,
        marker_color="black",
        name="Predicted Prices",
    )

    # Plot data
    fig = go.Figure()
    fig.add_trace(historical_prices)
    fig.add_trace(predicted_prices_trace)
    fig.update_layout(
        title="Stock Price Prediction",
        xaxis_title="Date",
        yaxis_title="Price",
        legend_title_text="Data",
    )

    return fig


# Define Gradio interface
interface = gr.Interface(
    fn=train_predict_wrapper,
    inputs=[
        gr.Textbox(label="Ticker Symbol"),
        gr.Textbox(label="Start Date (YYYY-MM-DD)"),
        gr.Textbox(label="End Date (YYYY-MM-DD)"),
        gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"),
    ],
    outputs="plot",
)

# Launch the app
interface.launch()