import mne import numpy as np from scipy import signal from typing import Dict, List, Tuple, Optional class EEGProcessor: def __init__(self): self.sfreq = 250 # Default sampling frequency self.freq_bands = { 'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 13), 'beta': (13, 30), 'gamma': (30, 50) } def preprocess(self, raw: mne.io.Raw) -> mne.io.Raw: """Preprocess raw EEG data""" # Set montage if not present if raw.get_montage() is None: raw.set_montage('standard_1020') # Basic preprocessing pipeline raw_processed = raw.copy() # Filter data raw_processed.filter(l_freq=0.5, h_freq=50.0) raw_processed.notch_filter(freqs=50) # Remove power line noise # Detect and interpolate bad channels raw_processed.interpolate_bads() # Apply ICA for artifact removal ica = mne.preprocessing.ICA(n_components=0.95, random_state=42) ica.fit(raw_processed) # Detect and remove eye blinks eog_indices, eog_scores = ica.find_bads_eog(raw_processed) if eog_indices: ica.exclude = eog_indices ica.apply(raw_processed) return raw_processed def extract_features(self, raw: mne.io.Raw) -> Dict: """Extract relevant features from preprocessed EEG data""" features = {} # Get data and times data, times = raw.get_data(return_times=True) # Calculate power spectral density psds, freqs = mne.time_frequency.psd_welch( raw, fmin=0.5, fmax=50.0, n_fft=int(raw.info['sfreq'] * 4), n_overlap=int(raw.info['sfreq'] * 2) ) # Extract band powers features['band_powers'] = self._calculate_band_powers(psds, freqs) # Calculate connectivity metrics features['connectivity'] = self._calculate_connectivity(data) # Extract statistical features features['statistics'] = self._calculate_statistics(data) return features def _calculate_band_powers(self, psds: np.ndarray, freqs: np.ndarray) -> Dict: """Calculate power in different frequency bands""" band_powers = {} for band_name, (fmin, fmax) in self.freq_bands.items(): # Find frequencies that fall within band freq_mask = (freqs >= fmin) & (freqs <= fmax) # Calculate average power in band band_power = np.mean(psds[:, freq_mask], axis=1) band_powers[band_name] = band_power return band_powers def _calculate_connectivity(self, data: np.ndarray) -> Dict: """Calculate connectivity metrics between channels""" n_channels = data.shape[0] connectivity = { 'correlation': np.corrcoef(data), 'coherence': np.zeros((n_channels, n_channels)) } # Calculate coherence between all channel pairs for i in range(n_channels): for j in range(i + 1, n_channels): f, coh = signal.coherence(data[i], data[j], fs=self.sfreq) connectivity['coherence'][i, j] = np.mean(coh) connectivity['coherence'][j, i] = connectivity['coherence'][i, j] return connectivity def _calculate_statistics(self, data: np.ndarray) -> Dict: """Calculate statistical features for each channel""" stats = { 'mean': np.mean(data, axis=1), 'std': np.std(data, axis=1), 'skewness': self._calculate_skewness(data), 'kurtosis': self._calculate_kurtosis(data), 'hjorth': self._calculate_hjorth_parameters(data) } return stats def _calculate_skewness(self, data: np.ndarray) -> np.ndarray: """Calculate skewness for each channel""" return np.array([signal.skew(channel) for channel in data]) def _calculate_kurtosis(self, data: np.ndarray) -> np.ndarray: """Calculate kurtosis for each channel""" return np.array([signal.kurtosis(channel) for channel in data]) def _calculate_hjorth_parameters(self, data: np.ndarray) -> Dict: """Calculate Hjorth parameters (activity, mobility, complexity)""" activity = np.var(data, axis=1) # First derivative variance diff1 = np.diff(data, axis=1) mobility = np.sqrt(np.var(diff1, axis=1) / activity) # Second derivative variance diff2 = np.diff(diff1, axis=1) complexity = np.sqrt(np.var(diff2, axis=1) / np.var(diff1, axis=1)) / mobility return { 'activity': activity, 'mobility': mobility, 'complexity': complexity } def process_file(self, file_path: str) -> Dict: """Process an EEG file and extract features. Args: file_path: Path to the EEG file (EDF, BDF, or CNT format) Returns: Dict containing processed EEG data and features """ # Load EEG file using MNE raw = mne.io.read_raw_edf(file_path, preload=True) # Preprocess the data raw_processed = self.preprocess(raw) # Extract features features = self.extract_features(raw_processed) return { 'raw_data': raw_processed, 'features': features }