import gradio as gr import numpy as np import mne import matplotlib.pyplot as plt from scipy import signal from scipy.stats import skew, kurtosis import pandas as pd from pathlib import Path import tempfile import os import logging from datetime import datetime # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger(__name__) class EEGProcessor: def __init__(self): self.sampling_rate = 250 self.freq_bands = { 'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 13), 'beta': (13, 30), 'gamma': (30, 50) } logger.info("EEG Processor initialized") def load_eeg(self, file_path): """Load and validate EEG data from file""" try: if not os.path.exists(file_path): raise ValueError("File does not exist") raw = mne.io.read_raw_edf(file_path, preload=True) self.sampling_rate = raw.info['sfreq'] logger.info(f"Loaded EEG with {len(raw.ch_names)} channels at {self.sampling_rate} Hz") return raw except Exception as e: logger.error(f"Error loading EEG file: {str(e)}") raise ValueError(f"Error loading EEG file: {str(e)}") def preprocess_signal(self, raw): """Apply preprocessing to the EEG signal""" try: logger.info("Starting signal preprocessing") raw.filter(l_freq=0.5, h_freq=50., fir_design='firwin') raw.notch_filter(freqs=[50, 60]) return raw except Exception as e: logger.error(f"Error in preprocessing: {str(e)}") raise def _calculate_band_power(self, data, low_freq, high_freq): """ Calculate power in specific frequency band for each channel Returns: Array of band powers for each channel """ freqs, psd = signal.welch(data, fs=self.sampling_rate) idx = np.logical_and(freqs >= low_freq, freqs <= high_freq) band_power = np.trapz(psd[:, idx], freqs[idx], axis=1) return band_power def extract_features(self, raw): """Extract time and frequency domain features""" try: logger.info("Starting feature extraction") data = raw.get_data() n_channels = data.shape[0] # Initialize feature dictionary with channel names features = { 'channel_names': raw.ch_names, 'n_channels': n_channels } # Calculate time domain features for each channel features['mean'] = np.mean(data, axis=1) features['variance'] = np.var(data, axis=1) features['skewness'] = skew(data, axis=1) features['kurtosis'] = kurtosis(data, axis=1) # Calculate frequency domain features for each channel band_powers = {} for band_name, (low_freq, high_freq) in self.freq_bands.items(): band_powers[band_name] = self._calculate_band_power(data, low_freq, high_freq) features['band_powers'] = band_powers logger.info("Feature extraction completed") return features except Exception as e: logger.error(f"Error in feature extraction: {str(e)}") raise def create_visualization(raw, features): """Create visualization plots with multi-channel support""" try: fig = plt.figure(figsize=(15, 20)) n_channels = features['n_channels'] # Plot 1: Raw EEG signal (first 10 seconds) ax1 = plt.subplot(3, 1, 1) data = raw.get_data() times = np.arange(min(data.shape[1], int(10 * raw.info['sfreq']))) / raw.info['sfreq'] plot_data = data[:, :len(times)] # Add offset to separate channels offsets = np.arange(n_channels) * np.std(plot_data) * 3 for i in range(n_channels): ax1.plot(times, plot_data[i] + offsets[i], label=features['channel_names'][i]) ax1.set_title('Raw EEG Signal (first 10 seconds)') ax1.set_xlabel('Time (s)') ax1.set_ylabel('Channel') ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Plot 2: Power spectrum ax2 = plt.subplot(3, 1, 2) freqs, psd = signal.welch(data, fs=raw.info['sfreq']) for i in range(n_channels): ax2.semilogy(freqs, psd[i], label=features['channel_names'][i]) ax2.set_title('Power Spectrum') ax2.set_xlabel('Frequency (Hz)') ax2.set_ylabel('Power Spectral Density') ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Plot 3: Average band powers across channels ax3 = plt.subplot(3, 1, 3) band_powers = features['band_powers'] # Calculate mean band power across channels mean_powers = {band: np.mean(powers) for band, powers in band_powers.items()} std_powers = {band: np.std(powers) for band, powers in band_powers.items()} bands = list(mean_powers.keys()) means = list(mean_powers.values()) stds = list(std_powers.values()) bars = ax3.bar(bands, means, yerr=stds, capsize=5) ax3.set_title('Average Band Powers Across Channels') ax3.set_xlabel('Frequency Band') ax3.set_ylabel('Power') plt.tight_layout() return fig except Exception as e: logger.error(f"Error creating visualization: {str(e)}") raise def process_eeg(file_obj): """Main processing function for Gradio interface""" try: logger.info("Starting EEG processing") # Handle file content if isinstance(file_obj, bytes): file_content = file_obj else: file_content = file_obj.read() # Create temporary file with tempfile.NamedTemporaryFile(suffix='.edf', delete=False) as tmp_file: tmp_file.write(file_content) tmp_path = tmp_file.name # Process EEG processor = EEGProcessor() raw = processor.load_eeg(tmp_path) raw = processor.preprocess_signal(raw) features = processor.extract_features(raw) # Create visualizations fig = create_visualization(raw, features) # Create feature summary table feature_summary = pd.DataFrame({ 'Channel': features['channel_names'], 'Mean': features['mean'], 'Variance': features['variance'], 'Skewness': features['skewness'], 'Kurtosis': features['kurtosis'] }) # Add band powers to summary for band, powers in features['band_powers'].items(): feature_summary[f'{band}_power'] = powers # Clean up os.unlink(tmp_path) logger.info("Processing completed successfully") return fig, feature_summary.to_string() except Exception as e: logger.error(f"Error in EEG processing: {str(e)}") raise gr.Error(f"Error processing EEG: {str(e)}") # Create Gradio interface with gr.Blocks(title="EEG Signal Analysis") as iface: gr.Markdown("# EEG Signal Analysis Tool") gr.Markdown("Upload an EEG file (.edf format) for analysis") with gr.Row(): file_input = gr.File( label="Upload EEG File", file_types=[".edf"], type="binary" ) with gr.Row(): analyze_btn = gr.Button("Analyze EEG") with gr.Row(): plot_output = gr.Plot(label="EEG Analysis Plots") feature_output = gr.Textbox(label="Feature Analysis", lines=10) analyze_btn.click( fn=process_eeg, inputs=[file_input], outputs=[plot_output, feature_output] ) # Launch the application if __name__ == "__main__": iface.launch(share=True)