invincible-jha commited on
Commit
baaf109
1 Parent(s): 0d41e29

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import required libraries
2
+ import gradio as gr
3
+ import numpy as np
4
+ import mne
5
+ import matplotlib.pyplot as plt
6
+ from scipy import signal
7
+ from scipy.stats import skew, kurtosis
8
+ import pandas as pd
9
+ from pathlib import Path
10
+ import tempfile
11
+ import os
12
+ from huggingface_hub import HfApi, HfFolder
13
+ import warnings
14
+ import logging
15
+
16
+ # Configure logging system
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Initialize Hugging Face API with token from environment variable
21
+ def initialize_hf_api():
22
+ try:
23
+ hf_token = os.getenv('HF_TOKEN')
24
+ if hf_token:
25
+ hf_api = HfApi(token=hf_token)
26
+ HfFolder.save_token(hf_token)
27
+ logger.info("Successfully initialized Hugging Face API")
28
+ return hf_api
29
+ else:
30
+ logger.warning("HF_TOKEN not found in environment variables")
31
+ return None
32
+ except Exception as e:
33
+ logger.error(f"Error initializing Hugging Face API: {str(e)}")
34
+ return None
35
+
36
+ class EEGProcessor:
37
+ """Class for processing and analyzing EEG signals"""
38
+
39
+ def __init__(self):
40
+ """Initialize the EEG processor with default parameters"""
41
+ self.sampling_rate = 250 # Default sampling rate in Hz
42
+ self.freq_bands = {
43
+ 'delta': (0.5, 4), # Delta band (0.5-4 Hz)
44
+ 'theta': (4, 8), # Theta band (4-8 Hz)
45
+ 'alpha': (8, 13), # Alpha band (8-13 Hz)
46
+ 'beta': (13, 30), # Beta band (13-30 Hz)
47
+ 'gamma': (30, 50) # Gamma band (30-50 Hz)
48
+ }
49
+
50
+ def load_eeg(self, file_path):
51
+ """
52
+ Load and validate EEG data from file
53
+ Args:
54
+ file_path: Path to the EEG file
55
+ Returns:
56
+ mne.io.Raw: Loaded EEG data
57
+ """
58
+ try:
59
+ # Validate file existence and size
60
+ if not os.path.exists(file_path):
61
+ raise ValueError("File does not exist")
62
+
63
+ # Check file size (100MB limit)
64
+ if os.path.getsize(file_path) > 100 * 1024 * 1024:
65
+ raise ValueError("File size exceeds 100MB limit")
66
+
67
+ # Load EEG data using MNE
68
+ raw = mne.io.read_raw_edf(file_path, preload=True)
69
+ self.sampling_rate = raw.info['sfreq']
70
+ logger.info(f"Successfully loaded EEG file: {file_path}")
71
+ return raw
72
+
73
+ except Exception as e:
74
+ logger.error(f"Error loading EEG file: {str(e)}")
75
+ raise ValueError(f"Error loading EEG file: {str(e)}")
76
+
77
+ def preprocess_signal(self, raw):
78
+ """
79
+ Apply preprocessing steps to the EEG signal
80
+ Args:
81
+ raw: Raw EEG data
82
+ Returns:
83
+ mne.io.Raw: Preprocessed EEG data
84
+ """
85
+ try:
86
+ logger.info("Starting signal preprocessing")
87
+
88
+ # Apply bandpass filter (0.5-50 Hz)
89
+ raw.filter(l_freq=0.5, h_freq=50., fir_design='firwin')
90
+ logger.info("Applied bandpass filter")
91
+
92
+ # Remove power line interference
93
+ raw.notch_filter(freqs=[50, 60])
94
+ logger.info("Applied notch filter")
95
+
96
+ return raw
97
+
98
+ except Exception as e:
99
+ logger.error(f"Error in signal preprocessing: {str(e)}")
100
+ raise
101
+
102
+ def extract_features(self, raw):
103
+ """
104
+ Extract time and frequency domain features from EEG data
105
+ Args:
106
+ raw: Preprocessed EEG data
107
+ Returns:
108
+ dict: Extracted features
109
+ """
110
+ try:
111
+ logger.info("Starting feature extraction")
112
+ data = raw.get_data()
113
+ features = {}
114
+
115
+ # Calculate time domain features
116
+ features['mean'] = np.mean(data, axis=1)
117
+ features['variance'] = np.var(data, axis=1)
118
+ features['skewness'] = skew(data, axis=1)
119
+ features['kurtosis'] = kurtosis(data, axis=1)
120
+
121
+ # Calculate frequency domain features
122
+ for band_name, (low_freq, high_freq) in self.freq_bands.items():
123
+ band_power = self._calculate_band_power(data, low_freq, high_freq)
124
+ features[f'{band_name}_power'] = band_power
125
+
126
+ logger.info("Feature extraction completed successfully")
127
+ return features
128
+
129
+ except Exception as e:
130
+ logger.error(f"Error in feature extraction: {str(e)}")
131
+ raise
132
+
133
+ def _calculate_band_power(self, data, low_freq, high_freq):
134
+ """
135
+ Calculate power in specific frequency band
136
+ Args:
137
+ data: EEG data
138
+ low_freq: Lower frequency bound
139
+ high_freq: Upper frequency bound
140
+ Returns:
141
+ float: Band power value
142
+ """
143
+ try:
144
+ # Calculate power spectral density
145
+ freqs, psd = signal.welch(data, fs=self.sampling_rate)
146
+
147
+ # Find frequencies within band
148
+ idx = np.logical_and(freqs >= low_freq, freqs <= high_freq)
149
+
150
+ # Calculate band power using trapezoidal integration
151
+ band_power = np.trapz(psd[:, idx], freqs[idx])
152
+ return band_power
153
+
154
+ except Exception as e:
155
+ logger.error(f"Error calculating band power: {str(e)}")
156
+ raise
157
+
158
+ def create_visualization(raw, features):
159
+ """
160
+ Create visualization plots for EEG analysis
161
+ Args:
162
+ raw: EEG data
163
+ features: Extracted features
164
+ Returns:
165
+ matplotlib.figure.Figure: Figure containing plots
166
+ """
167
+ try:
168
+ # Create figure with three subplots
169
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 15))
170
+
171
+ # Plot 1: Raw EEG signal
172
+ data = raw.get_data()
173
+ times = np.arange(data.shape[1]) / raw.info['sfreq']
174
+ ax1.plot(times, data.T)
175
+ ax1.set_title('Raw EEG Signal')
176
+ ax1.set_xlabel('Time (s)')
177
+ ax1.set_ylabel('Amplitude')
178
+
179
+ # Plot 2: Power spectrum
180
+ freqs, psd = signal.welch(data, fs=raw.info['sfreq'])
181
+ ax2.semilogy(freqs, psd.T)
182
+ ax2.set_title('Power Spectrum')
183
+ ax2.set_xlabel('Frequency (Hz)')
184
+ ax2.set_ylabel('Power Spectral Density')
185
+
186
+ # Plot 3: Band powers
187
+ band_powers = {k: v for k, v in features.items() if 'power' in k}
188
+ ax3.bar(band_powers.keys(), band_powers.values())
189
+ ax3.set_title('Band Powers')
190
+ ax3.set_xlabel('Frequency Band')
191
+ ax3.set_ylabel('Power')
192
+
193
+ plt.tight_layout()
194
+ return fig
195
+
196
+ except Exception as e:
197
+ logger.error(f"Error creating visualization: {str(e)}")
198
+ raise
199
+
200
+ def process_eeg(file):
201
+ """
202
+ Main processing function for the Gradio interface
203
+ Args:
204
+ file: Uploaded file object
205
+ Returns:
206
+ tuple: (matplotlib figure, feature analysis string)
207
+ """
208
+ try:
209
+ logger.info("Starting EEG processing")
210
+
211
+ # Create temporary file for processing
212
+ with tempfile.NamedTemporaryFile(suffix='.edf', delete=False) as tmp_file:
213
+ tmp_file.write(file.read())
214
+ tmp_path = tmp_file.name
215
+
216
+ logger.info("Temporary file created")
217
+
218
+ # Initialize processor and process EEG
219
+ processor = EEGProcessor()
220
+ raw = processor.load_eeg(tmp_path)
221
+ raw = processor.preprocess_signal(raw)
222
+ features = processor.extract_features(raw)
223
+
224
+ # Create visualizations and feature summary
225
+ fig = create_visualization(raw, features)
226
+ feature_df = pd.DataFrame({k: [v] for k, v in features.items()})
227
+
228
+ # Clean up temporary file
229
+ os.unlink(tmp_path)
230
+ logger.info("Processing completed successfully")
231
+
232
+ return fig, feature_df.to_string()
233
+
234
+ except Exception as e:
235
+ logger.error(f"Error in EEG processing: {str(e)}")
236
+ raise gr.Error(f"Error processing EEG: {str(e)}")
237
+
238
+ # Create Gradio interface
239
+ def create_interface():
240
+ """Create and configure the Gradio interface"""
241
+ try:
242
+ # Initialize Hugging Face API
243
+ initialize_hf_api()
244
+
245
+ # Create Gradio blocks interface
246
+ with gr.Blocks(title="EEG Signal Analysis") as iface:
247
+ gr.Markdown("# EEG Signal Analysis Tool")
248
+ gr.Markdown("Upload an EEG file (.edf format) for analysis")
249
+
250
+ with gr.Row():
251
+ file_input = gr.File(
252
+ label="Upload EEG File",
253
+ file_types=[".edf"],
254
+ type="binary"
255
+ )
256
+
257
+ with gr.Row():
258
+ analyze_btn = gr.Button("Analyze EEG")
259
+
260
+ with gr.Row():
261
+ plot_output = gr.Plot(label="EEG Analysis Plots")
262
+ feature_output = gr.Textbox(label="Feature Analysis", lines=10)
263
+
264
+ # Set up button click event
265
+ analyze_btn.click(
266
+ fn=process_eeg,
267
+ inputs=[file_input],
268
+ outputs=[plot_output, feature_output]
269
+ )
270
+
271
+ return iface
272
+
273
+ except Exception as e:
274
+ logger.error(f"Error creating interface: {str(e)}")
275
+ raise
276
+
277
+ # Launch the application
278
+ if __name__ == "__main__":
279
+ try:
280
+ iface = create_interface()
281
+ iface.launch()
282
+ except Exception as e:
283
+ logger.error(f"Application startup failed: {str(e)}")
284
+ raise