Anupam251272 commited on
Commit
a77ea34
1 Parent(s): 3ece877

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +295 -0
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import yfinance as yf
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch_geometric.nn import GCNConv
7
+ from transformers import pipeline
8
+ import gradio as gr
9
+ import plotly.graph_objects as go
10
+ from datetime import datetime, timedelta
11
+ import threading
12
+ import time
13
+
14
+ # Check if GPU is available
15
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
+ print(f"Using device: {device}")
17
+
18
+ class FinancialGNN(nn.Module):
19
+ def __init__(self, input_dim, hidden_dim, output_dim):
20
+ super(FinancialGNN, self).__init__()
21
+ self.conv1 = GCNConv(input_dim, hidden_dim)
22
+ self.conv2 = GCNConv(hidden_dim, output_dim)
23
+
24
+ def forward(self, x, edge_index):
25
+ # Add error handling for input dimensions
26
+ if x.dim() != 2:
27
+ raise ValueError(f"Expected 2D input tensor, got {x.dim()}D")
28
+
29
+ x = self.conv1(x, edge_index)
30
+ x = torch.relu(x)
31
+ x = self.conv2(x, edge_index)
32
+ return x
33
+
34
+ class MarketAnalysisSystem:
35
+ def __init__(self):
36
+ try:
37
+ # Initialize sentiment analyzer with error handling
38
+ self.sentiment_analyzer = pipeline(
39
+ "sentiment-analysis",
40
+ model="distilbert-base-uncased-finetuned-sst-2-english",
41
+ device=0 if torch.cuda.is_available() else -1
42
+ )
43
+ except Exception as e:
44
+ print(f"Error initializing sentiment analyzer: {e}")
45
+ self.sentiment_analyzer = None
46
+
47
+ # Initialize GNN model
48
+ self.gnn_model = FinancialGNN(input_dim=5, hidden_dim=32, output_dim=1).to(device)
49
+
50
+ # Define default stock symbols
51
+ self.symbols = ['AAPL', 'GOOGL', 'MSFT', 'AMZN', 'META']
52
+
53
+ # Initialize data storage
54
+ self.market_data = {}
55
+ self.sentiment_scores = {}
56
+
57
+ # Initialize monitoring flag and thread
58
+ self.monitoring = False
59
+ self.monitor_thread = None
60
+
61
+ def collect_market_data(self, symbols=None):
62
+ if symbols is None:
63
+ symbols = self.symbols
64
+
65
+ if not symbols:
66
+ raise ValueError("No symbols provided for market data collection")
67
+
68
+ end_date = datetime.now()
69
+ start_date = end_date - timedelta(days=30)
70
+
71
+ market_data = {}
72
+ for symbol in symbols:
73
+ try:
74
+ stock = yf.download(symbol, start=start_date, end=end_date, progress=False)
75
+ if stock.empty:
76
+ print(f"No data available for {symbol}")
77
+ continue
78
+ market_data[symbol] = stock
79
+ except Exception as e:
80
+ print(f"Error collecting data for {symbol}: {e}")
81
+ continue
82
+
83
+ if not market_data:
84
+ raise ValueError("No market data could be collected for any symbol")
85
+
86
+ return market_data
87
+
88
+ def analyze_sentiment(self, symbol):
89
+ if self.sentiment_analyzer is None:
90
+ return 0
91
+
92
+ try:
93
+ # In practice, you should implement real news fetching here
94
+ sample_news = f"Latest news about {symbol} shows positive market momentum"
95
+ sentiment = self.sentiment_analyzer(sample_news)[0]
96
+ return sentiment['score'] if sentiment['label'] == 'POSITIVE' else -sentiment['score']
97
+ except Exception as e:
98
+ print(f"Error in sentiment analysis for {symbol}: {e}")
99
+ return 0
100
+
101
+ def prepare_graph_features(self, market_data):
102
+ if not market_data:
103
+ raise ValueError("No market data available for feature preparation")
104
+
105
+ features = []
106
+ for symbol in market_data:
107
+ df = market_data[symbol]
108
+ if len(df) > 0:
109
+ try:
110
+ feature_vector = torch.tensor([
111
+ df['Close'].pct_change().fillna(0).mean(),
112
+ df['Close'].pct_change().fillna(0).std(),
113
+ df['Volume'].fillna(0).mean(),
114
+ df['High'].max(),
115
+ df['Low'].min()
116
+ ], dtype=torch.float32)
117
+ features.append(feature_vector)
118
+ except Exception as e:
119
+ print(f"Error preparing features for {symbol}: {e}")
120
+ continue
121
+
122
+ if not features:
123
+ raise ValueError("Could not prepare features for any symbol")
124
+
125
+ return torch.stack(features).to(device)
126
+
127
+ def create_correlation_edges(self, market_data):
128
+ n = len(market_data)
129
+ if n < 2:
130
+ raise ValueError("Need at least 2 symbols to create correlation edges")
131
+
132
+ edges = []
133
+ for i in range(n):
134
+ for j in range(i+1, n):
135
+ edges.append([i, j])
136
+ edges.append([j, i])
137
+
138
+ return torch.tensor(edges, dtype=torch.long).t().to(device)
139
+
140
+ def predict_market_trends(self):
141
+ try:
142
+ market_data = self.collect_market_data()
143
+ features = self.prepare_graph_features(market_data)
144
+ edge_index = self.create_correlation_edges(market_data)
145
+
146
+ with torch.no_grad():
147
+ predictions = self.gnn_model(features, edge_index)
148
+
149
+ return predictions.cpu().numpy()
150
+ except Exception as e:
151
+ print(f"Error predicting market trends: {e}")
152
+ return np.zeros(len(self.symbols))
153
+
154
+ def generate_market_visualization(self, market_data):
155
+ if not market_data:
156
+ raise ValueError("No market data available for visualization")
157
+
158
+ fig = go.Figure()
159
+
160
+ for symbol in market_data:
161
+ df = market_data[symbol]
162
+ if not df.empty:
163
+ fig.add_trace(go.Scatter(
164
+ x=df.index,
165
+ y=df['Close'],
166
+ name=symbol,
167
+ mode='lines'
168
+ ))
169
+
170
+ fig.update_layout(
171
+ title='Market Trends',
172
+ xaxis_title='Date',
173
+ yaxis_title='Price',
174
+ template='plotly_dark'
175
+ )
176
+
177
+ return fig
178
+
179
+ def start_monitoring(self):
180
+ if self.monitoring:
181
+ return
182
+
183
+ self.monitoring = True
184
+ self.monitor_thread = threading.Thread(target=self._monitoring_loop)
185
+ self.monitor_thread.daemon = True
186
+ self.monitor_thread.start()
187
+
188
+ def _monitoring_loop(self):
189
+ while self.monitoring:
190
+ try:
191
+ self.market_data = self.collect_market_data()
192
+ for symbol in self.symbols:
193
+ self.sentiment_scores[symbol] = self.analyze_sentiment(symbol)
194
+ time.sleep(300) # Update every 5 minutes
195
+ except Exception as e:
196
+ print(f"Error in monitoring loop: {e}")
197
+ time.sleep(60) # Wait a minute before retrying
198
+
199
+ def stop_monitoring(self):
200
+ self.monitoring = False
201
+ if self.monitor_thread:
202
+ self.monitor_thread.join(timeout=1)
203
+
204
+ def create_gradio_interface():
205
+ market_system = MarketAnalysisSystem()
206
+
207
+ def analyze_markets(symbols_input):
208
+ try:
209
+ # Input validation
210
+ if not symbols_input.strip():
211
+ return (
212
+ None,
213
+ "Error: Please enter at least one stock symbol",
214
+ "Error: No symbols provided"
215
+ )
216
+
217
+ symbols = [s.strip() for s in symbols_input.split(',') if s.strip()]
218
+
219
+ # Collect and analyze market data
220
+ try:
221
+ market_data = market_system.collect_market_data(symbols)
222
+ except Exception as e:
223
+ return (
224
+ None,
225
+ f"Error collecting market data: {str(e)}",
226
+ "Unable to analyze trends"
227
+ )
228
+
229
+ # Generate visualization
230
+ try:
231
+ fig = market_system.generate_market_visualization(market_data)
232
+ except Exception as e:
233
+ fig = None
234
+ print(f"Error generating visualization: {e}")
235
+
236
+ # Get sentiment scores
237
+ sentiments = {symbol: market_system.analyze_sentiment(symbol) for symbol in symbols}
238
+ sentiment_text = "\n".join([f"{symbol}: {score:.2f}" for symbol, score in sentiments.items()])
239
+
240
+ # Predict trends
241
+ try:
242
+ predictions = market_system.predict_market_trends()
243
+ prediction_text = "\n".join([
244
+ f"{symbol}: {'Upward' if pred > 0 else 'Downward'} trend"
245
+ for symbol, pred in zip(symbols, predictions[:len(symbols)])
246
+ ])
247
+ except Exception as e:
248
+ prediction_text = f"Error predicting trends: {str(e)}"
249
+
250
+ return fig, sentiment_text, prediction_text
251
+
252
+ except Exception as e:
253
+ return None, f"Error: {str(e)}", "Analysis failed"
254
+
255
+ interface = gr.Interface(
256
+ fn=analyze_markets,
257
+ inputs=gr.Textbox(
258
+ label="Enter stock symbols (comma-separated)",
259
+ value="AAPL,GOOGL,MSFT"
260
+ ),
261
+ outputs=[
262
+ gr.Plot(label="Market Trends"),
263
+ gr.Textbox(label="Sentiment Analysis"),
264
+ gr.Textbox(label="Trend Predictions")
265
+ ],
266
+ title="Real-Time Market Analysis System",
267
+ description="Enter stock symbols to analyze market trends, sentiment, and predictions."
268
+ )
269
+
270
+ return interface
271
+
272
+ if __name__ == "__main__":
273
+ # Ensure all required packages are imported
274
+ required_packages = {
275
+ 'numpy': np,
276
+ 'pandas': pd,
277
+ 'yfinance': yf,
278
+ 'torch': torch,
279
+ 'transformers': pipeline,
280
+ 'gradio': gr,
281
+ 'plotly': go
282
+ }
283
+
284
+ missing_packages = []
285
+ for package, module in required_packages.items():
286
+ if module is None:
287
+ missing_packages.append(package)
288
+
289
+ if missing_packages:
290
+ print(f"Missing required packages: {', '.join(missing_packages)}")
291
+ print("Please install them using pip:")
292
+ print(f"pip install {' '.join(missing_packages)}")
293
+ else:
294
+ interface = create_gradio_interface()
295
+ interface.launch(debug=True)