Spaces:
Running
Running
import gradio as gr | |
import plotly.graph_objects as go | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.models import Sequential | |
from tensorflow.keras.layers import Dense, LSTM | |
from sklearn.preprocessing import MinMaxScaler | |
from sklearn.linear_model import LinearRegression | |
from sklearn.ensemble import RandomForestRegressor | |
from sklearn.metrics import mean_squared_error | |
# Define periods | |
PERIODS = { | |
"1 Month": "1mo", | |
"3 Months": "3mo", | |
"6 Months": "6mo", | |
"1 Year": "1y", | |
"5 Years": "5y", | |
"10 Years": "10y", | |
"Max": "max" | |
} | |
def fetch_data(ticker, period): | |
if ticker.strip() == "" or period not in PERIODS: | |
return None, "Invalid input: Check the ticker symbol and period." | |
try: | |
data = yf.download(ticker, period=PERIODS[period]) | |
if data.empty: | |
return None, "No data found for the given ticker and period." | |
return data, None | |
except Exception as e: | |
return None, str(e) | |
def update_output_type(analysis_type): | |
if analysis_type == "Dividends": | |
return gr.update(visible=True), gr.update(visible=False) | |
else: | |
return gr.update(visible=False), gr.update(visible=True) | |
def fetch_news(ticker): | |
try: | |
stock = yf.Ticker(ticker) | |
news = stock.news | |
news_summary = "" | |
for article in news[:5]: # Limiting to 5 news articles | |
news_summary += f"### {article['title']}\n" | |
news_summary += f"**Source:** {article['publisher']} \n" | |
news_summary += f"**Published At:** {article['providerPublishTime']} \n" | |
news_summary += f"{article['summary']} \n" | |
news_summary += f"[Read more]({article['link']}) \n\n" | |
return news_summary | |
except Exception as e: | |
return str(e) | |
def summarize_asset(ticker): | |
if ticker.strip() == "": | |
return "Invalid input: Ticker symbol cannot be empty." | |
stock = yf.Ticker(ticker) | |
try: | |
info = stock.info | |
asset_type = info.get('quoteType', 'N/A') | |
summary = f"**Summary for {info.get('shortName', 'N/A')} ({ticker}):**\n\n" | |
current_price = info.get('currentPrice') | |
if current_price: | |
summary += f"- **Current Price:** ${current_price}\n" | |
market_cap = info.get('marketCap') | |
if market_cap: | |
summary += f"- **Market Cap:** ${market_cap}\n" | |
trailing_pe = info.get('trailingPE') | |
if trailing_pe: | |
summary += f"- **PE Ratio (TTM):** {trailing_pe}\n" | |
trailing_eps = info.get('trailingEps') | |
if trailing_eps: | |
summary += f"- **EPS (TTM):** {trailing_eps}\n" | |
dividend_yield = info.get('dividendYield') | |
if dividend_yield: | |
summary += f"- **Dividend Yield:** {dividend_yield * 100}%\n" | |
fifty_two_week_high = info.get('fiftyTwoWeekHigh') | |
if fifty_two_week_high: | |
summary += f"- **52 Week High:** ${fifty_two_week_high}\n" | |
fifty_two_week_low = info.get('fiftyTwoWeekLow') | |
if fifty_two_week_low: | |
summary += f"- **52 Week Low:** ${fifty_two_week_low}\n" | |
volume = info.get('volume') | |
if volume: | |
summary += f"- **Volume:** {volume}\n" | |
if asset_type == "ETF" or asset_type == "MUTUALFUND": | |
total_assets = info.get('totalAssets') | |
if total_assets: | |
summary += f"- **Total Assets:** ${total_assets}\n" | |
nav_price = info.get('navPrice') | |
if nav_price: | |
summary += f"- **NAV:** ${nav_price}\n" | |
expense_ratio = info.get('annualReportExpenseRatio') | |
if expense_ratio: | |
summary += f"- **Expense Ratio:** {expense_ratio * 100}%\n" | |
return summary | |
except Exception as e: | |
return str(e) | |
def plot_technical_analysis(ticker, period, analysis_type, ma_length, candle_period, adx_period): | |
data, error = fetch_data(ticker, period) | |
if error: | |
return error | |
fig = go.Figure() | |
if analysis_type == "Candlestick": | |
fig.add_trace(go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'], name='Candlestick')) | |
elif analysis_type == "Moving Average": | |
if ma_length < 1: | |
return "Moving average length must be a positive integer." | |
data[f'MA{ma_length}'] = data['Close'].rolling(window=ma_length).mean() | |
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='Close Price')) | |
fig.add_trace(go.Scatter(x=data.index, y=data[f'MA{ma_length}'], mode='lines', name=f'{ma_length}-day MA')) | |
elif analysis_type == "Bollinger Bands": | |
data['MA20'] = data['Close'].rolling(window=20).mean() | |
data['stddev'] = data['Close'].rolling(window=20).std() | |
data['upper'] = data['MA20'] + (data['stddev'] * 2) | |
data['lower'] = data['MA20'] - (data['stddev'] * 2) | |
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='Close Price')) | |
fig.add_trace(go.Scatter(x=data.index, y=data['upper'], mode='lines', name='Upper Band')) | |
fig.add_trace(go.Scatter(x=data.index, y=data['lower'], mode='lines', name='Lower Band')) | |
elif analysis_type == "RSI": | |
delta = data['Close'].diff(1) | |
gain = delta.where(delta > 0, 0) | |
loss = -delta.where(delta < 0, 0) | |
avg_gain = gain.rolling(window=14).mean() | |
avg_loss = loss.rolling(window=14).mean() | |
rs = avg_gain / avg_loss | |
rsi = 100 - (100 / (1 + rs)) | |
fig.add_trace(go.Scatter(x=data.index, y=rsi, mode='lines', name='RSI')) | |
elif analysis_type == "MACD": | |
exp1 = data['Close'].ewm(span=12, adjust=False).mean() | |
exp2 = data['Close'].ewm(span=26, adjust=False).mean() | |
macd = exp1 - exp2 | |
signal = macd.ewm(span=9, adjust=False).mean() | |
fig.add_trace(go.Scatter(x=data.index, y=macd, mode='lines', name='MACD')) | |
fig.add_trace(go.Scatter(x=data.index, y=signal, mode='lines', name='Signal Line')) | |
elif analysis_type == "ADX": | |
data['TR'] = abs(data['High'] - data['Low']) | |
data['DM+'] = np.where((data['High'] - data['High'].shift(1)) > (data['Low'].shift(1) - data['Low']), data['High'] - data['High'].shift(1), 0) | |
data['DM-'] = np.where((data['Low'].shift(1) - data['Low']) > (data['High'] - data['High'].shift(1)), data['Low'].shift(1) - data['Low'], 0) | |
data['TR'] = data['TR'].rolling(window=adx_period).sum() | |
data['DM+'] = data['DM+'].rolling(window=adx_period).sum() | |
data['DM-'] = data['DM-'].rolling(window=adx_period).sum() | |
data['DI+'] = 100 * (data['DM+'] / data['TR']) | |
data['DI-'] = 100 * (data['DM-'] / data['TR']) | |
data['DX'] = 100 * abs((data['DI+'] - data['DI-']) / (data['DI+'] + data['DI-'])) | |
data['ADX'] = data['DX'].rolling(window=adx_period).mean() | |
fig.add_trace(go.Scatter(x=data.index, y=data['ADX'], mode='lines', name='ADX')) | |
fig.update_layout(title=f"{analysis_type} Analysis for {ticker}", xaxis_title="Date", yaxis_title="Price", xaxis_rangeslider_visible=False) | |
return fig | |
def plot_fundamental_analysis(ticker, analysis_type): | |
if ticker.strip() == "": | |
return None, "Invalid input: Ticker symbol cannot be empty." | |
stock = yf.Ticker(ticker) | |
try: | |
if analysis_type == "Financials": | |
data = stock.financials | |
return None, data.to_html() | |
elif analysis_type == "Balance Sheet": | |
data = stock.balance_sheet | |
return None, data.to_html() | |
elif analysis_type == "Cash Flow": | |
data = stock.cashflow | |
return None, data.to_html() | |
elif analysis_type == "Dividends": | |
data = stock.dividends | |
fig = go.Figure() | |
# Plot dividends | |
fig.add_trace(go.Scatter( | |
x=data.index, y=data, | |
mode='lines+markers', name="Dividends" | |
)) | |
# Update layout | |
fig.update_layout( | |
title=f"Dividends for {ticker}", | |
xaxis_title="Date", | |
yaxis_title="Amount", | |
xaxis_rangeslider_visible=False, | |
width=1920, | |
legend=dict( | |
orientation="h", | |
yanchor="bottom", | |
y=1.02, | |
xanchor="right", | |
x=1 | |
) | |
) | |
return fig, None | |
else: | |
return None, "Analysis type not supported in this example." | |
except Exception as e: | |
return None, str(e) | |
def train_lstm_model(ticker, period, epochs, batch_size, future_days, layers, units): | |
if ticker.strip() == "" or period not in PERIODS: | |
return None, "Invalid input: Check the ticker symbol and period." | |
if epochs < 1 or batch_size < 1 or future_days < 1: | |
return None, "Epochs, batch size, and future days must be positive integers." | |
if layers < 1 or units < 1: | |
return None, "Number of layers and units must be positive integers." | |
data, error = fetch_data(ticker, period) | |
if error: | |
return error | |
data['Close'] = data['Close'].fillna(method='ffill') | |
close_prices = data['Close'].values.reshape(-1, 1) | |
scaler = MinMaxScaler() | |
close_prices = scaler.fit_transform(close_prices) | |
X, y = [], [] | |
time_step = 10 | |
for i in range(time_step, len(close_prices) - future_days): | |
X.append(close_prices[i-time_step:i]) | |
y.append(close_prices[i + future_days]) | |
X, y = np.array(X), np.array(y) | |
split = int(0.8 * len(X)) | |
X_train, X_test = X[:split], X[split:] | |
y_train, y_test = y[:split], y[split:] | |
model = Sequential() | |
for _ in range(layers): | |
model.add(LSTM(units=units, return_sequences=True, input_shape=(time_step, 1))) | |
model.add(LSTM(units=units, return_sequences=False)) | |
model.add(Dense(units=25)) | |
model.add(Dense(units=1)) | |
model.compile(optimizer='adam', loss='mean_squared_error') | |
model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size) | |
predictions = scaler.inverse_transform(model.predict(X_test)) | |
y_test = scaler.inverse_transform(y_test) | |
future_predictions = [] | |
last_data = close_prices[-time_step:] | |
for _ in range(future_days): | |
pred = model.predict(last_data.reshape(1, time_step, 1)) | |
future_predictions.append(pred[0, 0]) | |
last_data = np.append(last_data[1:], pred[0]) | |
future_predictions = scaler.inverse_transform(np.array(future_predictions).reshape(-1, 1)) | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=data.index, y=scaler.inverse_transform(close_prices).flatten(), mode='lines', name='Historical Data')) | |
fig.add_trace(go.Scatter(x=data.index[-len(y_test):], y=y_test.flatten(), mode='lines', name='Actual')) | |
fig.add_trace(go.Scatter(x=data.index[-len(predictions):], y=predictions.flatten(), mode='lines', name='Predicted')) | |
future_dates = pd.date_range(start=data.index[-1], periods=future_days + 1, inclusive='right') | |
fig.add_trace(go.Scatter(x=future_dates, y=future_predictions.flatten(), mode='lines', name='Future Predictions')) | |
fig.update_layout(title=f"Predicted vs Actual and Future Forecast for {ticker}", xaxis_title="Date", yaxis_title="Price", xaxis_rangeslider_visible=False, width=1920) | |
return fig | |
def train_linear_regression_model(ticker, period, future_days): | |
if ticker.strip() == "" or period not in PERIODS: | |
return None, "Invalid input: Check the ticker symbol and period." | |
if future_days < 1: | |
return None, "Future days must be a positive integer." | |
data, error = fetch_data(ticker, period) | |
if error: | |
return error | |
data['Close'] = data['Close'].fillna(method='ffill') | |
close_prices = data['Close'].values.reshape(-1, 1) | |
X = np.array(range(len(close_prices))).reshape(-1, 1) | |
y = close_prices | |
model = LinearRegression() | |
model.fit(X, y) | |
future_X = np.array(range(len(close_prices), len(close_prices) + future_days)).reshape(-1, 1) | |
future_predictions = model.predict(future_X) | |
predictions = model.predict(X) | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=data.index, y=close_prices.flatten(), mode='lines', name='Actual')) | |
fig.add_trace(go.Scatter(x=data.index, y=predictions.flatten(), mode='lines', name='Predicted')) | |
future_dates = pd.date_range(start=data.index[-1], periods=future_days + 1, inclusive='right') | |
fig.add_trace(go.Scatter(x=future_dates, y=future_predictions.flatten(), mode='lines', name='Future Predictions')) | |
fig.update_layout(title=f"Linear Regression Historical and Future Forecast for {ticker}", xaxis_title="Date", yaxis_title="Price", xaxis_rangeslider_visible=False) | |
return fig | |
def train_random_forest_model(ticker, period, future_days, n_estimators): | |
if ticker.strip() == "" or period not in PERIODS: | |
return None, "Invalid input: Check the ticker symbol and period." | |
if future_days < 1 or n_estimators < 1: | |
return None, "Future days and number of estimators must be positive integers." | |
data, error = fetch_data(ticker, period) | |
if error: | |
return error | |
data['Close'] = data['Close'].fillna(method='ffill') | |
close_prices = data['Close'].values.reshape(-1, 1) | |
X = np.array(range(len(close_prices))).reshape(-1, 1) | |
y = close_prices | |
model = RandomForestRegressor(n_estimators=n_estimators) | |
model.fit(X, y.ravel()) | |
future_X = np.array(range(len(close_prices), len(close_prices) + future_days)).reshape(-1, 1) | |
future_predictions = model.predict(future_X) | |
predictions = model.predict(X) | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=data.index, y=close_prices.flatten(), mode='lines', name='Actual')) | |
fig.add_trace(go.Scatter(x=data.index, y=predictions.flatten(), mode='lines', name='Predicted')) | |
future_dates = pd.date_range(start=data.index[-1], periods=future_days + 1, inclusive='right') | |
fig.add_trace(go.Scatter(x=future_dates, y=future_predictions.flatten(), mode='lines', name='Future Predictions')) | |
fig.update_layout(title=f"Random Forest Historical and Future Forecast for {ticker}", xaxis_title="Date", yaxis_title="Price", xaxis_rangeslider_visible=False) | |
return fig | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Stock Analysis Tool") | |
with gr.Tab("Summary"): | |
ticker = gr.Textbox(label="Enter Yahoo Finance Asset Ticker", value="AAPL") | |
summary_button = gr.Button("Generate Summary") | |
summary_output = gr.Markdown() | |
summary_button.click(fn=summarize_asset, inputs=ticker, outputs=summary_output) | |
with gr.Tab("Technical Analysis"): | |
with gr.Column(): | |
ticker = gr.Textbox(label="Enter Yahoo Finance Asset Ticker", value="AAPL") | |
period = gr.Dropdown(label="Period", choices=list(PERIODS.keys()), value="1 Year") | |
analysis_type = gr.Radio(label="Analysis Type", choices=["Candlestick", "Moving Average", "Bollinger Bands", "RSI", "MACD", "ADX"], value="Candlestick") | |
ma_length = gr.Number(label="Moving Average Length", value=20, visible=False) | |
candle_period = gr.Number(label="Candlestick Period", value=1, visible=True) | |
adx_period = gr.Number(label="ADX Period", value=14, visible=False) | |
def update_visibility(analysis_type): | |
return (gr.update(visible=analysis_type == "Moving Average"), gr.update(visible=analysis_type == "Candlestick"), gr.update(visible=analysis_type == "ADX")) | |
analysis_type.change(fn=update_visibility, inputs=analysis_type, outputs=[ma_length, candle_period, adx_period]) | |
plot_button = gr.Button("Plot Technical Analysis") | |
plot_output = gr.Plot() | |
plot_button.click(fn=plot_technical_analysis, inputs=[ticker, period, analysis_type, ma_length, candle_period, adx_period], outputs=plot_output) | |
with gr.Tab("Fundamental Analysis"): | |
ticker = gr.Textbox(label="Enter Yahoo Finance Asset Ticker", value="AAPL") | |
analysis_type = gr.Radio(label="Analysis Type", choices=["Financials", "Balance Sheet", "Cash Flow", "Dividends"], value="Financials") | |
plot_button = gr.Button("Show Fundamental Data") | |
plot_output = gr.Plot(visible=False) | |
html_output = gr.HTML(visible=True) | |
analysis_type.change(fn=update_output_type, inputs=analysis_type, outputs=[plot_output, html_output]) | |
plot_button.click(fn=plot_fundamental_analysis, inputs=[ticker, analysis_type], outputs=[plot_output, html_output]) | |
with gr.Tab("Predictive Model"): | |
ticker = gr.Textbox(label="Enter Yahoo Finance Asset Ticker", value="AAPL") | |
period = gr.Dropdown(label="Period", choices=list(PERIODS.keys()), value="1 Year") | |
epochs = gr.Number(label="Epochs", value=10) | |
batch_size = gr.Number(label="Batch Size", value=32) | |
future_days = gr.Number(label="Days to Predict", value=30) | |
layers = gr.Number(label="Number of LSTM Layers", value=2) | |
units = gr.Number(label="Number of Units per Layer", value=50) | |
train_lstm_button = gr.Button("Train LSTM Model") | |
lstm_output = gr.Plot() | |
train_lstm_button.click(fn=train_lstm_model, inputs=[ticker, period, epochs, batch_size, future_days, layers, units], outputs=lstm_output) | |
future_days_lr = gr.Number(label="Days to Predict (Linear Regression)", value=30) | |
train_lr_button = gr.Button("Train Linear Regression Model") | |
lr_output = gr.Plot() | |
train_lr_button.click(fn=train_linear_regression_model, inputs=[ticker, period, future_days_lr], outputs=lr_output) | |
future_days_rf = gr.Number(label="Days to Predict (Random Forest)", value=30) | |
n_estimators = gr.Number(label="Number of Estimators", value=100) | |
train_rf_button = gr.Button("Train Random Forest Model") | |
rf_output = gr.Plot() | |
train_rf_button.click(fn=train_random_forest_model, inputs=[ticker, period, future_days_rf, n_estimators], outputs=rf_output) | |
# Launch the interface | |
demo.launch(debug=True) |