import gradio as gr import yfinance as yf import pandas as pd from sklearn.linear_model import LinearRegression import plotly.graph_objects as go def train_predict_wrapper(ticker, start_date, end_date, prediction_days): """ Downloads stock data, trains a linear regression model, and predicts future prices. Args: ticker: The ticker symbol of the stock. start_date: The start date for the data (YYYY-MM-DD format). end_date: The end date for the data (YYYY-MM-DD format). prediction_days: The number of days to predict. Returns: A plotly figure with the historical and predicted prices. """ # Download stock data data = yf.download(ticker, start=start_date, end=end_date) data = data["Close"] # Convert dates to a numerical format (days since start date) start_date = pd.to_datetime(start_date) days_since_start = (data.index - start_date).days # Train linear regression model X = days_since_start.values[:-prediction_days].reshape(-1, 1) y = data.values[:-prediction_days] model = LinearRegression() model.fit(X, y) # Prepare data for prediction last_date = data.index[-1] future_dates = pd.date_range(start=last_date, periods=prediction_days + 1, closed='right') future_days_since_start = (future_dates - start_date).days X_future = future_days_since_start.values.reshape(-1, 1) # Predict future prices predicted_prices = model.predict(X_future) # Predict future prices future_dates = data.index.values[-prediction_days:] X_future = future_dates.reshape(-1, 1) predicted_prices = model.predict(X_future) # Prepare data for plotting historical_prices = go.Scatter( x=data.index, y=data.values, mode="lines", line_color=lambda p: "green" if p > 0 else "red", name="Historical Prices", ) predicted_prices_trace = go.Scatter( x=future_dates, y=predicted_prices, mode="lines", line_color="gold", line_width=3, marker_line_width=3, marker_color="black", name="Predicted Prices", ) # Plot data fig = go.Figure() fig.add_trace(historical_prices) fig.add_trace(predicted_prices_trace) fig.update_layout( title="Stock Price Prediction", xaxis_title="Date", yaxis_title="Price", legend_title_text="Data", ) return fig # Define Gradio interface interface = gr.Interface( fn=train_predict_wrapper, inputs=[ gr.Textbox(label="Ticker Symbol"), gr.Textbox(label="Start Date (YYYY-MM-DD)"), gr.Textbox(label="End Date (YYYY-MM-DD)"), gr.Slider(minimum=1, maximum=30, step=1, label="Prediction Days"), ], outputs="plot", ) # Launch the app interface.launch()