File size: 8,178 Bytes
baaf109
 
 
 
 
 
 
 
 
 
 
980d5ca
baaf109
1fb8a6d
980d5ca
 
 
 
 
baaf109
980d5ca
baaf109
 
 
1fb8a6d
baaf109
a080e03
 
 
 
 
baaf109
980d5ca
baaf109
 
980d5ca
baaf109
 
 
 
 
 
1fb8a6d
baaf109
 
 
 
 
 
 
980d5ca
baaf109
 
 
 
 
 
1fb8a6d
baaf109
 
1fb8a6d
 
 
 
 
 
 
 
 
 
baaf109
980d5ca
baaf109
 
 
1fb8a6d
 
 
 
 
 
 
baaf109
1fb8a6d
baaf109
 
 
 
 
1fb8a6d
 
baaf109
1fb8a6d
baaf109
1fb8a6d
980d5ca
baaf109
1fb8a6d
baaf109
 
 
 
 
1fb8a6d
baaf109
1fb8a6d
 
 
 
 
baaf109
1fb8a6d
 
 
 
 
 
 
980d5ca
1fb8a6d
baaf109
1fb8a6d
 
baaf109
1fb8a6d
 
baaf109
1fb8a6d
 
 
baaf109
 
 
1fb8a6d
 
 
 
 
 
 
 
 
baaf109
1fb8a6d
 
 
 
 
 
baaf109
 
 
 
 
1fb8a6d
baaf109
 
 
 
a080e03
980d5ca
baaf109
 
 
1fb8a6d
a080e03
 
 
 
 
1fb8a6d
baaf109
1fb8a6d
baaf109
 
980d5ca
baaf109
 
 
 
 
980d5ca
baaf109
1fb8a6d
 
 
 
 
 
 
 
 
 
 
 
 
baaf109
980d5ca
baaf109
 
 
1fb8a6d
baaf109
 
 
 
 
 
a080e03
 
 
 
 
 
 
 
1fb8a6d
a080e03
 
 
 
 
 
 
 
 
 
 
 
 
 
baaf109
 
 
1fb8a6d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import gradio as gr
import numpy as np
import mne
import matplotlib.pyplot as plt
from scipy import signal
from scipy.stats import skew, kurtosis
import pandas as pd
from pathlib import Path
import tempfile
import os
import logging
from datetime import datetime

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

logger = logging.getLogger(__name__)

class EEGProcessor:
    def __init__(self):
        self.sampling_rate = 250
        self.freq_bands = {
            'delta': (0.5, 4),
            'theta': (4, 8),
            'alpha': (8, 13),
            'beta': (13, 30),
            'gamma': (30, 50)
        }
        logger.info("EEG Processor initialized")

    def load_eeg(self, file_path):
        """Load and validate EEG data from file"""
        try:
            if not os.path.exists(file_path):
                raise ValueError("File does not exist")
            
            raw = mne.io.read_raw_edf(file_path, preload=True)
            self.sampling_rate = raw.info['sfreq']
            logger.info(f"Loaded EEG with {len(raw.ch_names)} channels at {self.sampling_rate} Hz")
            return raw
            
        except Exception as e:
            logger.error(f"Error loading EEG file: {str(e)}")
            raise ValueError(f"Error loading EEG file: {str(e)}")

    def preprocess_signal(self, raw):
        """Apply preprocessing to the EEG signal"""
        try:
            logger.info("Starting signal preprocessing")
            raw.filter(l_freq=0.5, h_freq=50., fir_design='firwin')
            raw.notch_filter(freqs=[50, 60])
            return raw
        except Exception as e:
            logger.error(f"Error in preprocessing: {str(e)}")
            raise

    def _calculate_band_power(self, data, low_freq, high_freq):
        """
        Calculate power in specific frequency band for each channel
        Returns: Array of band powers for each channel
        """
        freqs, psd = signal.welch(data, fs=self.sampling_rate)
        idx = np.logical_and(freqs >= low_freq, freqs <= high_freq)
        band_power = np.trapz(psd[:, idx], freqs[idx], axis=1)
        return band_power

    def extract_features(self, raw):
        """Extract time and frequency domain features"""
        try:
            logger.info("Starting feature extraction")
            data = raw.get_data()
            n_channels = data.shape[0]
            
            # Initialize feature dictionary with channel names
            features = {
                'channel_names': raw.ch_names,
                'n_channels': n_channels
            }
            
            # Calculate time domain features for each channel
            features['mean'] = np.mean(data, axis=1)
            features['variance'] = np.var(data, axis=1)
            features['skewness'] = skew(data, axis=1)
            features['kurtosis'] = kurtosis(data, axis=1)
            
            # Calculate frequency domain features for each channel
            band_powers = {}
            for band_name, (low_freq, high_freq) in self.freq_bands.items():
                band_powers[band_name] = self._calculate_band_power(data, low_freq, high_freq)
            
            features['band_powers'] = band_powers
            logger.info("Feature extraction completed")
            return features
            
        except Exception as e:
            logger.error(f"Error in feature extraction: {str(e)}")
            raise

def create_visualization(raw, features):
    """Create visualization plots with multi-channel support"""
    try:
        fig = plt.figure(figsize=(15, 20))
        n_channels = features['n_channels']
        
        # Plot 1: Raw EEG signal (first 10 seconds)
        ax1 = plt.subplot(3, 1, 1)
        data = raw.get_data()
        times = np.arange(min(data.shape[1], int(10 * raw.info['sfreq']))) / raw.info['sfreq']
        plot_data = data[:, :len(times)]
        
        # Add offset to separate channels
        offsets = np.arange(n_channels) * np.std(plot_data) * 3
        for i in range(n_channels):
            ax1.plot(times, plot_data[i] + offsets[i], label=features['channel_names'][i])
        
        ax1.set_title('Raw EEG Signal (first 10 seconds)')
        ax1.set_xlabel('Time (s)')
        ax1.set_ylabel('Channel')
        ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # Plot 2: Power spectrum
        ax2 = plt.subplot(3, 1, 2)
        freqs, psd = signal.welch(data, fs=raw.info['sfreq'])
        for i in range(n_channels):
            ax2.semilogy(freqs, psd[i], label=features['channel_names'][i])
        
        ax2.set_title('Power Spectrum')
        ax2.set_xlabel('Frequency (Hz)')
        ax2.set_ylabel('Power Spectral Density')
        ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # Plot 3: Average band powers across channels
        ax3 = plt.subplot(3, 1, 3)
        band_powers = features['band_powers']
        
        # Calculate mean band power across channels
        mean_powers = {band: np.mean(powers) for band, powers in band_powers.items()}
        std_powers = {band: np.std(powers) for band, powers in band_powers.items()}
        
        bands = list(mean_powers.keys())
        means = list(mean_powers.values())
        stds = list(std_powers.values())
        
        bars = ax3.bar(bands, means, yerr=stds, capsize=5)
        ax3.set_title('Average Band Powers Across Channels')
        ax3.set_xlabel('Frequency Band')
        ax3.set_ylabel('Power')
        
        plt.tight_layout()
        return fig
        
    except Exception as e:
        logger.error(f"Error creating visualization: {str(e)}")
        raise

def process_eeg(file_obj):
    """Main processing function for Gradio interface"""
    try:
        logger.info("Starting EEG processing")
        
        # Handle file content
        if isinstance(file_obj, bytes):
            file_content = file_obj
        else:
            file_content = file_obj.read()
        
        # Create temporary file
        with tempfile.NamedTemporaryFile(suffix='.edf', delete=False) as tmp_file:
            tmp_file.write(file_content)
            tmp_path = tmp_file.name
        
        # Process EEG
        processor = EEGProcessor()
        raw = processor.load_eeg(tmp_path)
        raw = processor.preprocess_signal(raw)
        features = processor.extract_features(raw)
        
        # Create visualizations
        fig = create_visualization(raw, features)
        
        # Create feature summary table
        feature_summary = pd.DataFrame({
            'Channel': features['channel_names'],
            'Mean': features['mean'],
            'Variance': features['variance'],
            'Skewness': features['skewness'],
            'Kurtosis': features['kurtosis']
        })
        
        # Add band powers to summary
        for band, powers in features['band_powers'].items():
            feature_summary[f'{band}_power'] = powers
        
        # Clean up
        os.unlink(tmp_path)
        logger.info("Processing completed successfully")
        
        return fig, feature_summary.to_string()
        
    except Exception as e:
        logger.error(f"Error in EEG processing: {str(e)}")
        raise gr.Error(f"Error processing EEG: {str(e)}")

# Create Gradio interface
with gr.Blocks(title="EEG Signal Analysis") as iface:
    gr.Markdown("# EEG Signal Analysis Tool")
    gr.Markdown("Upload an EEG file (.edf format) for analysis")
    
    with gr.Row():
        file_input = gr.File(
            label="Upload EEG File",
            file_types=[".edf"],
            type="binary"
        )
    
    with gr.Row():
        analyze_btn = gr.Button("Analyze EEG")
    
    with gr.Row():
        plot_output = gr.Plot(label="EEG Analysis Plots")
        feature_output = gr.Textbox(label="Feature Analysis", lines=10)
    
    analyze_btn.click(
        fn=process_eeg,
        inputs=[file_input],
        outputs=[plot_output, feature_output]
    )

# Launch the application
if __name__ == "__main__":
    iface.launch(share=True)