Spaces:
Build error
Build error
import mne | |
import numpy as np | |
from scipy import signal | |
from typing import Dict, List, Tuple, Optional | |
class EEGProcessor: | |
def __init__(self): | |
self.sfreq = 250 # Default sampling frequency | |
self.freq_bands = { | |
'delta': (0.5, 4), | |
'theta': (4, 8), | |
'alpha': (8, 13), | |
'beta': (13, 30), | |
'gamma': (30, 50) | |
} | |
def preprocess(self, raw: mne.io.Raw) -> mne.io.Raw: | |
"""Preprocess raw EEG data""" | |
# Set montage if not present | |
if raw.get_montage() is None: | |
raw.set_montage('standard_1020') | |
# Basic preprocessing pipeline | |
raw_processed = raw.copy() | |
# Filter data | |
raw_processed.filter(l_freq=0.5, h_freq=50.0) | |
raw_processed.notch_filter(freqs=50) # Remove power line noise | |
# Detect and interpolate bad channels | |
raw_processed.interpolate_bads() | |
# Apply ICA for artifact removal | |
ica = mne.preprocessing.ICA(n_components=0.95, random_state=42) | |
ica.fit(raw_processed) | |
# Detect and remove eye blinks | |
eog_indices, eog_scores = ica.find_bads_eog(raw_processed) | |
if eog_indices: | |
ica.exclude = eog_indices | |
ica.apply(raw_processed) | |
return raw_processed | |
def extract_features(self, raw: mne.io.Raw) -> Dict: | |
"""Extract relevant features from preprocessed EEG data""" | |
features = {} | |
# Get data and times | |
data, times = raw.get_data(return_times=True) | |
# Calculate power spectral density | |
psds, freqs = mne.time_frequency.psd_welch( | |
raw, | |
fmin=0.5, | |
fmax=50.0, | |
n_fft=int(raw.info['sfreq'] * 4), | |
n_overlap=int(raw.info['sfreq'] * 2) | |
) | |
# Extract band powers | |
features['band_powers'] = self._calculate_band_powers(psds, freqs) | |
# Calculate connectivity metrics | |
features['connectivity'] = self._calculate_connectivity(data) | |
# Extract statistical features | |
features['statistics'] = self._calculate_statistics(data) | |
return features | |
def _calculate_band_powers(self, psds: np.ndarray, freqs: np.ndarray) -> Dict: | |
"""Calculate power in different frequency bands""" | |
band_powers = {} | |
for band_name, (fmin, fmax) in self.freq_bands.items(): | |
# Find frequencies that fall within band | |
freq_mask = (freqs >= fmin) & (freqs <= fmax) | |
# Calculate average power in band | |
band_power = np.mean(psds[:, freq_mask], axis=1) | |
band_powers[band_name] = band_power | |
return band_powers | |
def _calculate_connectivity(self, data: np.ndarray) -> Dict: | |
"""Calculate connectivity metrics between channels""" | |
n_channels = data.shape[0] | |
connectivity = { | |
'correlation': np.corrcoef(data), | |
'coherence': np.zeros((n_channels, n_channels)) | |
} | |
# Calculate coherence between all channel pairs | |
for i in range(n_channels): | |
for j in range(i + 1, n_channels): | |
f, coh = signal.coherence(data[i], data[j], fs=self.sfreq) | |
connectivity['coherence'][i, j] = np.mean(coh) | |
connectivity['coherence'][j, i] = connectivity['coherence'][i, j] | |
return connectivity | |
def _calculate_statistics(self, data: np.ndarray) -> Dict: | |
"""Calculate statistical features for each channel""" | |
stats = { | |
'mean': np.mean(data, axis=1), | |
'std': np.std(data, axis=1), | |
'skewness': self._calculate_skewness(data), | |
'kurtosis': self._calculate_kurtosis(data), | |
'hjorth': self._calculate_hjorth_parameters(data) | |
} | |
return stats | |
def _calculate_skewness(self, data: np.ndarray) -> np.ndarray: | |
"""Calculate skewness for each channel""" | |
return np.array([signal.skew(channel) for channel in data]) | |
def _calculate_kurtosis(self, data: np.ndarray) -> np.ndarray: | |
"""Calculate kurtosis for each channel""" | |
return np.array([signal.kurtosis(channel) for channel in data]) | |
def _calculate_hjorth_parameters(self, data: np.ndarray) -> Dict: | |
"""Calculate Hjorth parameters (activity, mobility, complexity)""" | |
activity = np.var(data, axis=1) | |
# First derivative variance | |
diff1 = np.diff(data, axis=1) | |
mobility = np.sqrt(np.var(diff1, axis=1) / activity) | |
# Second derivative variance | |
diff2 = np.diff(diff1, axis=1) | |
complexity = np.sqrt(np.var(diff2, axis=1) / np.var(diff1, axis=1)) / mobility | |
return { | |
'activity': activity, | |
'mobility': mobility, | |
'complexity': complexity | |
} | |
def process_file(self, file_path: str) -> Dict: | |
"""Process an EEG file and extract features. | |
Args: | |
file_path: Path to the EEG file (EDF, BDF, or CNT format) | |
Returns: | |
Dict containing processed EEG data and features | |
""" | |
# Load EEG file using MNE | |
raw = mne.io.read_raw_edf(file_path, preload=True) | |
# Preprocess the data | |
raw_processed = self.preprocess(raw) | |
# Extract features | |
features = self.extract_features(raw_processed) | |
return { | |
'raw_data': raw_processed, | |
'features': features | |
} |