mekosotto Claude Sonnet 4.6 commited on
Commit
c4c7642
·
1 Parent(s): 0d591d4

feat(eeg): add run_pipeline orchestrator + CLI (FIF/EDF → Parquet)

Browse files
src/pipelines/eeg_pipeline.py CHANGED
@@ -12,6 +12,8 @@ a logged WARNING), determinism (seeded ICA + sklearn RNG), traceability
12
  """
13
  from __future__ import annotations
14
 
 
 
15
  import mne
16
  import numpy as np
17
  import pandas as pd
@@ -30,6 +32,11 @@ logger = get_logger(__name__)
30
  _EOG_CORR_THRESHOLD: float = 0.9
31
 
32
 
 
 
 
 
 
33
  def is_valid_epoch(epoch: np.ndarray | None) -> bool:
34
  """Return True iff `epoch` is a non-empty 2-D numeric array with no NaN/inf.
35
 
@@ -390,3 +397,77 @@ def extract_features_from_recording(
390
  100.0 * n_dropped / max(n_total_epochs, 1),
391
  )
392
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
  from __future__ import annotations
14
 
15
+ from pathlib import Path
16
+
17
  import mne
18
  import numpy as np
19
  import pandas as pd
 
32
  _EOG_CORR_THRESHOLD: float = 0.9
33
 
34
 
35
+ # Default I/O paths for the EEG pipeline. Override via run_pipeline() args.
36
+ DEFAULT_INPUT = Path("data/raw/eeg.fif")
37
+ DEFAULT_OUTPUT = Path("data/processed/eeg_features.parquet")
38
+
39
+
40
  def is_valid_epoch(epoch: np.ndarray | None) -> bool:
41
  """Return True iff `epoch` is a non-empty 2-D numeric array with no NaN/inf.
42
 
 
397
  100.0 * n_dropped / max(n_total_epochs, 1),
398
  )
399
  return out
400
+
401
+
402
+ def run_pipeline(
403
+ input_path: Path = DEFAULT_INPUT,
404
+ output_path: Path = DEFAULT_OUTPUT,
405
+ epoch_duration_s: float = 2.0,
406
+ eog_ch_name: str | None = None,
407
+ n_components: int = 15,
408
+ random_state: int = 97,
409
+ ) -> None:
410
+ """Run the EEG pipeline end-to-end: raw FIF/EDF -> processed feature Parquet.
411
+
412
+ Reads `input_path` via MNE, applies bandpass + ICA + epoching + feature
413
+ extraction, then writes a model-ready Parquet at `output_path` (preserves
414
+ float64 dtype; satisfies AGENTS.md §6).
415
+
416
+ Args:
417
+ input_path: Path to the raw recording (.fif or .edf).
418
+ output_path: Where to write the processed feature Parquet file.
419
+ Parent directory is created if missing.
420
+ epoch_duration_s: Length of each fixed-duration epoch (seconds).
421
+ eog_ch_name: Name of the EOG channel for ICA-based artifact rejection.
422
+ None disables ICA.
423
+ n_components: Cap on ICA components.
424
+ random_state: Seed for ICA's solver. Required for §4 Determinism.
425
+
426
+ Raises:
427
+ FileNotFoundError: if `input_path` does not exist.
428
+ IsADirectoryError: if `output_path` resolves to an existing directory.
429
+ """
430
+ input_path = Path(input_path)
431
+ output_path = Path(output_path)
432
+ if not input_path.exists():
433
+ raise FileNotFoundError(f"Raw EEG file not found: {input_path}")
434
+
435
+ logger.info("Reading raw EEG from %s", input_path)
436
+ if input_path.suffix.lower() == ".edf":
437
+ raw = mne.io.read_raw_edf(input_path, preload=True, verbose="ERROR")
438
+ else:
439
+ raw = mne.io.read_raw_fif(input_path, preload=True, verbose="ERROR")
440
+ logger.info(
441
+ "Loaded %d channels, sfreq=%.1f Hz, n_times=%d",
442
+ len(raw.ch_names), raw.info["sfreq"], raw.n_times,
443
+ )
444
+
445
+ features = extract_features_from_recording(
446
+ raw,
447
+ epoch_duration_s=epoch_duration_s,
448
+ eog_ch_name=eog_ch_name,
449
+ n_components=n_components,
450
+ random_state=random_state,
451
+ )
452
+
453
+ output_path.parent.mkdir(parents=True, exist_ok=True)
454
+ if output_path.is_dir():
455
+ raise IsADirectoryError(
456
+ f"output_path must be a file, got a directory: {output_path}"
457
+ )
458
+ # Parquet preserves dtypes (float64 features stay float64) and is
459
+ # byte-deterministic with single-threaded snappy. AGENTS.md §6.
460
+ features.to_parquet(
461
+ output_path, index=False, engine="pyarrow", compression="snappy",
462
+ )
463
+ logger.info(
464
+ "Wrote processed features to %s (rows=%d, cols=%d)",
465
+ output_path, len(features), features.shape[1],
466
+ )
467
+
468
+
469
+ if __name__ == "__main__":
470
+ # Day-2 CLI entrypoint — runs with default paths against `data/raw/eeg.fif`.
471
+ # Argument parsing (argparse / click) will land in a later task.
472
+ # python -m src.pipelines.eeg_pipeline
473
+ run_pipeline()
tests/pipelines/test_eeg_pipeline.py CHANGED
@@ -1,6 +1,7 @@
1
  """Unit + integration tests for the EEG pipeline."""
2
  from __future__ import annotations
3
 
 
4
  from pathlib import Path
5
 
6
  import mne
@@ -14,6 +15,7 @@ from src.pipelines.eeg_pipeline import (
14
  extract_features_from_recording,
15
  is_valid_epoch,
16
  remove_artifacts_with_ica,
 
17
  )
18
 
19
 
@@ -341,3 +343,87 @@ class TestExtractFeaturesFromRecording:
341
  raw, epoch_duration_s=1e-6, eog_ch_name="EOG061",
342
  n_components=4, random_state=97,
343
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Unit + integration tests for the EEG pipeline."""
2
  from __future__ import annotations
3
 
4
+ import shutil
5
  from pathlib import Path
6
 
7
  import mne
 
15
  extract_features_from_recording,
16
  is_valid_epoch,
17
  remove_artifacts_with_ica,
18
+ run_pipeline,
19
  )
20
 
21
 
 
343
  raw, epoch_duration_s=1e-6, eog_ch_name="EOG061",
344
  n_components=4, random_state=97,
345
  )
346
+
347
+
348
+ class TestRunPipeline:
349
+ def test_end_to_end_writes_processed_parquet(self, tmp_path: Path) -> None:
350
+ raw_dir = tmp_path / "data" / "raw"
351
+ proc_dir = tmp_path / "data" / "processed"
352
+ raw_dir.mkdir(parents=True)
353
+ proc_dir.mkdir(parents=True)
354
+ input_path = raw_dir / "rec.fif"
355
+ output_path = proc_dir / "eeg_features.parquet"
356
+ shutil.copy(FIXTURE, input_path)
357
+
358
+ run_pipeline(
359
+ input_path=input_path, output_path=output_path,
360
+ epoch_duration_s=2.0, eog_ch_name="EOG061",
361
+ n_components=4, random_state=97,
362
+ )
363
+
364
+ assert output_path.exists()
365
+ df = pd.read_parquet(output_path)
366
+ assert len(df) == 5
367
+ assert all(c.startswith("feat_") for c in df.columns)
368
+
369
+ def test_run_pipeline_preserves_float64_dtype(self, tmp_path: Path) -> None:
370
+ raw_dir = tmp_path / "data" / "raw"
371
+ proc_dir = tmp_path / "data" / "processed"
372
+ raw_dir.mkdir(parents=True)
373
+ proc_dir.mkdir(parents=True)
374
+ input_path = raw_dir / "rec.fif"
375
+ output_path = proc_dir / "eeg_features.parquet"
376
+ shutil.copy(FIXTURE, input_path)
377
+
378
+ run_pipeline(
379
+ input_path=input_path, output_path=output_path,
380
+ epoch_duration_s=2.0, eog_ch_name="EOG061",
381
+ n_components=4, random_state=97,
382
+ )
383
+ df = pd.read_parquet(output_path)
384
+ for col in df.columns:
385
+ assert df[col].dtype == np.float64, f"{col} widened to {df[col].dtype}"
386
+
387
+ def test_run_pipeline_is_idempotent(self, tmp_path: Path) -> None:
388
+ raw_dir = tmp_path / "data" / "raw"
389
+ proc_dir = tmp_path / "data" / "processed"
390
+ raw_dir.mkdir(parents=True)
391
+ proc_dir.mkdir(parents=True)
392
+ input_path = raw_dir / "rec.fif"
393
+ output_path = proc_dir / "eeg_features.parquet"
394
+ shutil.copy(FIXTURE, input_path)
395
+
396
+ run_pipeline(
397
+ input_path=input_path, output_path=output_path,
398
+ epoch_duration_s=2.0, eog_ch_name="EOG061",
399
+ n_components=4, random_state=97,
400
+ )
401
+ first = output_path.read_bytes()
402
+ run_pipeline(
403
+ input_path=input_path, output_path=output_path,
404
+ epoch_duration_s=2.0, eog_ch_name="EOG061",
405
+ n_components=4, random_state=97,
406
+ )
407
+ second = output_path.read_bytes()
408
+ assert first == second, "EEG pipeline output must be byte-deterministic"
409
+
410
+ def test_run_pipeline_raises_when_input_missing(self, tmp_path: Path) -> None:
411
+ with pytest.raises(FileNotFoundError):
412
+ run_pipeline(
413
+ input_path=tmp_path / "nope.fif",
414
+ output_path=tmp_path / "out.parquet",
415
+ )
416
+
417
+ def test_run_pipeline_rejects_directory_as_output(self, tmp_path: Path) -> None:
418
+ raw_dir = tmp_path / "data" / "raw"
419
+ raw_dir.mkdir(parents=True)
420
+ input_path = raw_dir / "rec.fif"
421
+ shutil.copy(FIXTURE, input_path)
422
+ bad_output = tmp_path / "out_dir"
423
+ bad_output.mkdir()
424
+ with pytest.raises(IsADirectoryError, match="must be a file"):
425
+ run_pipeline(
426
+ input_path=input_path, output_path=bad_output,
427
+ epoch_duration_s=2.0, eog_ch_name="EOG061",
428
+ n_components=4, random_state=97,
429
+ )