Anupam251272's picture
Create app.py
a77ea34 verified
raw
history blame
10.4 kB
import numpy as np
import pandas as pd
import yfinance as yf
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from transformers import pipeline
import gradio as gr
import plotly.graph_objects as go
from datetime import datetime, timedelta
import threading
import time
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
class FinancialGNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(FinancialGNN, self).__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, output_dim)
def forward(self, x, edge_index):
# Add error handling for input dimensions
if x.dim() != 2:
raise ValueError(f"Expected 2D input tensor, got {x.dim()}D")
x = self.conv1(x, edge_index)
x = torch.relu(x)
x = self.conv2(x, edge_index)
return x
class MarketAnalysisSystem:
def __init__(self):
try:
# Initialize sentiment analyzer with error handling
self.sentiment_analyzer = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=0 if torch.cuda.is_available() else -1
)
except Exception as e:
print(f"Error initializing sentiment analyzer: {e}")
self.sentiment_analyzer = None
# Initialize GNN model
self.gnn_model = FinancialGNN(input_dim=5, hidden_dim=32, output_dim=1).to(device)
# Define default stock symbols
self.symbols = ['AAPL', 'GOOGL', 'MSFT', 'AMZN', 'META']
# Initialize data storage
self.market_data = {}
self.sentiment_scores = {}
# Initialize monitoring flag and thread
self.monitoring = False
self.monitor_thread = None
def collect_market_data(self, symbols=None):
if symbols is None:
symbols = self.symbols
if not symbols:
raise ValueError("No symbols provided for market data collection")
end_date = datetime.now()
start_date = end_date - timedelta(days=30)
market_data = {}
for symbol in symbols:
try:
stock = yf.download(symbol, start=start_date, end=end_date, progress=False)
if stock.empty:
print(f"No data available for {symbol}")
continue
market_data[symbol] = stock
except Exception as e:
print(f"Error collecting data for {symbol}: {e}")
continue
if not market_data:
raise ValueError("No market data could be collected for any symbol")
return market_data
def analyze_sentiment(self, symbol):
if self.sentiment_analyzer is None:
return 0
try:
# In practice, you should implement real news fetching here
sample_news = f"Latest news about {symbol} shows positive market momentum"
sentiment = self.sentiment_analyzer(sample_news)[0]
return sentiment['score'] if sentiment['label'] == 'POSITIVE' else -sentiment['score']
except Exception as e:
print(f"Error in sentiment analysis for {symbol}: {e}")
return 0
def prepare_graph_features(self, market_data):
if not market_data:
raise ValueError("No market data available for feature preparation")
features = []
for symbol in market_data:
df = market_data[symbol]
if len(df) > 0:
try:
feature_vector = torch.tensor([
df['Close'].pct_change().fillna(0).mean(),
df['Close'].pct_change().fillna(0).std(),
df['Volume'].fillna(0).mean(),
df['High'].max(),
df['Low'].min()
], dtype=torch.float32)
features.append(feature_vector)
except Exception as e:
print(f"Error preparing features for {symbol}: {e}")
continue
if not features:
raise ValueError("Could not prepare features for any symbol")
return torch.stack(features).to(device)
def create_correlation_edges(self, market_data):
n = len(market_data)
if n < 2:
raise ValueError("Need at least 2 symbols to create correlation edges")
edges = []
for i in range(n):
for j in range(i+1, n):
edges.append([i, j])
edges.append([j, i])
return torch.tensor(edges, dtype=torch.long).t().to(device)
def predict_market_trends(self):
try:
market_data = self.collect_market_data()
features = self.prepare_graph_features(market_data)
edge_index = self.create_correlation_edges(market_data)
with torch.no_grad():
predictions = self.gnn_model(features, edge_index)
return predictions.cpu().numpy()
except Exception as e:
print(f"Error predicting market trends: {e}")
return np.zeros(len(self.symbols))
def generate_market_visualization(self, market_data):
if not market_data:
raise ValueError("No market data available for visualization")
fig = go.Figure()
for symbol in market_data:
df = market_data[symbol]
if not df.empty:
fig.add_trace(go.Scatter(
x=df.index,
y=df['Close'],
name=symbol,
mode='lines'
))
fig.update_layout(
title='Market Trends',
xaxis_title='Date',
yaxis_title='Price',
template='plotly_dark'
)
return fig
def start_monitoring(self):
if self.monitoring:
return
self.monitoring = True
self.monitor_thread = threading.Thread(target=self._monitoring_loop)
self.monitor_thread.daemon = True
self.monitor_thread.start()
def _monitoring_loop(self):
while self.monitoring:
try:
self.market_data = self.collect_market_data()
for symbol in self.symbols:
self.sentiment_scores[symbol] = self.analyze_sentiment(symbol)
time.sleep(300) # Update every 5 minutes
except Exception as e:
print(f"Error in monitoring loop: {e}")
time.sleep(60) # Wait a minute before retrying
def stop_monitoring(self):
self.monitoring = False
if self.monitor_thread:
self.monitor_thread.join(timeout=1)
def create_gradio_interface():
market_system = MarketAnalysisSystem()
def analyze_markets(symbols_input):
try:
# Input validation
if not symbols_input.strip():
return (
None,
"Error: Please enter at least one stock symbol",
"Error: No symbols provided"
)
symbols = [s.strip() for s in symbols_input.split(',') if s.strip()]
# Collect and analyze market data
try:
market_data = market_system.collect_market_data(symbols)
except Exception as e:
return (
None,
f"Error collecting market data: {str(e)}",
"Unable to analyze trends"
)
# Generate visualization
try:
fig = market_system.generate_market_visualization(market_data)
except Exception as e:
fig = None
print(f"Error generating visualization: {e}")
# Get sentiment scores
sentiments = {symbol: market_system.analyze_sentiment(symbol) for symbol in symbols}
sentiment_text = "\n".join([f"{symbol}: {score:.2f}" for symbol, score in sentiments.items()])
# Predict trends
try:
predictions = market_system.predict_market_trends()
prediction_text = "\n".join([
f"{symbol}: {'Upward' if pred > 0 else 'Downward'} trend"
for symbol, pred in zip(symbols, predictions[:len(symbols)])
])
except Exception as e:
prediction_text = f"Error predicting trends: {str(e)}"
return fig, sentiment_text, prediction_text
except Exception as e:
return None, f"Error: {str(e)}", "Analysis failed"
interface = gr.Interface(
fn=analyze_markets,
inputs=gr.Textbox(
label="Enter stock symbols (comma-separated)",
value="AAPL,GOOGL,MSFT"
),
outputs=[
gr.Plot(label="Market Trends"),
gr.Textbox(label="Sentiment Analysis"),
gr.Textbox(label="Trend Predictions")
],
title="Real-Time Market Analysis System",
description="Enter stock symbols to analyze market trends, sentiment, and predictions."
)
return interface
if __name__ == "__main__":
# Ensure all required packages are imported
required_packages = {
'numpy': np,
'pandas': pd,
'yfinance': yf,
'torch': torch,
'transformers': pipeline,
'gradio': gr,
'plotly': go
}
missing_packages = []
for package, module in required_packages.items():
if module is None:
missing_packages.append(package)
if missing_packages:
print(f"Missing required packages: {', '.join(missing_packages)}")
print("Please install them using pip:")
print(f"pip install {' '.join(missing_packages)}")
else:
interface = create_gradio_interface()
interface.launch(debug=True)