| """E1 — PPG encoding decision: morphological vs raw patch. |
| |
| Per the E1 decision rule in EXPERIMENT_TRACKING.md: |
| if morphology_extraction_rate < 0.70: -> raw patches |
| elif E1b_linear_probe_AUROC > E1a + 0.02: -> morphological |
| else: -> raw patches |
| |
| This script implements Stage 1 (extraction rate) directly. If extraction rate |
| passes, we'd move to Stage 2 (linear probe comparison on AF) — but that |
| requires AF labels, which are pending. For now we decide Stage 1 and defer |
| Stage 2 until AF labels land. |
| |
| Features extracted (Bishop & Ercole / neurokit2): |
| PPG_Rate, PPG_Width, PPG_UpstrokeSlope, PPG_Amplitude, PPG_DicroticNotch. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import random |
| import re |
| import warnings |
| from pathlib import Path |
|
|
| import numpy as np |
| from dotenv import load_dotenv |
| from tqdm import tqdm |
|
|
| warnings.filterwarnings("ignore") |
| load_dotenv() |
| os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) |
|
|
| from datasets import load_from_disk |
| from huggingface_hub import snapshot_download |
|
|
| import neurokit2 as nk |
|
|
| REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" |
| OUT = Path(__file__).resolve().parent.parent / "docs" |
| RNG = random.Random(11) |
|
|
|
|
| def try_morphology(ppg: np.ndarray, fs: float) -> tuple[bool, int, int]: |
| """Returns (ok, n_detected_beats, n_expected_beats). |
| |
| `ok` is True if neurokit2 detects ≥5 valid beats AND the fraction |
| detected/expected > 0.70. Expected beats is duration * typical_hr (60-100). |
| """ |
| try: |
| signals, info = nk.ppg_process(ppg, sampling_rate=int(round(fs))) |
| peaks = np.asarray(info.get("PPG_Peaks", [])) |
| if len(peaks) < 5: |
| return False, len(peaks), 0 |
| duration_s = len(ppg) / fs |
| |
| detected_rate = signals["PPG_Rate"].dropna().median() |
| if not np.isfinite(detected_rate) or detected_rate < 30 or detected_rate > 200: |
| return False, len(peaks), 0 |
| expected = int(duration_s * detected_rate / 60.0) |
| if expected < 3: |
| return False, len(peaks), expected |
| extracted_frac = len(peaks) / expected |
| return 0.70 <= extracted_frac <= 1.30, len(peaks), expected |
| except Exception: |
| return False, 0, 0 |
|
|
|
|
| def main() -> None: |
| |
| want = sorted(RNG.sample(range(412), 40)) |
| root = Path( |
| snapshot_download( |
| REPO, |
| repo_type="dataset", |
| allow_patterns=[f"shard_{i:05d}/*" for i in want], |
| max_workers=12, |
| ) |
| ) |
| shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()] |
|
|
| n_attempted = 0 |
| n_ok = 0 |
| n_nonempty = 0 |
| beat_counts = [] |
| target = 500 |
| results = [] |
|
|
| for sidx in tqdm(shards, desc="shards"): |
| if n_attempted >= target: |
| break |
| ds = load_from_disk(str(root / f"shard_{sidx:05d}")) |
| for i in range(len(ds)): |
| if n_attempted >= target: |
| break |
| row = ds[i] |
| ppg = np.asarray(row["ppg"], dtype=np.float32)[0] |
| fs = float(row["ppg_fs"]) |
| n_attempted += 1 |
| if ppg.size == 0: |
| continue |
| n_nonempty += 1 |
| ok, got, exp = try_morphology(ppg, fs) |
| beat_counts.append(got) |
| if ok: |
| n_ok += 1 |
| results.append( |
| {"record": row["record_name"], "ok": ok, "detected": got, "expected": exp} |
| ) |
|
|
| extraction_rate = n_ok / max(n_nonempty, 1) |
| decision = "raw_patches" if extraction_rate < 0.70 else "needs_stage2_probe" |
|
|
| report = { |
| "n_segments_attempted": n_attempted, |
| "n_segments_nonempty": n_nonempty, |
| "n_segments_ok": n_ok, |
| "extraction_rate": extraction_rate, |
| "median_detected_beats_per_segment": ( |
| float(np.median(beat_counts)) if beat_counts else 0.0 |
| ), |
| "mean_detected_beats_per_segment": ( |
| float(np.mean(beat_counts)) if beat_counts else 0.0 |
| ), |
| "stage1_decision": decision, |
| "rule": ( |
| "extraction_rate < 0.70 -> raw_patches (stop). " |
| "else -> run stage-2 linear-probe comparison after AF labels arrive." |
| ), |
| } |
| (OUT / "e1_stage1_report.json").write_text(json.dumps(report, indent=2)) |
| print(json.dumps(report, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|