|
import gradio as gr |
|
import yfinance as yf |
|
import pandas as pd |
|
from sklearn.linear_model import LinearRegression |
|
import plotly.graph_objects as go |
|
|
|
|
|
def train_predict_wrapper(ticker, start_date, end_date, prediction_days): |
|
""" |
|
Downloads stock data, trains a linear regression model, and predicts future prices. |
|
|
|
Args: |
|
ticker: The ticker symbol of the stock. |
|
start_date: The start date for the data (YYYY-MM-DD format). |
|
end_date: The end date for the data (YYYY-MM-DD format). |
|
prediction_days: The number of days to predict. |
|
|
|
Returns: |
|
A plotly figure with the historical and predicted prices. |
|
""" |
|
|
|
data = yf.download(ticker, start=start_date, end=end_date) |
|
data = data["Close"] |
|
|
|
|
|
start_date = pd.to_datetime(start_date) |
|
days_since_start = (data.index - start_date).days |
|
|
|
|
|
X = days_since_start.values[:-prediction_days].reshape(-1, 1) |
|
y = data.values[:-prediction_days] |
|
model = LinearRegression() |
|
model.fit(X, y) |
|
|
|
|
|
last_date = data.index[-1] |
|
future_dates = pd.date_range(start=last_date, periods=prediction_days + 1, closed='right') |
|
future_days_since_start = (future_dates - start_date).days |
|
X_future = future_days_since_start.values.reshape(-1, 1) |
|
|
|
|
|
predicted_prices = model.predict(X_future) |
|
|
|
|
|
future_dates = data.index.values[-prediction_days:] |
|
X_future = future_dates.reshape(-1, 1) |
|
predicted_prices = model.predict(X_future) |
|
|
|
|
|
historical_prices = go.Scatter( |
|
x=data.index, |
|
y=data.values, |
|
mode="lines", |
|
line_color=lambda p: "green" if p > 0 else "red", |
|
name="Historical Prices", |
|
) |
|
predicted_prices_trace = go.Scatter( |
|
x=future_dates, |
|
y=predicted_prices, |
|
mode="lines", |
|
line_color="gold", |
|
line_width=3, |
|
marker_line_width=3, |
|
marker_color="black", |
|
name="Predicted Prices", |
|
) |
|
|
|
|
|
fig = go.Figure() |
|
fig.add_trace(historical_prices) |
|
fig.add_trace(predicted_prices_trace) |
|
fig.update_layout( |
|
title="Stock Price Prediction", |
|
xaxis_title="Date", |
|
yaxis_title="Price", |
|
legend_title_text="Data", |
|
) |
|
|
|
return fig |
|
|
|
|
|
|
|
interface = gr.Interface( |
|
fn=train_predict_wrapper, |
|
inputs=[ |
|
gr.Textbox(label="Ticker Symbol"), |
|
gr.Textbox(label="Start Date (YYYY-MM-DD)"), |
|
gr.Textbox(label="End Date (YYYY-MM-DD)"), |
|
gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"), |
|
], |
|
outputs="plot", |
|
) |
|
|
|
|
|
interface.launch() |