invincible-jha's picture
Upload eeg_processor.py
e9b27a2 verified
import mne
import numpy as np
from scipy import signal
from typing import Dict, List, Tuple, Optional
class EEGProcessor:
def __init__(self):
self.sfreq = 250 # Default sampling frequency
self.freq_bands = {
'delta': (0.5, 4),
'theta': (4, 8),
'alpha': (8, 13),
'beta': (13, 30),
'gamma': (30, 50)
}
def preprocess(self, raw: mne.io.Raw) -> mne.io.Raw:
"""Preprocess raw EEG data"""
# Set montage if not present
if raw.get_montage() is None:
raw.set_montage('standard_1020')
# Basic preprocessing pipeline
raw_processed = raw.copy()
# Filter data
raw_processed.filter(l_freq=0.5, h_freq=50.0)
raw_processed.notch_filter(freqs=50) # Remove power line noise
# Detect and interpolate bad channels
raw_processed.interpolate_bads()
# Apply ICA for artifact removal
ica = mne.preprocessing.ICA(n_components=0.95, random_state=42)
ica.fit(raw_processed)
# Detect and remove eye blinks
eog_indices, eog_scores = ica.find_bads_eog(raw_processed)
if eog_indices:
ica.exclude = eog_indices
ica.apply(raw_processed)
return raw_processed
def extract_features(self, raw: mne.io.Raw) -> Dict:
"""Extract relevant features from preprocessed EEG data"""
features = {}
# Get data and times
data, times = raw.get_data(return_times=True)
# Calculate power spectral density
psds, freqs = mne.time_frequency.psd_welch(
raw,
fmin=0.5,
fmax=50.0,
n_fft=int(raw.info['sfreq'] * 4),
n_overlap=int(raw.info['sfreq'] * 2)
)
# Extract band powers
features['band_powers'] = self._calculate_band_powers(psds, freqs)
# Calculate connectivity metrics
features['connectivity'] = self._calculate_connectivity(data)
# Extract statistical features
features['statistics'] = self._calculate_statistics(data)
return features
def _calculate_band_powers(self, psds: np.ndarray, freqs: np.ndarray) -> Dict:
"""Calculate power in different frequency bands"""
band_powers = {}
for band_name, (fmin, fmax) in self.freq_bands.items():
# Find frequencies that fall within band
freq_mask = (freqs >= fmin) & (freqs <= fmax)
# Calculate average power in band
band_power = np.mean(psds[:, freq_mask], axis=1)
band_powers[band_name] = band_power
return band_powers
def _calculate_connectivity(self, data: np.ndarray) -> Dict:
"""Calculate connectivity metrics between channels"""
n_channels = data.shape[0]
connectivity = {
'correlation': np.corrcoef(data),
'coherence': np.zeros((n_channels, n_channels))
}
# Calculate coherence between all channel pairs
for i in range(n_channels):
for j in range(i + 1, n_channels):
f, coh = signal.coherence(data[i], data[j], fs=self.sfreq)
connectivity['coherence'][i, j] = np.mean(coh)
connectivity['coherence'][j, i] = connectivity['coherence'][i, j]
return connectivity
def _calculate_statistics(self, data: np.ndarray) -> Dict:
"""Calculate statistical features for each channel"""
stats = {
'mean': np.mean(data, axis=1),
'std': np.std(data, axis=1),
'skewness': self._calculate_skewness(data),
'kurtosis': self._calculate_kurtosis(data),
'hjorth': self._calculate_hjorth_parameters(data)
}
return stats
def _calculate_skewness(self, data: np.ndarray) -> np.ndarray:
"""Calculate skewness for each channel"""
return np.array([signal.skew(channel) for channel in data])
def _calculate_kurtosis(self, data: np.ndarray) -> np.ndarray:
"""Calculate kurtosis for each channel"""
return np.array([signal.kurtosis(channel) for channel in data])
def _calculate_hjorth_parameters(self, data: np.ndarray) -> Dict:
"""Calculate Hjorth parameters (activity, mobility, complexity)"""
activity = np.var(data, axis=1)
# First derivative variance
diff1 = np.diff(data, axis=1)
mobility = np.sqrt(np.var(diff1, axis=1) / activity)
# Second derivative variance
diff2 = np.diff(diff1, axis=1)
complexity = np.sqrt(np.var(diff2, axis=1) / np.var(diff1, axis=1)) / mobility
return {
'activity': activity,
'mobility': mobility,
'complexity': complexity
}
def process_file(self, file_path: str) -> Dict:
"""Process an EEG file and extract features.
Args:
file_path: Path to the EEG file (EDF, BDF, or CNT format)
Returns:
Dict containing processed EEG data and features
"""
# Load EEG file using MNE
raw = mne.io.read_raw_edf(file_path, preload=True)
# Preprocess the data
raw_processed = self.preprocess(raw)
# Extract features
features = self.extract_features(raw_processed)
return {
'raw_data': raw_processed,
'features': features
}