EEG-Network-Viewer / src /preprocess.py
stardust-coder's picture
[add] app files
b11ec91
# preprocessing.py
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import mne
@dataclass(frozen=True)
class PreprocessConfig:
fs: float
f_low: float
f_high: float
def to_time_channel(x: np.ndarray) -> np.ndarray:
if x.ndim == 1:
return x[:, None]
if x.ndim != 2:
raise ValueError(f"Expected 1D or 2D array, got {x.shape}")
T, C = x.shape
if T <= 256 and C > T:
x = x.T
return x
def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
info = mne.create_info(
ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
sfreq=cfg.fs,
ch_types="eeg",
)
raw = mne.io.RawArray(x_tc.T, info, verbose=False)
raw_filt = raw.copy().filter(cfg.f_low, cfg.f_high, verbose=False)
return raw_filt.get_data().T
def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
Xf = np.fft.fft(x_tc, axis=0)
N = Xf.shape[0]
h = np.zeros(N)
if N % 2 == 0:
h[0] = h[N // 2] = 1
h[1:N // 2] = 2
else:
h[0] = 1
h[1:(N + 1) // 2] = 2
env = np.abs(np.fft.ifft(Xf * h[:, None], axis=0))
return env.astype(np.float32)
def preprocess_pipeline(x: np.ndarray, cfg: PreprocessConfig):
x_tc = to_time_channel(x)
x_filt = bandpass_tc(x_tc, cfg)
env = hilbert_envelope_tc(x_filt)
return {
"raw": x_tc,
"filtered": x_filt,
"envelope": env,
}