|
import mne |
|
import numpy as np |
|
from scipy import signal |
|
from typing import Dict, List, Tuple, Optional |
|
|
|
class EEGProcessor: |
|
def __init__(self): |
|
self.sfreq = 250 |
|
self.freq_bands = { |
|
'delta': (0.5, 4), |
|
'theta': (4, 8), |
|
'alpha': (8, 13), |
|
'beta': (13, 30), |
|
'gamma': (30, 50) |
|
} |
|
|
|
def preprocess(self, raw: mne.io.Raw) -> mne.io.Raw: |
|
"""Preprocess raw EEG data""" |
|
|
|
if raw.get_montage() is None: |
|
raw.set_montage('standard_1020') |
|
|
|
|
|
raw_processed = raw.copy() |
|
|
|
|
|
raw_processed.filter(l_freq=0.5, h_freq=50.0) |
|
raw_processed.notch_filter(freqs=50) |
|
|
|
|
|
raw_processed.interpolate_bads() |
|
|
|
|
|
ica = mne.preprocessing.ICA(n_components=0.95, random_state=42) |
|
ica.fit(raw_processed) |
|
|
|
|
|
eog_indices, eog_scores = ica.find_bads_eog(raw_processed) |
|
if eog_indices: |
|
ica.exclude = eog_indices |
|
ica.apply(raw_processed) |
|
|
|
return raw_processed |
|
|
|
def extract_features(self, raw: mne.io.Raw) -> Dict: |
|
"""Extract relevant features from preprocessed EEG data""" |
|
features = {} |
|
|
|
|
|
data, times = raw.get_data(return_times=True) |
|
|
|
|
|
psds, freqs = mne.time_frequency.psd_welch( |
|
raw, |
|
fmin=0.5, |
|
fmax=50.0, |
|
n_fft=int(raw.info['sfreq'] * 4), |
|
n_overlap=int(raw.info['sfreq'] * 2) |
|
) |
|
|
|
|
|
features['band_powers'] = self._calculate_band_powers(psds, freqs) |
|
|
|
|
|
features['connectivity'] = self._calculate_connectivity(data) |
|
|
|
|
|
features['statistics'] = self._calculate_statistics(data) |
|
|
|
return features |
|
|
|
def _calculate_band_powers(self, psds: np.ndarray, freqs: np.ndarray) -> Dict: |
|
"""Calculate power in different frequency bands""" |
|
band_powers = {} |
|
|
|
for band_name, (fmin, fmax) in self.freq_bands.items(): |
|
|
|
freq_mask = (freqs >= fmin) & (freqs <= fmax) |
|
|
|
|
|
band_power = np.mean(psds[:, freq_mask], axis=1) |
|
band_powers[band_name] = band_power |
|
|
|
return band_powers |
|
|
|
def _calculate_connectivity(self, data: np.ndarray) -> Dict: |
|
"""Calculate connectivity metrics between channels""" |
|
n_channels = data.shape[0] |
|
connectivity = { |
|
'correlation': np.corrcoef(data), |
|
'coherence': np.zeros((n_channels, n_channels)) |
|
} |
|
|
|
|
|
for i in range(n_channels): |
|
for j in range(i + 1, n_channels): |
|
f, coh = signal.coherence(data[i], data[j], fs=self.sfreq) |
|
connectivity['coherence'][i, j] = np.mean(coh) |
|
connectivity['coherence'][j, i] = connectivity['coherence'][i, j] |
|
|
|
return connectivity |
|
|
|
def _calculate_statistics(self, data: np.ndarray) -> Dict: |
|
"""Calculate statistical features for each channel""" |
|
stats = { |
|
'mean': np.mean(data, axis=1), |
|
'std': np.std(data, axis=1), |
|
'skewness': self._calculate_skewness(data), |
|
'kurtosis': self._calculate_kurtosis(data), |
|
'hjorth': self._calculate_hjorth_parameters(data) |
|
} |
|
return stats |
|
|
|
def _calculate_skewness(self, data: np.ndarray) -> np.ndarray: |
|
"""Calculate skewness for each channel""" |
|
return np.array([signal.skew(channel) for channel in data]) |
|
|
|
def _calculate_kurtosis(self, data: np.ndarray) -> np.ndarray: |
|
"""Calculate kurtosis for each channel""" |
|
return np.array([signal.kurtosis(channel) for channel in data]) |
|
|
|
def _calculate_hjorth_parameters(self, data: np.ndarray) -> Dict: |
|
"""Calculate Hjorth parameters (activity, mobility, complexity)""" |
|
activity = np.var(data, axis=1) |
|
|
|
|
|
diff1 = np.diff(data, axis=1) |
|
mobility = np.sqrt(np.var(diff1, axis=1) / activity) |
|
|
|
|
|
diff2 = np.diff(diff1, axis=1) |
|
complexity = np.sqrt(np.var(diff2, axis=1) / np.var(diff1, axis=1)) / mobility |
|
|
|
return { |
|
'activity': activity, |
|
'mobility': mobility, |
|
'complexity': complexity |
|
} |
|
|
|
def process_file(self, file_path: str) -> Dict: |
|
"""Process an EEG file and extract features. |
|
|
|
Args: |
|
file_path: Path to the EEG file (EDF, BDF, or CNT format) |
|
|
|
Returns: |
|
Dict containing processed EEG data and features |
|
""" |
|
|
|
raw = mne.io.read_raw_edf(file_path, preload=True) |
|
|
|
|
|
raw_processed = self.preprocess(raw) |
|
|
|
|
|
features = self.extract_features(raw_processed) |
|
|
|
return { |
|
'raw_data': raw_processed, |
|
'features': features |
|
} |