Spaces:
Sleeping
Sleeping
| # preprocessing.py | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import mne | |
| class PreprocessConfig: | |
| fs: float | |
| f_low: float | |
| f_high: float | |
| def to_time_channel(x: np.ndarray) -> np.ndarray: | |
| if x.ndim == 1: | |
| return x[:, None] | |
| if x.ndim != 2: | |
| raise ValueError(f"Expected 1D or 2D array, got {x.shape}") | |
| T, C = x.shape | |
| if T <= 256 and C > T: | |
| x = x.T | |
| return x | |
| def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray: | |
| info = mne.create_info( | |
| ch_names=[f"ch{i}" for i in range(x_tc.shape[1])], | |
| sfreq=cfg.fs, | |
| ch_types="eeg", | |
| ) | |
| raw = mne.io.RawArray(x_tc.T, info, verbose=False) | |
| raw_filt = raw.copy().filter(cfg.f_low, cfg.f_high, verbose=False) | |
| return raw_filt.get_data().T | |
| def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray: | |
| Xf = np.fft.fft(x_tc, axis=0) | |
| N = Xf.shape[0] | |
| h = np.zeros(N) | |
| if N % 2 == 0: | |
| h[0] = h[N // 2] = 1 | |
| h[1:N // 2] = 2 | |
| else: | |
| h[0] = 1 | |
| h[1:(N + 1) // 2] = 2 | |
| env = np.abs(np.fft.ifft(Xf * h[:, None], axis=0)) | |
| return env.astype(np.float32) | |
| def preprocess_pipeline(x: np.ndarray, cfg: PreprocessConfig): | |
| x_tc = to_time_channel(x) | |
| x_filt = bandpass_tc(x_tc, cfg) | |
| env = hilbert_envelope_tc(x_filt) | |
| return { | |
| "raw": x_tc, | |
| "filtered": x_filt, | |
| "envelope": env, | |
| } | |