Spaces:
Runtime error
Runtime error
Anupam251272
commited on
Commit
•
a77ea34
1
Parent(s):
3ece877
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import yfinance as yf
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch_geometric.nn import GCNConv
|
7 |
+
from transformers import pipeline
|
8 |
+
import gradio as gr
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
from datetime import datetime, timedelta
|
11 |
+
import threading
|
12 |
+
import time
|
13 |
+
|
14 |
+
# Check if GPU is available
|
15 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
+
print(f"Using device: {device}")
|
17 |
+
|
18 |
+
class FinancialGNN(nn.Module):
|
19 |
+
def __init__(self, input_dim, hidden_dim, output_dim):
|
20 |
+
super(FinancialGNN, self).__init__()
|
21 |
+
self.conv1 = GCNConv(input_dim, hidden_dim)
|
22 |
+
self.conv2 = GCNConv(hidden_dim, output_dim)
|
23 |
+
|
24 |
+
def forward(self, x, edge_index):
|
25 |
+
# Add error handling for input dimensions
|
26 |
+
if x.dim() != 2:
|
27 |
+
raise ValueError(f"Expected 2D input tensor, got {x.dim()}D")
|
28 |
+
|
29 |
+
x = self.conv1(x, edge_index)
|
30 |
+
x = torch.relu(x)
|
31 |
+
x = self.conv2(x, edge_index)
|
32 |
+
return x
|
33 |
+
|
34 |
+
class MarketAnalysisSystem:
|
35 |
+
def __init__(self):
|
36 |
+
try:
|
37 |
+
# Initialize sentiment analyzer with error handling
|
38 |
+
self.sentiment_analyzer = pipeline(
|
39 |
+
"sentiment-analysis",
|
40 |
+
model="distilbert-base-uncased-finetuned-sst-2-english",
|
41 |
+
device=0 if torch.cuda.is_available() else -1
|
42 |
+
)
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Error initializing sentiment analyzer: {e}")
|
45 |
+
self.sentiment_analyzer = None
|
46 |
+
|
47 |
+
# Initialize GNN model
|
48 |
+
self.gnn_model = FinancialGNN(input_dim=5, hidden_dim=32, output_dim=1).to(device)
|
49 |
+
|
50 |
+
# Define default stock symbols
|
51 |
+
self.symbols = ['AAPL', 'GOOGL', 'MSFT', 'AMZN', 'META']
|
52 |
+
|
53 |
+
# Initialize data storage
|
54 |
+
self.market_data = {}
|
55 |
+
self.sentiment_scores = {}
|
56 |
+
|
57 |
+
# Initialize monitoring flag and thread
|
58 |
+
self.monitoring = False
|
59 |
+
self.monitor_thread = None
|
60 |
+
|
61 |
+
def collect_market_data(self, symbols=None):
|
62 |
+
if symbols is None:
|
63 |
+
symbols = self.symbols
|
64 |
+
|
65 |
+
if not symbols:
|
66 |
+
raise ValueError("No symbols provided for market data collection")
|
67 |
+
|
68 |
+
end_date = datetime.now()
|
69 |
+
start_date = end_date - timedelta(days=30)
|
70 |
+
|
71 |
+
market_data = {}
|
72 |
+
for symbol in symbols:
|
73 |
+
try:
|
74 |
+
stock = yf.download(symbol, start=start_date, end=end_date, progress=False)
|
75 |
+
if stock.empty:
|
76 |
+
print(f"No data available for {symbol}")
|
77 |
+
continue
|
78 |
+
market_data[symbol] = stock
|
79 |
+
except Exception as e:
|
80 |
+
print(f"Error collecting data for {symbol}: {e}")
|
81 |
+
continue
|
82 |
+
|
83 |
+
if not market_data:
|
84 |
+
raise ValueError("No market data could be collected for any symbol")
|
85 |
+
|
86 |
+
return market_data
|
87 |
+
|
88 |
+
def analyze_sentiment(self, symbol):
|
89 |
+
if self.sentiment_analyzer is None:
|
90 |
+
return 0
|
91 |
+
|
92 |
+
try:
|
93 |
+
# In practice, you should implement real news fetching here
|
94 |
+
sample_news = f"Latest news about {symbol} shows positive market momentum"
|
95 |
+
sentiment = self.sentiment_analyzer(sample_news)[0]
|
96 |
+
return sentiment['score'] if sentiment['label'] == 'POSITIVE' else -sentiment['score']
|
97 |
+
except Exception as e:
|
98 |
+
print(f"Error in sentiment analysis for {symbol}: {e}")
|
99 |
+
return 0
|
100 |
+
|
101 |
+
def prepare_graph_features(self, market_data):
|
102 |
+
if not market_data:
|
103 |
+
raise ValueError("No market data available for feature preparation")
|
104 |
+
|
105 |
+
features = []
|
106 |
+
for symbol in market_data:
|
107 |
+
df = market_data[symbol]
|
108 |
+
if len(df) > 0:
|
109 |
+
try:
|
110 |
+
feature_vector = torch.tensor([
|
111 |
+
df['Close'].pct_change().fillna(0).mean(),
|
112 |
+
df['Close'].pct_change().fillna(0).std(),
|
113 |
+
df['Volume'].fillna(0).mean(),
|
114 |
+
df['High'].max(),
|
115 |
+
df['Low'].min()
|
116 |
+
], dtype=torch.float32)
|
117 |
+
features.append(feature_vector)
|
118 |
+
except Exception as e:
|
119 |
+
print(f"Error preparing features for {symbol}: {e}")
|
120 |
+
continue
|
121 |
+
|
122 |
+
if not features:
|
123 |
+
raise ValueError("Could not prepare features for any symbol")
|
124 |
+
|
125 |
+
return torch.stack(features).to(device)
|
126 |
+
|
127 |
+
def create_correlation_edges(self, market_data):
|
128 |
+
n = len(market_data)
|
129 |
+
if n < 2:
|
130 |
+
raise ValueError("Need at least 2 symbols to create correlation edges")
|
131 |
+
|
132 |
+
edges = []
|
133 |
+
for i in range(n):
|
134 |
+
for j in range(i+1, n):
|
135 |
+
edges.append([i, j])
|
136 |
+
edges.append([j, i])
|
137 |
+
|
138 |
+
return torch.tensor(edges, dtype=torch.long).t().to(device)
|
139 |
+
|
140 |
+
def predict_market_trends(self):
|
141 |
+
try:
|
142 |
+
market_data = self.collect_market_data()
|
143 |
+
features = self.prepare_graph_features(market_data)
|
144 |
+
edge_index = self.create_correlation_edges(market_data)
|
145 |
+
|
146 |
+
with torch.no_grad():
|
147 |
+
predictions = self.gnn_model(features, edge_index)
|
148 |
+
|
149 |
+
return predictions.cpu().numpy()
|
150 |
+
except Exception as e:
|
151 |
+
print(f"Error predicting market trends: {e}")
|
152 |
+
return np.zeros(len(self.symbols))
|
153 |
+
|
154 |
+
def generate_market_visualization(self, market_data):
|
155 |
+
if not market_data:
|
156 |
+
raise ValueError("No market data available for visualization")
|
157 |
+
|
158 |
+
fig = go.Figure()
|
159 |
+
|
160 |
+
for symbol in market_data:
|
161 |
+
df = market_data[symbol]
|
162 |
+
if not df.empty:
|
163 |
+
fig.add_trace(go.Scatter(
|
164 |
+
x=df.index,
|
165 |
+
y=df['Close'],
|
166 |
+
name=symbol,
|
167 |
+
mode='lines'
|
168 |
+
))
|
169 |
+
|
170 |
+
fig.update_layout(
|
171 |
+
title='Market Trends',
|
172 |
+
xaxis_title='Date',
|
173 |
+
yaxis_title='Price',
|
174 |
+
template='plotly_dark'
|
175 |
+
)
|
176 |
+
|
177 |
+
return fig
|
178 |
+
|
179 |
+
def start_monitoring(self):
|
180 |
+
if self.monitoring:
|
181 |
+
return
|
182 |
+
|
183 |
+
self.monitoring = True
|
184 |
+
self.monitor_thread = threading.Thread(target=self._monitoring_loop)
|
185 |
+
self.monitor_thread.daemon = True
|
186 |
+
self.monitor_thread.start()
|
187 |
+
|
188 |
+
def _monitoring_loop(self):
|
189 |
+
while self.monitoring:
|
190 |
+
try:
|
191 |
+
self.market_data = self.collect_market_data()
|
192 |
+
for symbol in self.symbols:
|
193 |
+
self.sentiment_scores[symbol] = self.analyze_sentiment(symbol)
|
194 |
+
time.sleep(300) # Update every 5 minutes
|
195 |
+
except Exception as e:
|
196 |
+
print(f"Error in monitoring loop: {e}")
|
197 |
+
time.sleep(60) # Wait a minute before retrying
|
198 |
+
|
199 |
+
def stop_monitoring(self):
|
200 |
+
self.monitoring = False
|
201 |
+
if self.monitor_thread:
|
202 |
+
self.monitor_thread.join(timeout=1)
|
203 |
+
|
204 |
+
def create_gradio_interface():
|
205 |
+
market_system = MarketAnalysisSystem()
|
206 |
+
|
207 |
+
def analyze_markets(symbols_input):
|
208 |
+
try:
|
209 |
+
# Input validation
|
210 |
+
if not symbols_input.strip():
|
211 |
+
return (
|
212 |
+
None,
|
213 |
+
"Error: Please enter at least one stock symbol",
|
214 |
+
"Error: No symbols provided"
|
215 |
+
)
|
216 |
+
|
217 |
+
symbols = [s.strip() for s in symbols_input.split(',') if s.strip()]
|
218 |
+
|
219 |
+
# Collect and analyze market data
|
220 |
+
try:
|
221 |
+
market_data = market_system.collect_market_data(symbols)
|
222 |
+
except Exception as e:
|
223 |
+
return (
|
224 |
+
None,
|
225 |
+
f"Error collecting market data: {str(e)}",
|
226 |
+
"Unable to analyze trends"
|
227 |
+
)
|
228 |
+
|
229 |
+
# Generate visualization
|
230 |
+
try:
|
231 |
+
fig = market_system.generate_market_visualization(market_data)
|
232 |
+
except Exception as e:
|
233 |
+
fig = None
|
234 |
+
print(f"Error generating visualization: {e}")
|
235 |
+
|
236 |
+
# Get sentiment scores
|
237 |
+
sentiments = {symbol: market_system.analyze_sentiment(symbol) for symbol in symbols}
|
238 |
+
sentiment_text = "\n".join([f"{symbol}: {score:.2f}" for symbol, score in sentiments.items()])
|
239 |
+
|
240 |
+
# Predict trends
|
241 |
+
try:
|
242 |
+
predictions = market_system.predict_market_trends()
|
243 |
+
prediction_text = "\n".join([
|
244 |
+
f"{symbol}: {'Upward' if pred > 0 else 'Downward'} trend"
|
245 |
+
for symbol, pred in zip(symbols, predictions[:len(symbols)])
|
246 |
+
])
|
247 |
+
except Exception as e:
|
248 |
+
prediction_text = f"Error predicting trends: {str(e)}"
|
249 |
+
|
250 |
+
return fig, sentiment_text, prediction_text
|
251 |
+
|
252 |
+
except Exception as e:
|
253 |
+
return None, f"Error: {str(e)}", "Analysis failed"
|
254 |
+
|
255 |
+
interface = gr.Interface(
|
256 |
+
fn=analyze_markets,
|
257 |
+
inputs=gr.Textbox(
|
258 |
+
label="Enter stock symbols (comma-separated)",
|
259 |
+
value="AAPL,GOOGL,MSFT"
|
260 |
+
),
|
261 |
+
outputs=[
|
262 |
+
gr.Plot(label="Market Trends"),
|
263 |
+
gr.Textbox(label="Sentiment Analysis"),
|
264 |
+
gr.Textbox(label="Trend Predictions")
|
265 |
+
],
|
266 |
+
title="Real-Time Market Analysis System",
|
267 |
+
description="Enter stock symbols to analyze market trends, sentiment, and predictions."
|
268 |
+
)
|
269 |
+
|
270 |
+
return interface
|
271 |
+
|
272 |
+
if __name__ == "__main__":
|
273 |
+
# Ensure all required packages are imported
|
274 |
+
required_packages = {
|
275 |
+
'numpy': np,
|
276 |
+
'pandas': pd,
|
277 |
+
'yfinance': yf,
|
278 |
+
'torch': torch,
|
279 |
+
'transformers': pipeline,
|
280 |
+
'gradio': gr,
|
281 |
+
'plotly': go
|
282 |
+
}
|
283 |
+
|
284 |
+
missing_packages = []
|
285 |
+
for package, module in required_packages.items():
|
286 |
+
if module is None:
|
287 |
+
missing_packages.append(package)
|
288 |
+
|
289 |
+
if missing_packages:
|
290 |
+
print(f"Missing required packages: {', '.join(missing_packages)}")
|
291 |
+
print("Please install them using pip:")
|
292 |
+
print(f"pip install {' '.join(missing_packages)}")
|
293 |
+
else:
|
294 |
+
interface = create_gradio_interface()
|
295 |
+
interface.launch(debug=True)
|