dibend's picture
Update app.py
38f860b
raw history blame
No virus
2.84 kB
import gradio as gr
import yfinance as yf
import pandas as pd
from sklearn.linear_model import LinearRegression
import plotly.graph_objects as go
def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
"""
Downloads stock data, trains a linear regression model, and predicts future prices.
Args:
ticker: The ticker symbol of the stock.
start_date: The start date for the data (YYYY-MM-DD format).
end_date: The end date for the data (YYYY-MM-DD format).
prediction_days: The number of days to predict.
Returns:
A plotly figure with the historical and predicted prices.
"""
# Download stock data
data = yf.download(ticker, start=start_date, end=end_date)
data = data["Close"]
# Convert dates to a numerical format (days since start date)
start_date = pd.to_datetime(start_date)
days_since_start = (data.index - start_date).days
# Train linear regression model
X = days_since_start.values[:-prediction_days].reshape(-1, 1)
y = data.values[:-prediction_days]
model = LinearRegression()
model.fit(X, y)
# Prepare data for prediction
last_date = data.index[-1]
future_dates = pd.date_range(start=last_date, periods=prediction_days + 1, closed='right')
future_days_since_start = (future_dates - start_date).days
X_future = future_days_since_start.values.reshape(-1, 1)
# Predict future prices
predicted_prices = model.predict(X_future)
# Predict future prices
future_dates = data.index.values[-prediction_days:]
X_future = future_dates.reshape(-1, 1)
predicted_prices = model.predict(X_future)
# Prepare data for plotting
historical_prices = go.Scatter(
x=data.index,
y=data.values,
mode="lines",
line_color=lambda p: "green" if p > 0 else "red",
name="Historical Prices",
)
predicted_prices_trace = go.Scatter(
x=future_dates,
y=predicted_prices,
mode="lines",
line_color="gold",
line_width=3,
marker_line_width=3,
marker_color="black",
name="Predicted Prices",
)
# Plot data
fig = go.Figure()
fig.add_trace(historical_prices)
fig.add_trace(predicted_prices_trace)
fig.update_layout(
title="Stock Price Prediction",
xaxis_title="Date",
yaxis_title="Price",
legend_title_text="Data",
)
return fig
# Define Gradio interface
interface = gr.Interface(
fn=train_predict_wrapper,
inputs=[
gr.Textbox(label="Ticker Symbol"),
gr.Textbox(label="Start Date (YYYY-MM-DD)"),
gr.Textbox(label="End Date (YYYY-MM-DD)"),
gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"),
],
outputs="plot",
)
# Launch the app
interface.launch()