File size: 2,179 Bytes
a2d3623 1e716d6 38f860b 1e716d6 ed74d2e 1e716d6 a2d3623 aadb13b 87d437c aadb13b a2d3623 87d437c ed74d2e a2d3623 ed74d2e a2d3623 aadb13b 87d437c aadb13b ed74d2e 87d437c ed74d2e 87d437c ed74d2e 87d437c ed74d2e 87d437c ed74d2e a2d3623 ed74d2e 87d437c ed74d2e a2d3623 ed74d2e a2d3623 ddd2d15 ed74d2e a2d3623 87d437c a2d3623 ed74d2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
import yfinance as yf
import pandas as pd
from sklearn.linear_model import LinearRegression
import plotly.graph_objects as go
def train_predict_wrapper(ticker, start_date, end_date, prediction_days):
# Download stock data
data = yf.download(ticker, start=start_date, end=end_date)
data = data["Close"]
# Convert index to Unix timestamp (seconds)
data.index = (data.index - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')
# Train linear regression model
X = data.index.values[:-prediction_days].reshape(-1, 1)
y = data.values[:-prediction_days]
model = LinearRegression()
model.fit(X, y)
# Prepare data for prediction
last_timestamp = data.index[-1]
future_timestamps = pd.date_range(start=pd.to_datetime(last_timestamp, unit='s'),
periods=prediction_days + 1, closed='right')
future_timestamps = (future_timestamps - pd.Timestamp("1970-01-01")) // pd.Timedelta('1s')
X_future = future_timestamps.values.reshape(-1, 1)
# Predict future prices
predicted_prices = model.predict(X_future)
# Prepare data for plotting
historical_prices = go.Scatter(
x=pd.to_datetime(data.index, unit='s'),
y=data.values,
mode="lines",
name="Historical Prices"
)
predicted_prices_trace = go.Scatter(
x=pd.to_datetime(future_timestamps, unit='s'),
y=predicted_prices,
mode="lines",
name="Predicted Prices"
)
# Plot data
fig = go.Figure()
fig.add_trace(historical_prices)
fig.add_trace(predicted_prices_trace)
fig.update_layout(
title="Stock Price Prediction",
xaxis_title="Date",
yaxis_title="Price",
legend_title_text="Data"
)
return fig
# Define Gradio interface
interface = gr.Interface(
fn=train_predict_wrapper,
inputs=[
gr.Textbox(label="Ticker Symbol"),
gr.Textbox(label="Start Date (YYYY-MM-DD)"),
gr.Textbox(label="End Date (YYYY-MM-DD)"),
gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"),
],
outputs="plot"
)
# Launch the app
interface.launch() |