Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
import yfinance as yf | |
import torch | |
import torch.nn as nn | |
from torch_geometric.nn import GCNConv | |
from transformers import pipeline | |
import gradio as gr | |
import plotly.graph_objects as go | |
from datetime import datetime, timedelta | |
import threading | |
import time | |
# Check if GPU is available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using device: {device}") | |
class FinancialGNN(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim): | |
super(FinancialGNN, self).__init__() | |
self.conv1 = GCNConv(input_dim, hidden_dim) | |
self.conv2 = GCNConv(hidden_dim, output_dim) | |
def forward(self, x, edge_index): | |
# Add error handling for input dimensions | |
if x.dim() != 2: | |
raise ValueError(f"Expected 2D input tensor, got {x.dim()}D") | |
x = self.conv1(x, edge_index) | |
x = torch.relu(x) | |
x = self.conv2(x, edge_index) | |
return x | |
class MarketAnalysisSystem: | |
def __init__(self): | |
try: | |
# Initialize sentiment analyzer with error handling | |
self.sentiment_analyzer = pipeline( | |
"sentiment-analysis", | |
model="distilbert-base-uncased-finetuned-sst-2-english", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
except Exception as e: | |
print(f"Error initializing sentiment analyzer: {e}") | |
self.sentiment_analyzer = None | |
# Initialize GNN model | |
self.gnn_model = FinancialGNN(input_dim=5, hidden_dim=32, output_dim=1).to(device) | |
# Define default stock symbols | |
self.symbols = ['AAPL', 'GOOGL', 'MSFT', 'AMZN', 'META'] | |
# Initialize data storage | |
self.market_data = {} | |
self.sentiment_scores = {} | |
# Initialize monitoring flag and thread | |
self.monitoring = False | |
self.monitor_thread = None | |
def collect_market_data(self, symbols=None): | |
if symbols is None: | |
symbols = self.symbols | |
if not symbols: | |
raise ValueError("No symbols provided for market data collection") | |
end_date = datetime.now() | |
start_date = end_date - timedelta(days=30) | |
market_data = {} | |
for symbol in symbols: | |
try: | |
stock = yf.download(symbol, start=start_date, end=end_date, progress=False) | |
if stock.empty: | |
print(f"No data available for {symbol}") | |
continue | |
market_data[symbol] = stock | |
except Exception as e: | |
print(f"Error collecting data for {symbol}: {e}") | |
continue | |
if not market_data: | |
raise ValueError("No market data could be collected for any symbol") | |
return market_data | |
def analyze_sentiment(self, symbol): | |
if self.sentiment_analyzer is None: | |
return 0 | |
try: | |
# In practice, you should implement real news fetching here | |
sample_news = f"Latest news about {symbol} shows positive market momentum" | |
sentiment = self.sentiment_analyzer(sample_news)[0] | |
return sentiment['score'] if sentiment['label'] == 'POSITIVE' else -sentiment['score'] | |
except Exception as e: | |
print(f"Error in sentiment analysis for {symbol}: {e}") | |
return 0 | |
def prepare_graph_features(self, market_data): | |
if not market_data: | |
raise ValueError("No market data available for feature preparation") | |
features = [] | |
for symbol in market_data: | |
df = market_data[symbol] | |
if len(df) > 0: | |
try: | |
feature_vector = torch.tensor([ | |
df['Close'].pct_change().fillna(0).mean(), | |
df['Close'].pct_change().fillna(0).std(), | |
df['Volume'].fillna(0).mean(), | |
df['High'].max(), | |
df['Low'].min() | |
], dtype=torch.float32) | |
features.append(feature_vector) | |
except Exception as e: | |
print(f"Error preparing features for {symbol}: {e}") | |
continue | |
if not features: | |
raise ValueError("Could not prepare features for any symbol") | |
return torch.stack(features).to(device) | |
def create_correlation_edges(self, market_data): | |
n = len(market_data) | |
if n < 2: | |
raise ValueError("Need at least 2 symbols to create correlation edges") | |
edges = [] | |
for i in range(n): | |
for j in range(i+1, n): | |
edges.append([i, j]) | |
edges.append([j, i]) | |
return torch.tensor(edges, dtype=torch.long).t().to(device) | |
def predict_market_trends(self): | |
try: | |
market_data = self.collect_market_data() | |
features = self.prepare_graph_features(market_data) | |
edge_index = self.create_correlation_edges(market_data) | |
with torch.no_grad(): | |
predictions = self.gnn_model(features, edge_index) | |
return predictions.cpu().numpy() | |
except Exception as e: | |
print(f"Error predicting market trends: {e}") | |
return np.zeros(len(self.symbols)) | |
def generate_market_visualization(self, market_data): | |
if not market_data: | |
raise ValueError("No market data available for visualization") | |
fig = go.Figure() | |
for symbol in market_data: | |
df = market_data[symbol] | |
if not df.empty: | |
fig.add_trace(go.Scatter( | |
x=df.index, | |
y=df['Close'], | |
name=symbol, | |
mode='lines' | |
)) | |
fig.update_layout( | |
title='Market Trends', | |
xaxis_title='Date', | |
yaxis_title='Price', | |
template='plotly_dark' | |
) | |
return fig | |
def start_monitoring(self): | |
if self.monitoring: | |
return | |
self.monitoring = True | |
self.monitor_thread = threading.Thread(target=self._monitoring_loop) | |
self.monitor_thread.daemon = True | |
self.monitor_thread.start() | |
def _monitoring_loop(self): | |
while self.monitoring: | |
try: | |
self.market_data = self.collect_market_data() | |
for symbol in self.symbols: | |
self.sentiment_scores[symbol] = self.analyze_sentiment(symbol) | |
time.sleep(300) # Update every 5 minutes | |
except Exception as e: | |
print(f"Error in monitoring loop: {e}") | |
time.sleep(60) # Wait a minute before retrying | |
def stop_monitoring(self): | |
self.monitoring = False | |
if self.monitor_thread: | |
self.monitor_thread.join(timeout=1) | |
def create_gradio_interface(): | |
market_system = MarketAnalysisSystem() | |
def analyze_markets(symbols_input): | |
try: | |
# Input validation | |
if not symbols_input.strip(): | |
return ( | |
None, | |
"Error: Please enter at least one stock symbol", | |
"Error: No symbols provided" | |
) | |
symbols = [s.strip() for s in symbols_input.split(',') if s.strip()] | |
# Collect and analyze market data | |
try: | |
market_data = market_system.collect_market_data(symbols) | |
except Exception as e: | |
return ( | |
None, | |
f"Error collecting market data: {str(e)}", | |
"Unable to analyze trends" | |
) | |
# Generate visualization | |
try: | |
fig = market_system.generate_market_visualization(market_data) | |
except Exception as e: | |
fig = None | |
print(f"Error generating visualization: {e}") | |
# Get sentiment scores | |
sentiments = {symbol: market_system.analyze_sentiment(symbol) for symbol in symbols} | |
sentiment_text = "\n".join([f"{symbol}: {score:.2f}" for symbol, score in sentiments.items()]) | |
# Predict trends | |
try: | |
predictions = market_system.predict_market_trends() | |
prediction_text = "\n".join([ | |
f"{symbol}: {'Upward' if pred > 0 else 'Downward'} trend" | |
for symbol, pred in zip(symbols, predictions[:len(symbols)]) | |
]) | |
except Exception as e: | |
prediction_text = f"Error predicting trends: {str(e)}" | |
return fig, sentiment_text, prediction_text | |
except Exception as e: | |
return None, f"Error: {str(e)}", "Analysis failed" | |
interface = gr.Interface( | |
fn=analyze_markets, | |
inputs=gr.Textbox( | |
label="Enter stock symbols (comma-separated)", | |
value="AAPL,GOOGL,MSFT" | |
), | |
outputs=[ | |
gr.Plot(label="Market Trends"), | |
gr.Textbox(label="Sentiment Analysis"), | |
gr.Textbox(label="Trend Predictions") | |
], | |
title="Real-Time Market Analysis System", | |
description="Enter stock symbols to analyze market trends, sentiment, and predictions." | |
) | |
return interface | |
if __name__ == "__main__": | |
# Ensure all required packages are imported | |
required_packages = { | |
'numpy': np, | |
'pandas': pd, | |
'yfinance': yf, | |
'torch': torch, | |
'transformers': pipeline, | |
'gradio': gr, | |
'plotly': go | |
} | |
missing_packages = [] | |
for package, module in required_packages.items(): | |
if module is None: | |
missing_packages.append(package) | |
if missing_packages: | |
print(f"Missing required packages: {', '.join(missing_packages)}") | |
print("Please install them using pip:") | |
print(f"pip install {' '.join(missing_packages)}") | |
else: | |
interface = create_gradio_interface() | |
interface.launch(debug=True) |