File size: 2,169 Bytes
a2d3623
1e716d6
38f860b
1e716d6
ed74d2e
 
1e716d6
a2d3623
 
 
aadb13b
87d437c
 
aadb13b
a2d3623
87d437c
ed74d2e
a2d3623
ed74d2e
a2d3623
aadb13b
87d437c
 
1c9bac5
87d437c
 
aadb13b
 
 
 
ed74d2e
 
87d437c
ed74d2e
 
87d437c
ed74d2e
 
87d437c
ed74d2e
 
87d437c
ed74d2e
 
 
a2d3623
ed74d2e
 
 
 
 
 
87d437c
ed74d2e
a2d3623
 
 
 
ed74d2e
a2d3623
 
ddd2d15
 
 
ed74d2e
a2d3623
87d437c
a2d3623
 
 
ed74d2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import yfinance as yf
import pandas as pd
from sklearn.linear_model import LinearRegression
import plotly.graph_objects as go

def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
    # Download stock data
    data = yf.download(ticker, start=start_date, end=end_date)
    data = data["Close"]

    # Convert index to Unix timestamp (seconds)
    data.index = (data.index - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')

    # Train linear regression model
    X = data.index.values[:-prediction_days].reshape(-1, 1)
    y = data.values[:-prediction_days]
    model = LinearRegression()
    model.fit(X, y)

    # Prepare data for prediction
    last_timestamp = data.index[-1]
    future_timestamps = pd.date_range(start=pd.to_datetime(last_timestamp, unit='s'), 
                                      periods=prediction_days, freq='D')
    future_timestamps = (future_timestamps - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')
    X_future = future_timestamps.values.reshape(-1, 1)

    # Predict future prices
    predicted_prices = model.predict(X_future)

    # Prepare data for plotting
    historical_prices = go.Scatter(
        x=pd.to_datetime(data.index, unit='s'),
        y=data.values,
        mode="lines",
        name="Historical Prices"
    )
    predicted_prices_trace = go.Scatter(
        x=pd.to_datetime(future_timestamps, unit='s'),
        y=predicted_prices,
        mode="lines",
        name="Predicted Prices"
    )

    # Plot data
    fig = go.Figure()
    fig.add_trace(historical_prices)
    fig.add_trace(predicted_prices_trace)
    fig.update_layout(
        title="Stock Price Prediction",
        xaxis_title="Date",
        yaxis_title="Price",
        legend_title_text="Data"
    )

    return fig

# Define Gradio interface
interface = gr.Interface(
    fn=train_predict_wrapper,
    inputs=[
        gr.Textbox(label="Ticker Symbol"),
        gr.Textbox(label="Start Date (YYYY-MM-DD)"),
        gr.Textbox(label="End Date (YYYY-MM-DD)"),
        gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"),
    ],
    outputs="plot"
)

# Launch the app
interface.launch()