Spaces:
Sleeping
Sleeping
# %% | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import yfinance as yf | |
from prophet import Prophet | |
import matplotlib.pyplot as plt | |
import io | |
from PIL import Image | |
from datetime import datetime | |
import plotly.express as px # Make sure plotly.express is imported | |
import warnings | |
# Suppress specific FutureWarnings | |
warnings.simplefilter(action='ignore', category=FutureWarning) | |
def stock_forecast(ticker, start_date='2022-01-01', end_date=None): | |
# Set end_date to today's date if not provided | |
if end_date is None: | |
end_date = datetime.today().strftime('%Y-%m-%d') | |
# Download stock market data | |
df = yf.download(ticker, start=start_date, end=end_date) | |
# Reset the index to get 'Date' back as a column | |
df_plot = df.reset_index() | |
# Prepare the data for Prophet | |
df1 = df_plot[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) | |
# Fit the model | |
m = Prophet() | |
m.fit(df1) | |
# Create future dataframe and make predictions | |
future = m.make_future_dataframe(periods=30) | |
forecast = m.predict(future) | |
# Plotting stock closing prices | |
fig1 = px.line(df1.tail(90), x="ds", y="y", title='Stock Closing Prices by Date') | |
fig1.update_layout( | |
plot_bgcolor='black', | |
paper_bgcolor='black', | |
font=dict(color='white') | |
) | |
# Plotting forecast | |
forecast_40 = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(40) | |
fig2 = px.line(forecast_40, x="ds", y=['yhat', 'yhat_lower', 'yhat_upper'], title=f'{ticker} 30 Days Forecast') | |
fig2.update_layout( | |
plot_bgcolor='black', | |
paper_bgcolor='black', | |
font=dict(color='white') | |
) | |
# Customizing Matplotlib plot for forecast components | |
plt.style.use('dark_background') # Use dark background style | |
# Plot the components of the forecast using Matplotlib | |
fig, ax = plt.subplots(figsize=(10, 8)) | |
m.plot_components(forecast) | |
# Apply customizations to the plot | |
for line in ax.get_lines(): | |
line.set_color('purple') # Set line color to purple | |
ax.set_facecolor('black') # Set background color of the axes to black | |
fig.patch.set_facecolor('black') # Set background color of the figure to black | |
# Save the plot to a BytesIO object | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
plt.close(fig) | |
# Convert the image to a PIL Image | |
img = Image.open(buf) | |
return fig1, fig2, img | |
# %% | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=stock_forecast, # Function to run | |
inputs=[gr.Textbox(label="Enter Stock Ticker", value="NVDA"), # Input: stock ticker | |
gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2022-01-01"), # Input: start date | |
gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value=datetime.today().strftime('%Y-%m-%d'))], # Input: end date | |
outputs=[gr.Plot(label="Stock Closing Prices"), # Output: Plotly chart | |
gr.Plot(label="30 Days Forecast"), | |
gr.Image(label="Forecast Components")], # Output: Image | |
title="Stock Market Forecast", | |
description="Enter a stock ticker symbol, a start date, and an end date (in YYYY-MM-DD format) to view the historical closing prices, a 30-day forecast, and the forecast components." | |
) | |
# Launch the Gradio app | |
interface.launch() | |