dibend's picture
Update app.py
129ffde verified
raw
history blame contribute delete
No virus
18.2 kB
import gradio as gr
import plotly.graph_objects as go
import yfinance as yf
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM
from sklearn.preprocessing import MinMaxScaler
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
# Define periods
PERIODS = {
"1 Month": "1mo",
"3 Months": "3mo",
"6 Months": "6mo",
"1 Year": "1y",
"5 Years": "5y",
"10 Years": "10y",
"Max": "max"
}
def fetch_data(ticker, period):
if ticker.strip() == "" or period not in PERIODS:
return None, "Invalid input: Check the ticker symbol and period."
try:
data = yf.download(ticker, period=PERIODS[period])
if data.empty:
return None, "No data found for the given ticker and period."
return data, None
except Exception as e:
return None, str(e)
def update_output_type(analysis_type):
if analysis_type == "Dividends":
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
def fetch_news(ticker):
try:
stock = yf.Ticker(ticker)
news = stock.news
news_summary = ""
for article in news[:5]: # Limiting to 5 news articles
news_summary += f"### {article['title']}\n"
news_summary += f"**Source:** {article['publisher']} \n"
news_summary += f"**Published At:** {article['providerPublishTime']} \n"
news_summary += f"{article['summary']} \n"
news_summary += f"[Read more]({article['link']}) \n\n"
return news_summary
except Exception as e:
return str(e)
def summarize_asset(ticker):
if ticker.strip() == "":
return "Invalid input: Ticker symbol cannot be empty."
stock = yf.Ticker(ticker)
try:
info = stock.info
asset_type = info.get('quoteType', 'N/A')
summary = f"**Summary for {info.get('shortName', 'N/A')} ({ticker}):**\n\n"
current_price = info.get('currentPrice')
if current_price:
summary += f"- **Current Price:** ${current_price}\n"
market_cap = info.get('marketCap')
if market_cap:
summary += f"- **Market Cap:** ${market_cap}\n"
trailing_pe = info.get('trailingPE')
if trailing_pe:
summary += f"- **PE Ratio (TTM):** {trailing_pe}\n"
trailing_eps = info.get('trailingEps')
if trailing_eps:
summary += f"- **EPS (TTM):** {trailing_eps}\n"
dividend_yield = info.get('dividendYield')
if dividend_yield:
summary += f"- **Dividend Yield:** {dividend_yield * 100}%\n"
fifty_two_week_high = info.get('fiftyTwoWeekHigh')
if fifty_two_week_high:
summary += f"- **52 Week High:** ${fifty_two_week_high}\n"
fifty_two_week_low = info.get('fiftyTwoWeekLow')
if fifty_two_week_low:
summary += f"- **52 Week Low:** ${fifty_two_week_low}\n"
volume = info.get('volume')
if volume:
summary += f"- **Volume:** {volume}\n"
if asset_type == "ETF" or asset_type == "MUTUALFUND":
total_assets = info.get('totalAssets')
if total_assets:
summary += f"- **Total Assets:** ${total_assets}\n"
nav_price = info.get('navPrice')
if nav_price:
summary += f"- **NAV:** ${nav_price}\n"
expense_ratio = info.get('annualReportExpenseRatio')
if expense_ratio:
summary += f"- **Expense Ratio:** {expense_ratio * 100}%\n"
return summary
except Exception as e:
return str(e)
def plot_technical_analysis(ticker, period, analysis_type, ma_length, candle_period, adx_period):
data, error = fetch_data(ticker, period)
if error:
return error
fig = go.Figure()
if analysis_type == "Candlestick":
fig.add_trace(go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'], name='Candlestick'))
elif analysis_type == "Moving Average":
if ma_length < 1:
return "Moving average length must be a positive integer."
data[f'MA{ma_length}'] = data['Close'].rolling(window=ma_length).mean()
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='Close Price'))
fig.add_trace(go.Scatter(x=data.index, y=data[f'MA{ma_length}'], mode='lines', name=f'{ma_length}-day MA'))
elif analysis_type == "Bollinger Bands":
data['MA20'] = data['Close'].rolling(window=20).mean()
data['stddev'] = data['Close'].rolling(window=20).std()
data['upper'] = data['MA20'] + (data['stddev'] * 2)
data['lower'] = data['MA20'] - (data['stddev'] * 2)
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='Close Price'))
fig.add_trace(go.Scatter(x=data.index, y=data['upper'], mode='lines', name='Upper Band'))
fig.add_trace(go.Scatter(x=data.index, y=data['lower'], mode='lines', name='Lower Band'))
elif analysis_type == "RSI":
delta = data['Close'].diff(1)
gain = delta.where(delta > 0, 0)
loss = -delta.where(delta < 0, 0)
avg_gain = gain.rolling(window=14).mean()
avg_loss = loss.rolling(window=14).mean()
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
fig.add_trace(go.Scatter(x=data.index, y=rsi, mode='lines', name='RSI'))
elif analysis_type == "MACD":
exp1 = data['Close'].ewm(span=12, adjust=False).mean()
exp2 = data['Close'].ewm(span=26, adjust=False).mean()
macd = exp1 - exp2
signal = macd.ewm(span=9, adjust=False).mean()
fig.add_trace(go.Scatter(x=data.index, y=macd, mode='lines', name='MACD'))
fig.add_trace(go.Scatter(x=data.index, y=signal, mode='lines', name='Signal Line'))
elif analysis_type == "ADX":
data['TR'] = abs(data['High'] - data['Low'])
data['DM+'] = np.where((data['High'] - data['High'].shift(1)) > (data['Low'].shift(1) - data['Low']), data['High'] - data['High'].shift(1), 0)
data['DM-'] = np.where((data['Low'].shift(1) - data['Low']) > (data['High'] - data['High'].shift(1)), data['Low'].shift(1) - data['Low'], 0)
data['TR'] = data['TR'].rolling(window=adx_period).sum()
data['DM+'] = data['DM+'].rolling(window=adx_period).sum()
data['DM-'] = data['DM-'].rolling(window=adx_period).sum()
data['DI+'] = 100 * (data['DM+'] / data['TR'])
data['DI-'] = 100 * (data['DM-'] / data['TR'])
data['DX'] = 100 * abs((data['DI+'] - data['DI-']) / (data['DI+'] + data['DI-']))
data['ADX'] = data['DX'].rolling(window=adx_period).mean()
fig.add_trace(go.Scatter(x=data.index, y=data['ADX'], mode='lines', name='ADX'))
fig.update_layout(title=f"{analysis_type} Analysis for {ticker}", xaxis_title="Date", yaxis_title="Price", xaxis_rangeslider_visible=False)
return fig
def plot_fundamental_analysis(ticker, analysis_type):
if ticker.strip() == "":
return None, "Invalid input: Ticker symbol cannot be empty."
stock = yf.Ticker(ticker)
try:
if analysis_type == "Financials":
data = stock.financials
return None, data.to_html()
elif analysis_type == "Balance Sheet":
data = stock.balance_sheet
return None, data.to_html()
elif analysis_type == "Cash Flow":
data = stock.cashflow
return None, data.to_html()
elif analysis_type == "Dividends":
data = stock.dividends
fig = go.Figure()
# Plot dividends
fig.add_trace(go.Scatter(
x=data.index, y=data,
mode='lines+markers', name="Dividends"
))
# Update layout
fig.update_layout(
title=f"Dividends for {ticker}",
xaxis_title="Date",
yaxis_title="Amount",
xaxis_rangeslider_visible=False,
width=1920,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
return fig, None
else:
return None, "Analysis type not supported in this example."
except Exception as e:
return None, str(e)
def train_lstm_model(ticker, period, epochs, batch_size, future_days, layers, units):
if ticker.strip() == "" or period not in PERIODS:
return None, "Invalid input: Check the ticker symbol and period."
if epochs < 1 or batch_size < 1 or future_days < 1:
return None, "Epochs, batch size, and future days must be positive integers."
if layers < 1 or units < 1:
return None, "Number of layers and units must be positive integers."
data, error = fetch_data(ticker, period)
if error:
return error
data['Close'] = data['Close'].fillna(method='ffill')
close_prices = data['Close'].values.reshape(-1, 1)
scaler = MinMaxScaler()
close_prices = scaler.fit_transform(close_prices)
X, y = [], []
time_step = 10
for i in range(time_step, len(close_prices) - future_days):
X.append(close_prices[i-time_step:i])
y.append(close_prices[i + future_days])
X, y = np.array(X), np.array(y)
split = int(0.8 * len(X))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]
model = Sequential()
for _ in range(layers):
model.add(LSTM(units=units, return_sequences=True, input_shape=(time_step, 1)))
model.add(LSTM(units=units, return_sequences=False))
model.add(Dense(units=25))
model.add(Dense(units=1))
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size)
predictions = scaler.inverse_transform(model.predict(X_test))
y_test = scaler.inverse_transform(y_test)
future_predictions = []
last_data = close_prices[-time_step:]
for _ in range(future_days):
pred = model.predict(last_data.reshape(1, time_step, 1))
future_predictions.append(pred[0, 0])
last_data = np.append(last_data[1:], pred[0])
future_predictions = scaler.inverse_transform(np.array(future_predictions).reshape(-1, 1))
fig = go.Figure()
fig.add_trace(go.Scatter(x=data.index, y=scaler.inverse_transform(close_prices).flatten(), mode='lines', name='Historical Data'))
fig.add_trace(go.Scatter(x=data.index[-len(y_test):], y=y_test.flatten(), mode='lines', name='Actual'))
fig.add_trace(go.Scatter(x=data.index[-len(predictions):], y=predictions.flatten(), mode='lines', name='Predicted'))
future_dates = pd.date_range(start=data.index[-1], periods=future_days + 1, inclusive='right')
fig.add_trace(go.Scatter(x=future_dates, y=future_predictions.flatten(), mode='lines', name='Future Predictions'))
fig.update_layout(title=f"Predicted vs Actual and Future Forecast for {ticker}", xaxis_title="Date", yaxis_title="Price", xaxis_rangeslider_visible=False, width=1920)
return fig
def train_linear_regression_model(ticker, period, future_days):
if ticker.strip() == "" or period not in PERIODS:
return None, "Invalid input: Check the ticker symbol and period."
if future_days < 1:
return None, "Future days must be a positive integer."
data, error = fetch_data(ticker, period)
if error:
return error
data['Close'] = data['Close'].fillna(method='ffill')
close_prices = data['Close'].values.reshape(-1, 1)
X = np.array(range(len(close_prices))).reshape(-1, 1)
y = close_prices
model = LinearRegression()
model.fit(X, y)
future_X = np.array(range(len(close_prices), len(close_prices) + future_days)).reshape(-1, 1)
future_predictions = model.predict(future_X)
predictions = model.predict(X)
fig = go.Figure()
fig.add_trace(go.Scatter(x=data.index, y=close_prices.flatten(), mode='lines', name='Actual'))
fig.add_trace(go.Scatter(x=data.index, y=predictions.flatten(), mode='lines', name='Predicted'))
future_dates = pd.date_range(start=data.index[-1], periods=future_days + 1, inclusive='right')
fig.add_trace(go.Scatter(x=future_dates, y=future_predictions.flatten(), mode='lines', name='Future Predictions'))
fig.update_layout(title=f"Linear Regression Historical and Future Forecast for {ticker}", xaxis_title="Date", yaxis_title="Price", xaxis_rangeslider_visible=False)
return fig
def train_random_forest_model(ticker, period, future_days, n_estimators):
if ticker.strip() == "" or period not in PERIODS:
return None, "Invalid input: Check the ticker symbol and period."
if future_days < 1 or n_estimators < 1:
return None, "Future days and number of estimators must be positive integers."
data, error = fetch_data(ticker, period)
if error:
return error
data['Close'] = data['Close'].fillna(method='ffill')
close_prices = data['Close'].values.reshape(-1, 1)
X = np.array(range(len(close_prices))).reshape(-1, 1)
y = close_prices
model = RandomForestRegressor(n_estimators=n_estimators)
model.fit(X, y.ravel())
future_X = np.array(range(len(close_prices), len(close_prices) + future_days)).reshape(-1, 1)
future_predictions = model.predict(future_X)
predictions = model.predict(X)
fig = go.Figure()
fig.add_trace(go.Scatter(x=data.index, y=close_prices.flatten(), mode='lines', name='Actual'))
fig.add_trace(go.Scatter(x=data.index, y=predictions.flatten(), mode='lines', name='Predicted'))
future_dates = pd.date_range(start=data.index[-1], periods=future_days + 1, inclusive='right')
fig.add_trace(go.Scatter(x=future_dates, y=future_predictions.flatten(), mode='lines', name='Future Predictions'))
fig.update_layout(title=f"Random Forest Historical and Future Forecast for {ticker}", xaxis_title="Date", yaxis_title="Price", xaxis_rangeslider_visible=False)
return fig
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Stock Analysis Tool")
with gr.Tab("Summary"):
ticker = gr.Textbox(label="Enter Yahoo Finance Asset Ticker", value="AAPL")
summary_button = gr.Button("Generate Summary")
summary_output = gr.Markdown()
summary_button.click(fn=summarize_asset, inputs=ticker, outputs=summary_output)
with gr.Tab("Technical Analysis"):
with gr.Column():
ticker = gr.Textbox(label="Enter Yahoo Finance Asset Ticker", value="AAPL")
period = gr.Dropdown(label="Period", choices=list(PERIODS.keys()), value="1 Year")
analysis_type = gr.Radio(label="Analysis Type", choices=["Candlestick", "Moving Average", "Bollinger Bands", "RSI", "MACD", "ADX"], value="Candlestick")
ma_length = gr.Number(label="Moving Average Length", value=20, visible=False)
candle_period = gr.Number(label="Candlestick Period", value=1, visible=True)
adx_period = gr.Number(label="ADX Period", value=14, visible=False)
def update_visibility(analysis_type):
return (gr.update(visible=analysis_type == "Moving Average"), gr.update(visible=analysis_type == "Candlestick"), gr.update(visible=analysis_type == "ADX"))
analysis_type.change(fn=update_visibility, inputs=analysis_type, outputs=[ma_length, candle_period, adx_period])
plot_button = gr.Button("Plot Technical Analysis")
plot_output = gr.Plot()
plot_button.click(fn=plot_technical_analysis, inputs=[ticker, period, analysis_type, ma_length, candle_period, adx_period], outputs=plot_output)
with gr.Tab("Fundamental Analysis"):
ticker = gr.Textbox(label="Enter Yahoo Finance Asset Ticker", value="AAPL")
analysis_type = gr.Radio(label="Analysis Type", choices=["Financials", "Balance Sheet", "Cash Flow", "Dividends"], value="Financials")
plot_button = gr.Button("Show Fundamental Data")
plot_output = gr.Plot(visible=False)
html_output = gr.HTML(visible=True)
analysis_type.change(fn=update_output_type, inputs=analysis_type, outputs=[plot_output, html_output])
plot_button.click(fn=plot_fundamental_analysis, inputs=[ticker, analysis_type], outputs=[plot_output, html_output])
with gr.Tab("Predictive Model"):
ticker = gr.Textbox(label="Enter Yahoo Finance Asset Ticker", value="AAPL")
period = gr.Dropdown(label="Period", choices=list(PERIODS.keys()), value="1 Year")
epochs = gr.Number(label="Epochs", value=10)
batch_size = gr.Number(label="Batch Size", value=32)
future_days = gr.Number(label="Days to Predict", value=30)
layers = gr.Number(label="Number of LSTM Layers", value=2)
units = gr.Number(label="Number of Units per Layer", value=50)
train_lstm_button = gr.Button("Train LSTM Model")
lstm_output = gr.Plot()
train_lstm_button.click(fn=train_lstm_model, inputs=[ticker, period, epochs, batch_size, future_days, layers, units], outputs=lstm_output)
future_days_lr = gr.Number(label="Days to Predict (Linear Regression)", value=30)
train_lr_button = gr.Button("Train Linear Regression Model")
lr_output = gr.Plot()
train_lr_button.click(fn=train_linear_regression_model, inputs=[ticker, period, future_days_lr], outputs=lr_output)
future_days_rf = gr.Number(label="Days to Predict (Random Forest)", value=30)
n_estimators = gr.Number(label="Number of Estimators", value=100)
train_rf_button = gr.Button("Train Random Forest Model")
rf_output = gr.Plot()
train_rf_button.click(fn=train_random_forest_model, inputs=[ticker, period, future_days_rf, n_estimators], outputs=rf_output)
# Launch the interface
demo.launch(debug=True)