| """Validate alignment using PPG foot (onset) rather than systolic peak. |
| |
| Foot = minimum between two consecutive systolic peaks. This is the feature |
| that physiologically corresponds to the pulse arrival time. Using it should |
| collapse the bimodal PTT distribution. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import random |
| import re |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from dotenv import load_dotenv |
| from scipy.signal import butter, filtfilt, find_peaks |
|
|
| load_dotenv() |
| os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) |
| from datasets import load_from_disk |
| from huggingface_hub import snapshot_download |
|
|
| REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" |
| OUT = Path(__file__).resolve().parent.parent / "docs" |
| FIG = OUT / "figures" |
| FIG.mkdir(parents=True, exist_ok=True) |
| RNG = random.Random(7) |
|
|
|
|
| def bandpass(x, fs, lo, hi, order=3): |
| ny = 0.5 * fs |
| b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") |
| return filtfilt(b, a, x, method="gust") |
|
|
|
|
| def r_peaks(ecg, fs): |
| x = bandpass(ecg, fs, 5.0, 15.0) |
| s = np.diff(x, prepend=x[:1]) ** 2 |
| w = max(int(0.12 * fs), 1) |
| mwa = np.convolve(s, np.ones(w) / w, mode="same") |
| thr = mwa.mean() + 0.5 * mwa.std() |
| p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) |
| snap = max(int(0.06 * fs), 1) |
| return np.asarray( |
| [max(0, q - snap) + int(np.argmax(x[max(0, q - snap) : min(len(x), q + snap)])) for q in p] |
| ) |
|
|
|
|
| def ppg_feet(ppg, fs): |
| """Detect PPG foot via zero-crossing of filtered first derivative going pos, gated by peaks.""" |
| x = bandpass(ppg, fs, 0.5, 8.0) |
| |
| peaks, _ = find_peaks( |
| x, distance=int(0.3 * fs), height=x.mean() + 0.3 * x.std(), prominence=0.1 * x.std() |
| ) |
| feet = [] |
| for i in range(1, len(peaks)): |
| lo, hi = peaks[i - 1], peaks[i] |
| |
| feet.append(lo + int(np.argmin(x[lo:hi]))) |
| return np.asarray(feet, dtype=int) |
|
|
|
|
| def clean_ptts_via_foot(ecg, ecg_fs, ppg, ppg_fs, t0e, t0p): |
| r = r_peaks(ecg, ecg_fs) |
| f = ppg_feet(ppg, ppg_fs) |
| if len(r) < 3 or len(f) < 3: |
| return [] |
| r_t = t0e + r / ecg_fs |
| f_t = t0p + f / ppg_fs |
| out = [] |
| for rt in r_t: |
| cand = f_t[(f_t >= rt + 0.050) & (f_t <= rt + 0.500)] |
| if len(cand) == 1: |
| out.append((cand[0] - rt) * 1000.0) |
| return out |
|
|
|
|
| def main(): |
| want = list(range(0, 412, 20)) |
| root = Path( |
| snapshot_download( |
| REPO, |
| repo_type="dataset", |
| allow_patterns=[f"shard_{i:05d}/*" for i in want], |
| max_workers=12, |
| ) |
| ) |
| shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()] |
| all_ptts = [] |
| stds = [] |
| good = 0 |
| for sidx in shards: |
| if good >= 100: |
| break |
| ds = load_from_disk(str(root / f"shard_{sidx:05d}")) |
| for i in range(min(len(ds), 30)): |
| if good >= 100: |
| break |
| row = ds[i] |
| ecg = np.asarray(row["ecg"], dtype=np.float32) |
| ppg = np.asarray(row["ppg"], dtype=np.float32) |
| names = list(row["ecg_names"]) |
| if "II" not in names: |
| continue |
| lead = ecg[names.index("II")] |
| ptts = clean_ptts_via_foot( |
| lead, |
| float(row["ecg_fs"]), |
| ppg[0], |
| float(row["ppg_fs"]), |
| float(row["ecg_time_s"][0]), |
| float(row["ppg_time_s"][0]), |
| ) |
| if len(ptts) >= 5: |
| all_ptts.extend(ptts) |
| stds.append(float(np.std(ptts))) |
| good += 1 |
| res = { |
| "n_clean_beats": len(all_ptts), |
| "n_good_segments": good, |
| "ptt_foot_median_ms": float(np.median(all_ptts)), |
| "ptt_foot_p5_ms": float(np.percentile(all_ptts, 5)), |
| "ptt_foot_p95_ms": float(np.percentile(all_ptts, 95)), |
| "within_segment_std_median_ms": float(np.median(stds)), |
| "within_segment_std_p90_ms": float(np.percentile(stds, 90)), |
| } |
| plt.figure(figsize=(7, 4)) |
| plt.hist(all_ptts, bins=60, color="#36a", edgecolor="black") |
| plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms") |
| plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms") |
| plt.xlabel("PTT (ECG R-peak → PPG foot) (ms)") |
| plt.ylabel("count") |
| plt.title(f"PTT via PPG foot — {len(all_ptts)} beats, {good} segments") |
| plt.legend() |
| plt.tight_layout() |
| plt.savefig(FIG / "ptt_histogram_foot.png", dpi=120) |
| plt.close() |
| (OUT / "e0_alignment.json").write_text(json.dumps(res, indent=2)) |
| print(json.dumps(res, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|