|
import gradio as gr |
|
import yfinance as yf |
|
from prophet import Prophet |
|
from sklearn.linear_model import LinearRegression |
|
import pandas as pd |
|
from datetime import datetime |
|
import plotly.graph_objects as go |
|
|
|
def download_data(ticker, start_date='2010-01-01'): |
|
""" |
|
์ฃผ์ ๋ฐ์ดํฐ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ํฌ๋งท์ ์กฐ์ ํ๋ ํจ์ |
|
""" |
|
data = yf.download(ticker, start=start_date) |
|
if data.empty: |
|
raise ValueError(f"No data returned for {ticker}") |
|
data.reset_index(inplace=True) |
|
if 'Adj Close' in data.columns: |
|
data = data[['Date', 'Adj Close']] |
|
data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True) |
|
else: |
|
raise ValueError("Expected 'Adj Close' in columns") |
|
return data |
|
|
|
def predict_future_prices(ticker, periods=1825): |
|
data = download_data(ticker) |
|
|
|
|
|
model_prophet = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True) |
|
model_prophet.fit(data) |
|
|
|
|
|
future = model_prophet.make_future_dataframe(periods=periods, freq='D') |
|
forecast_prophet = model_prophet.predict(future) |
|
|
|
|
|
model_lr = LinearRegression() |
|
X = pd.to_numeric(pd.Series(range(len(data)))) |
|
y = data['y'].values |
|
model_lr.fit(X.values.reshape(-1, 1), y) |
|
|
|
|
|
future_dates = pd.date_range(start=data['ds'].iloc[-1], periods=periods+1, freq='D')[1:] |
|
future_lr = pd.DataFrame({'ds': future_dates}) |
|
future_lr['ds'] = future_lr['ds'].dt.strftime('%Y-%m-%d') |
|
X_future = pd.to_numeric(pd.Series(range(len(data), len(data) + len(future_lr)))) |
|
future_lr['yhat'] = model_lr.predict(X_future.values.reshape(-1, 1)) |
|
|
|
|
|
forecast_prophet['ds'] = forecast_prophet['ds'].dt.strftime('%Y-%m-%d') |
|
fig = go.Figure() |
|
fig.add_trace(go.Scatter(x=forecast_prophet['ds'], y=forecast_prophet['yhat'], mode='lines', name='Prophet Forecast (Blue)')) |
|
fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red'))) |
|
fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black'))) |
|
|
|
return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']] |
|
|
|
css = """footer { visibility: hidden; }""" |
|
|
|
with gr.Blocks(css=css) as app: |
|
gr.Markdown(""" |
|
<style> |
|
.markdown-text h2 { |
|
font-size: 12px; # ํฐํธ ํฌ๊ธฐ๋ฅผ 12px๋ก ์ค์ |
|
} |
|
</style> |
|
<h2>AIQ StockAI: ๊ธ๋ก๋ฒ ์์ฐ(์ฃผ์, ์ง์, BTC, ์ํ ๋ฑ) ๋ฏธ๋ ์ฃผ๊ฐ ์์ธก AI ์๋น์ค</h2> |
|
<h2>์ ์ธ๊ณ ๋ชจ๋ ํฐ์ปค ๋ณด๊ธฐ(์ผํ ํ์ด๋ธ์ค): <a href="https://finance.yahoo.com/most-active" target="_blank">์ฌ๊ธฐ๋ฅผ ํด๋ฆญ</a></h2> |
|
""") |
|
|
|
with gr.Row(): |
|
ticker_input = gr.Textbox(value="NVDA", label="Enter Stock Ticker for Forecast") |
|
periods_input = gr.Number(value=1825, label="Forecast Period (days)") |
|
forecast_button = gr.Button("Generate Forecast") |
|
|
|
forecast_chart = gr.Plot(label="Forecast Chart") |
|
forecast_data_prophet = gr.Dataframe(label="Prophet Forecast Data") |
|
forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data") |
|
|
|
forecast_button.click( |
|
fn=predict_future_prices, |
|
inputs=[ticker_input, periods_input], |
|
outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr] |
|
) |
|
|
|
app.launch() |